做网页检索系统的时候,当时还想加给它加上一个拼写纠错的功能(当然,只适合英文...),后来找到了个效果不错的方法给大家分享一下。
拼写纠错、校正是一个提高搜索引擎用户体验的很关键的一项能力。如我们在Google搜索 saferi,然后会看到它自动帮助我们纠正了输错的单词。
仔细想一想为什么它会知道我们输错的单词就是safari,我们会发现主要原因有两个;
1. 错写的单词与正确单词的拼写相似,容易错写;这里safari是否容易错写成saferi需要统计数据的支持;为了简化问题,我们认为字形越相近的错写率越高,用编辑距离来表示。字形相近要求单词之间编辑距离小于等于2,这里saferi与safari编辑距离为1,后面我们再具体了解编辑距离的定义。
2. 正确单词有很多,除去语义因素外最有可能的单词,也就是这个单词的使用频率了。所以我们确认的标准还有一项就是,单词使用频率。
下面介绍一个机器学习拼写检查方法,基于贝叶斯定理的拼写检查法,主要思想就是上面2条,列举所有可能的正确拼写,根据编辑距离以及词频从中选取可能性最大的用于校正。
原理:
用户输入的错误的单词记做w,用户想要输入的拼写正确的单词记做c,则
P(c | w) : 用户输错成w时,想要的单词是c的概率。
P(w | c) : 用户将c错写成w的概率,与编辑距离有关。
P(c) : 正确词是c的概率,可以认为是c的使用频率,需要数据训练。
根据贝叶斯公式
P(c | w) = P(w | c) * P(c) / P(w)
因为同一次纠正中w是不变的,所以公式中我们不必理会P(w),它是一个常量。比较 P(c | w) 就是比较 P(w | c) * P(c) 的大小。
一、P(c)
P(c)替换成“使用频率”,我们从足够大的文本库(词典)点击打开链接中统计出各个单词的出现频率,也可以将频率归一化缩小方便比较。
二、P(w | c)
P(w | c)替换成常数lambda * editDist
editDist编辑距离只计算editDist = 1与editDist = 2的,
editDist1,编辑距离为1的有下面几种情况:
(1)splits:将word依次按照每一位分割成前后两半。比如,'abc'会被分割成 [('', 'abc'), ('a', 'bc'), ('ab', 'c'), ('abc', '')] 。
(2)beletes:依次删除word的每一位后、所形成的所有新词。比如,'abc'对应的deletes就是 ['bc', 'ac', 'ab'] 。
(3)transposes:依次交换word的邻近两位,所形成的所有新词。比如,'abc'对应的transposes就是 ['bac', 'acb'] 。
(4)replaces:将word的每一位依次替换成其他25个字母,所形成的所有新词。比如,'abc'对应的replaces就是 ['abc', 'bbc', 'cbc', ... , 'abx', ' aby', 'abz' ] ,一共包含78个词(26 × 3)。
(5)inserts:在word的邻近两位之间依次插入一个字母,所形成的所有新词。比如,'abc' 对应的inserts就是['aabc', 'babc', 'cabc', ..., 'abcx', 'abcy', 'abcz'],一共包含104个词(26 × 4)。
editDist2则是在editDist1得到的单词集合的基础上再对它们作以上五种变换,得到所有编辑距离为2的单词(无论是否存在,在词典中不存在的记P(c) = 1)。
三、纠正规则
1. 如果拼写的单词在词典中出现的,直接返回。
2. 如果词典中不存在的,返回编辑距离为1和2的所有单词中,P(c | w) * P(c)最大的单词。
Java实现代码:
参考链接:https://github.com/toufuChew/webSearch/tree/master/src/QueryCorrection
词频统计(训练数据)
package QueryCorrection;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import IKAnalyze.IKAnalyze;
/**
* 生成训练集
* @author cq
*
*/
public class Statistics {
private static final String dict = "/Users/cq/Desktop/dict.txt";
private static final String big = "/Users/cq/Downloads/big.txt";
public Statistics(){
File dictfile = new File(dict);
if (dictfile.exists() && dictfile.length() != 0)
return;
HashMap<String, Integer> map = new HashMap<String, Integer>();
try {
BufferedInputStream bin = new BufferedInputStream(new FileInputStream(new File(big)));
byte[] b = new byte[65535];
int ch;
try {
while((ch = bin.read(b)) != -1){
String txt = new String(b, "utf-8");
String[] arr = IKAnalyze.CNAnalyzerBStr(txt); //会过滤停用词,如and,the...
for (int i = 0; i < arr.length; i++) {
if (map.containsKey(arr[i]))
map.put(arr[i], map.get(arr[i]) + 1);
else
map.put(arr[i], 1);
}
}
Iterator<Map.Entry<String, Integer>> it = map.entrySet().iterator();
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(dictfile)));
while(it.hasNext()){
Map.Entry<String, Integer> entry = it.next();
bw.write(entry.getKey() + " " + entry.getValue() + "\n");
}
bw.flush();
bw.close();
System.out.println("已完成统计!");
} catch (IOException e) {
e.printStackTrace();
}
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
public static HashMap<String, Integer> words(){
try {
BufferedReader br = new BufferedReader(new InputStreamReader(new BufferedInputStream(new FileInputStream(dict))));
String row;
HashMap<String, Integer> map = new HashMap<String, Integer>();
try {
while((row = br.readLine()) != null){
String[] s = row.split(" ");
map.put(s[0], Integer.valueOf(s[1]));
}
br.close();
return map;
} catch (IOException e) {
e.printStackTrace();
}
} catch (FileNotFoundException e) {
e.printStackTrace();
}
return null;
}
public static void main(String[] args){
new Statistics();
}
}
编辑距离处理及单词校正
package QueryCorrection;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.regex.Pattern;
public class Correction{
private HashMap<String, Integer> map = null; //字典
private static final String alphabet = "abcdefghijklmnopqrstuvwxyz";
private static final double pow = 0.8;
public Correction(){
map = Statistics.words();
}
/**
* 编辑距离为1的所有可能的词
* @param word
* @return
*/
private ArrayList<String> editDist1(String word){
ArrayList<String> words = new ArrayList<String>();
BinaryWord[] binary = new BinaryWord[word.length()];
//split
for (int i = 0; i < word.length(); i++)
binary[i] = new BinaryWord(word.substring(0, i), word.substring(i));
//deletes
for (int i = 0; i < binary.length; i++)
if (binary[i].rw.length() > 1)
words.add(binary[i].lw + binary[i].rw.substring(1));
else
words.add(binary[i].lw);
//transposes
for (int i = 0; i < binary.length; i++)
if (binary[i].rw.length() > 1)
words.add(binary[i].lw + binary[i].rw.charAt(1) + binary[i].rw.charAt(0) + binary[i].rw.substring(2));
//replaces & inserts
for (int i = 0; i < binary.length; i++)
for (int j = 0; j < alphabet.length(); j++){
words.add(binary[i].lw + alphabet.charAt(j) + binary[i].rw.substring(1));
words.add(binary[i].lw + alphabet.charAt(j) + binary[i].rw); //inserts
}
//last inserts
for (int i = 0; i < alphabet.length(); i++)
words.add(word + alphabet.charAt(i));
return words;
}
/**
* 编辑距离为2
* @param word
* @return
*/
private ArrayList<String> editDist2(String word){
ArrayList<String> dist1 = editDist1(word);
ArrayList<String> dist2 = new ArrayList<String>(); //只加入编辑距离为2且出现在字典中的词
for (int i = 0; i < dist1.size(); i++){
ArrayList<String> temp = editDist1(dist1.get(i));
for (int j = 0; j < temp.size(); j++)
if(map.containsKey(temp.get(j)))
dist2.add(temp.get(j));
}
return dist2;
}
/**
* 获得最大概率的词
* @param arr
* @param islegal 是否都是合法单词
* @return
*/
public String maxPx(ArrayList<String> arr, boolean islegal){
int frq = 0;
String result = null;
if (islegal){
for (int i = 0; i < arr.size(); i++)
if (frq < map.get(arr.get(i))){
frq = map.get(arr.get(i));
result = arr.get(i);
}
}
else
for (int i = 0; i < arr.size(); i++){
if (map.containsKey(arr.get(i)))
if (frq < map.get(arr.get(i))){
frq = map.get(arr.get(i));
result = arr.get(i);
}
}
return result;
}
/**
* 是否是合法单词(需要纠正)
* @param word
* @return
*/
private boolean legalSpell(String word){
if (map.containsKey(word))
return true;
return false;
}
/**
* 英文词纠正
* 纠正数:1
* new Correction(String word).correct()
* @param null
* @return
*/
public String correct(String wd){
if (legalSpell(wd))
return wd;
String r1 = maxPx(editDist1(wd), false);
int score1 = 0;
if (r1 != null)
score1 = map.get(r1);
String r2 = maxPx(editDist2(wd), true);
int score2 = 0;
if (r2 != null)
score2 = map.get(r2);
if (score2 > score1 * pow)
return r2;
if (score2 == score1 * pow)
return wd;
return r1;
}
/**
* 对多个词进行纠错
* @param queryterm
* @return false if no illegal word
*/
public boolean correct(String[] queryterm){
boolean hascorrect = false;
//String regex = "[\\u4e00-\\u9fa5]"; //非中文
String regex = "[\\u4e00-\\u9fa5]+";
for (int i = 0; i < queryterm.length; i++){
if (!queryterm[i].matches(regex)){
String temp = correct(queryterm[i]);
if (temp.compareTo(queryterm[i]) != 0){
hascorrect = true;
queryterm[i] = temp;
}
}
}
return hascorrect;
}
public static void main(String[] args){
String[] query = {"calendar", "conclsion", "ture","canlendae", "conclsion", "ture","中文"};
System.out.println(new Correction().correct(query));
long st = System.currentTimeMillis();
for (int i = 0; i < query.length; i++)
System.out.println(query[i]);
System.out.println((System.currentTimeMillis() - st)/ 1000.0 + "s");
}
}
class BinaryWord{
String lw = null;
String rw = null;
public BinaryWord(String lw, String rw){
this.lw = lw;
this.rw = rw;
}
}
效果图: