做網頁檢索系統的時候,當時還想加給它加上一個拼寫糾錯的功能(當然,只適合英文...),後來找到了個效果不錯的方法給大家分享一下。
拼寫糾錯、校正是一個提高搜索引擎用戶體驗的很關鍵的一項能力。如我們在Google搜索 saferi,然後會看到它自動幫助我們糾正了輸錯的單詞。
仔細想一想爲什麼它會知道我們輸錯的單詞就是safari,我們會發現主要原因有兩個;
1. 錯寫的單詞與正確單詞的拼寫相似,容易錯寫;這裏safari是否容易錯寫成saferi需要統計數據的支持;爲了簡化問題,我們認爲字形越相近的錯寫率越高,用編輯距離來表示。字形相近要求單詞之間編輯距離小於等於2,這裏saferi與safari編輯距離爲1,後面我們再具體瞭解編輯距離的定義。
2. 正確單詞有很多,除去語義因素外最有可能的單詞,也就是這個單詞的使用頻率了。所以我們確認的標準還有一項就是,單詞使用頻率。
下面介紹一個機器學習拼寫檢查方法,基於貝葉斯定理的拼寫檢查法,主要思想就是上面2條,列舉所有可能的正確拼寫,根據編輯距離以及詞頻從中選取可能性最大的用於校正。
原理:
用戶輸入的錯誤的單詞記做w,用戶想要輸入的拼寫正確的單詞記做c,則
P(c | w) : 用戶輸錯成w時,想要的單詞是c的概率。
P(w | c) : 用戶將c錯寫成w的概率,與編輯距離有關。
P(c) : 正確詞是c的概率,可以認爲是c的使用頻率,需要數據訓練。
根據貝葉斯公式
P(c | w) = P(w | c) * P(c) / P(w)
因爲同一次糾正中w是不變的,所以公式中我們不必理會P(w),它是一個常量。比較 P(c | w) 就是比較 P(w | c) * P(c) 的大小。
一、P(c)
P(c)替換成“使用頻率”,我們從足夠大的文本庫(詞典)點擊打開鏈接中統計出各個單詞的出現頻率,也可以將頻率歸一化縮小方便比較。
二、P(w | c)
P(w | c)替換成常數lambda * editDist
editDist編輯距離只計算editDist = 1與editDist = 2的,
editDist1,編輯距離爲1的有下面幾種情況:
(1)splits:將word依次按照每一位分割成前後兩半。比如,'abc'會被分割成 [('', 'abc'), ('a', 'bc'), ('ab', 'c'), ('abc', '')] 。
(2)beletes:依次刪除word的每一位後、所形成的所有新詞。比如,'abc'對應的deletes就是 ['bc', 'ac', 'ab'] 。
(3)transposes:依次交換word的鄰近兩位,所形成的所有新詞。比如,'abc'對應的transposes就是 ['bac', 'acb'] 。
(4)replaces:將word的每一位依次替換成其他25個字母,所形成的所有新詞。比如,'abc'對應的replaces就是 ['abc', 'bbc', 'cbc', ... , 'abx', ' aby', 'abz' ] ,一共包含78個詞(26 × 3)。
(5)inserts:在word的鄰近兩位之間依次插入一個字母,所形成的所有新詞。比如,'abc' 對應的inserts就是['aabc', 'babc', 'cabc', ..., 'abcx', 'abcy', 'abcz'],一共包含104個詞(26 × 4)。
editDist2則是在editDist1得到的單詞集合的基礎上再對它們作以上五種變換,得到所有編輯距離爲2的單詞(無論是否存在,在詞典中不存在的記P(c) = 1)。
三、糾正規則
1. 如果拼寫的單詞在詞典中出現的,直接返回。
2. 如果詞典中不存在的,返回編輯距離爲1和2的所有單詞中,P(c | w) * P(c)最大的單詞。
Java實現代碼:
參考鏈接:https://github.com/toufuChew/webSearch/tree/master/src/QueryCorrection
詞頻統計(訓練數據)
package QueryCorrection;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import IKAnalyze.IKAnalyze;
/**
* 生成訓練集
* @author cq
*
*/
public class Statistics {
private static final String dict = "/Users/cq/Desktop/dict.txt";
private static final String big = "/Users/cq/Downloads/big.txt";
public Statistics(){
File dictfile = new File(dict);
if (dictfile.exists() && dictfile.length() != 0)
return;
HashMap<String, Integer> map = new HashMap<String, Integer>();
try {
BufferedInputStream bin = new BufferedInputStream(new FileInputStream(new File(big)));
byte[] b = new byte[65535];
int ch;
try {
while((ch = bin.read(b)) != -1){
String txt = new String(b, "utf-8");
String[] arr = IKAnalyze.CNAnalyzerBStr(txt); //會過濾停用詞,如and,the...
for (int i = 0; i < arr.length; i++) {
if (map.containsKey(arr[i]))
map.put(arr[i], map.get(arr[i]) + 1);
else
map.put(arr[i], 1);
}
}
Iterator<Map.Entry<String, Integer>> it = map.entrySet().iterator();
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(dictfile)));
while(it.hasNext()){
Map.Entry<String, Integer> entry = it.next();
bw.write(entry.getKey() + " " + entry.getValue() + "\n");
}
bw.flush();
bw.close();
System.out.println("已完成統計!");
} catch (IOException e) {
e.printStackTrace();
}
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
public static HashMap<String, Integer> words(){
try {
BufferedReader br = new BufferedReader(new InputStreamReader(new BufferedInputStream(new FileInputStream(dict))));
String row;
HashMap<String, Integer> map = new HashMap<String, Integer>();
try {
while((row = br.readLine()) != null){
String[] s = row.split(" ");
map.put(s[0], Integer.valueOf(s[1]));
}
br.close();
return map;
} catch (IOException e) {
e.printStackTrace();
}
} catch (FileNotFoundException e) {
e.printStackTrace();
}
return null;
}
public static void main(String[] args){
new Statistics();
}
}
編輯距離處理及單詞校正
package QueryCorrection;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.regex.Pattern;
public class Correction{
private HashMap<String, Integer> map = null; //字典
private static final String alphabet = "abcdefghijklmnopqrstuvwxyz";
private static final double pow = 0.8;
public Correction(){
map = Statistics.words();
}
/**
* 編輯距離爲1的所有可能的詞
* @param word
* @return
*/
private ArrayList<String> editDist1(String word){
ArrayList<String> words = new ArrayList<String>();
BinaryWord[] binary = new BinaryWord[word.length()];
//split
for (int i = 0; i < word.length(); i++)
binary[i] = new BinaryWord(word.substring(0, i), word.substring(i));
//deletes
for (int i = 0; i < binary.length; i++)
if (binary[i].rw.length() > 1)
words.add(binary[i].lw + binary[i].rw.substring(1));
else
words.add(binary[i].lw);
//transposes
for (int i = 0; i < binary.length; i++)
if (binary[i].rw.length() > 1)
words.add(binary[i].lw + binary[i].rw.charAt(1) + binary[i].rw.charAt(0) + binary[i].rw.substring(2));
//replaces & inserts
for (int i = 0; i < binary.length; i++)
for (int j = 0; j < alphabet.length(); j++){
words.add(binary[i].lw + alphabet.charAt(j) + binary[i].rw.substring(1));
words.add(binary[i].lw + alphabet.charAt(j) + binary[i].rw); //inserts
}
//last inserts
for (int i = 0; i < alphabet.length(); i++)
words.add(word + alphabet.charAt(i));
return words;
}
/**
* 編輯距離爲2
* @param word
* @return
*/
private ArrayList<String> editDist2(String word){
ArrayList<String> dist1 = editDist1(word);
ArrayList<String> dist2 = new ArrayList<String>(); //只加入編輯距離爲2且出現在字典中的詞
for (int i = 0; i < dist1.size(); i++){
ArrayList<String> temp = editDist1(dist1.get(i));
for (int j = 0; j < temp.size(); j++)
if(map.containsKey(temp.get(j)))
dist2.add(temp.get(j));
}
return dist2;
}
/**
* 獲得最大概率的詞
* @param arr
* @param islegal 是否都是合法單詞
* @return
*/
public String maxPx(ArrayList<String> arr, boolean islegal){
int frq = 0;
String result = null;
if (islegal){
for (int i = 0; i < arr.size(); i++)
if (frq < map.get(arr.get(i))){
frq = map.get(arr.get(i));
result = arr.get(i);
}
}
else
for (int i = 0; i < arr.size(); i++){
if (map.containsKey(arr.get(i)))
if (frq < map.get(arr.get(i))){
frq = map.get(arr.get(i));
result = arr.get(i);
}
}
return result;
}
/**
* 是否是合法單詞(需要糾正)
* @param word
* @return
*/
private boolean legalSpell(String word){
if (map.containsKey(word))
return true;
return false;
}
/**
* 英文詞糾正
* 糾正數:1
* new Correction(String word).correct()
* @param null
* @return
*/
public String correct(String wd){
if (legalSpell(wd))
return wd;
String r1 = maxPx(editDist1(wd), false);
int score1 = 0;
if (r1 != null)
score1 = map.get(r1);
String r2 = maxPx(editDist2(wd), true);
int score2 = 0;
if (r2 != null)
score2 = map.get(r2);
if (score2 > score1 * pow)
return r2;
if (score2 == score1 * pow)
return wd;
return r1;
}
/**
* 對多個詞進行糾錯
* @param queryterm
* @return false if no illegal word
*/
public boolean correct(String[] queryterm){
boolean hascorrect = false;
//String regex = "[\\u4e00-\\u9fa5]"; //非中文
String regex = "[\\u4e00-\\u9fa5]+";
for (int i = 0; i < queryterm.length; i++){
if (!queryterm[i].matches(regex)){
String temp = correct(queryterm[i]);
if (temp.compareTo(queryterm[i]) != 0){
hascorrect = true;
queryterm[i] = temp;
}
}
}
return hascorrect;
}
public static void main(String[] args){
String[] query = {"calendar", "conclsion", "ture","canlendae", "conclsion", "ture","中文"};
System.out.println(new Correction().correct(query));
long st = System.currentTimeMillis();
for (int i = 0; i < query.length; i++)
System.out.println(query[i]);
System.out.println((System.currentTimeMillis() - st)/ 1000.0 + "s");
}
}
class BinaryWord{
String lw = null;
String rw = null;
public BinaryWord(String lw, String rw){
this.lw = lw;
this.rw = rw;
}
}
效果圖: