单词拼写校正原理及实现(贝叶斯推断)

 

做网页检索系统的时候,当时还想加给它加上一个拼写纠错的功能(当然,只适合英文...),后来找到了个效果不错的方法给大家分享一下。

拼写纠错、校正是一个提高搜索引擎用户体验的很关键的一项能力。如我们在Google搜索 saferi,然后会看到它自动帮助我们纠正了输错的单词。

仔细想一想为什么它会知道我们输错的单词就是safari,我们会发现主要原因有两个;

1. 错写的单词与正确单词的拼写相似,容易错写;这里safari是否容易错写成saferi需要统计数据的支持;为了简化问题,我们认为字形越相近的错写率越高,用编辑距离来表示。字形相近要求单词之间编辑距离小于等于2,这里saferi与safari编辑距离为1,后面我们再具体了解编辑距离的定义。

2. 正确单词有很多,除去语义因素外最有可能的单词,也就是这个单词的使用频率了。所以我们确认的标准还有一项就是,单词使用频率。

下面介绍一个机器学习拼写检查方法,基于贝叶斯定理的拼写检查法,主要思想就是上面2条,列举所有可能的正确拼写,根据编辑距离以及词频从中选取可能性最大的用于校正。

原理:

用户输入的错误的单词记做w,用户想要输入的拼写正确的单词记做c,则

    P(c | w) : 用户输错成w时,想要的单词是c的概率。

    P(w | c) :   用户将c错写成w的概率,与编辑距离有关。

    P(c) :    正确词是c的概率,可以认为是c的使用频率,需要数据训练。

根据贝叶斯公式 

    P(c | w) = P(w | c) * P(c) / P(w)

因为同一次纠正中w是不变的,所以公式中我们不必理会P(w),它是一个常量。比较 P(c | w) 就是比较 P(w | c) * P(c) 的大小。

 

一、P(c)

P(c)替换成“使用频率”,我们从足够大的文本库(词典)点击打开链接中统计出各个单词的出现频率,也可以将频率归一化缩小方便比较。    

二、P(w | c)

P(w | c)替换成常数lambda * editDist

editDist编辑距离只计算editDist = 1与editDist = 2的,

editDist1,编辑距离为1的有下面几种情况:

 

(1)splits:将word依次按照每一位分割成前后两半。比如,'abc'会被分割成 [('', 'abc'), ('a', 'bc'), ('ab', 'c'), ('abc', '')] 。

  (2)beletes:依次删除word的每一位后、所形成的所有新词。比如,'abc'对应的deletes就是 ['bc', 'ac', 'ab'] 。

  (3)transposes:依次交换word的邻近两位,所形成的所有新词。比如,'abc'对应的transposes就是 ['bac', 'acb'] 。

  (4)replaces:将word的每一位依次替换成其他25个字母,所形成的所有新词。比如,'abc'对应的replaces就是 ['abc', 'bbc', 'cbc', ... , 'abx', ' aby', 'abz' ] ,一共包含78个词(26 × 3)。

  (5)inserts:在word的邻近两位之间依次插入一个字母,所形成的所有新词。比如,'abc' 对应的inserts就是['aabc', 'babc', 'cabc', ..., 'abcx', 'abcy', 'abcz'],一共包含104个词(26 × 4)。

 

 

 

editDist2则是在editDist1得到的单词集合的基础上再对它们作以上五种变换,得到所有编辑距离为2的单词(无论是否存在,在词典中不存在的记P(c) = 1)。

 

三、纠正规则

1. 如果拼写的单词在词典中出现的,直接返回。

2. 如果词典中不存在的,返回编辑距离为1和2的所有单词中,P(c | w) * P(c)最大的单词。

 

Java实现代码:

参考链接:https://github.com/toufuChew/webSearch/tree/master/src/QueryCorrection

词频统计(训练数据)

package QueryCorrection;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import IKAnalyze.IKAnalyze;
/**
 * 生成训练集
 * @author cq
 *
 */
public class Statistics {
	private static final String dict = "/Users/cq/Desktop/dict.txt";
	private static final String big = "/Users/cq/Downloads/big.txt";
	public Statistics(){
		File dictfile = new File(dict);
		if (dictfile.exists() && dictfile.length() != 0)
			return;
		HashMap<String, Integer> map = new HashMap<String, Integer>();
		try {
			BufferedInputStream bin = new BufferedInputStream(new FileInputStream(new File(big)));
			byte[] b = new byte[65535];
			int ch;
			try {
				while((ch = bin.read(b)) != -1){
					String txt = new String(b, "utf-8");
					String[] arr = IKAnalyze.CNAnalyzerBStr(txt); //会过滤停用词,如and,the...
					for (int i = 0; i < arr.length; i++) {
						if (map.containsKey(arr[i]))
							map.put(arr[i], map.get(arr[i]) + 1);
						else
							map.put(arr[i], 1);
					}
				}
				Iterator<Map.Entry<String, Integer>> it = map.entrySet().iterator();
				BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(dictfile)));
				while(it.hasNext()){
					Map.Entry<String, Integer> entry = it.next();
					bw.write(entry.getKey() + " " + entry.getValue() + "\n");
				}
				bw.flush();
				bw.close();
				System.out.println("已完成统计!");
			} catch (IOException e) {
				e.printStackTrace();
			}
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		}	
	}
	public static HashMap<String, Integer> words(){
		try {
			BufferedReader br = new BufferedReader(new InputStreamReader(new BufferedInputStream(new FileInputStream(dict))));
			String row;
			HashMap<String, Integer> map = new HashMap<String, Integer>();
			try {
				while((row = br.readLine()) != null){
					String[] s = row.split(" ");
					map.put(s[0], Integer.valueOf(s[1]));
				}
				br.close();
				return map;
			} catch (IOException e) {
				e.printStackTrace();
			}
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		}
		return null;
	}
	public static void main(String[] args){
		new Statistics();
		
	}
}

编辑距离处理及单词校正

package QueryCorrection;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.regex.Pattern;

public class Correction{
	private HashMap<String, Integer> map = null; //字典
	private static final String alphabet = "abcdefghijklmnopqrstuvwxyz";
	private static final double pow = 0.8;
	public Correction(){
		map = Statistics.words();
	}
	/**
	 * 编辑距离为1的所有可能的词
	 * @param word
	 * @return
	 */
	private ArrayList<String> editDist1(String word){
		ArrayList<String> words = new ArrayList<String>();
		BinaryWord[] binary = new BinaryWord[word.length()];
		//split
		for (int i = 0; i < word.length(); i++)
			binary[i] = new BinaryWord(word.substring(0, i), word.substring(i));
		//deletes
		for (int i = 0; i < binary.length; i++)
			if (binary[i].rw.length() > 1)
				words.add(binary[i].lw + binary[i].rw.substring(1));
			else
				words.add(binary[i].lw);
		//transposes
		for (int i = 0; i < binary.length; i++)
			if (binary[i].rw.length() > 1)
				words.add(binary[i].lw + binary[i].rw.charAt(1) + binary[i].rw.charAt(0) + binary[i].rw.substring(2));
		//replaces & inserts
		for (int i = 0; i < binary.length; i++)
			for (int j = 0; j < alphabet.length(); j++){
				words.add(binary[i].lw + alphabet.charAt(j) + binary[i].rw.substring(1));
				words.add(binary[i].lw + alphabet.charAt(j) + binary[i].rw); //inserts
			}
		//last inserts
		for (int i = 0; i < alphabet.length(); i++)
			words.add(word + alphabet.charAt(i));
		return words;
	}
	/**
	 * 编辑距离为2
	 * @param word
	 * @return
	 */
	private ArrayList<String> editDist2(String word){
		ArrayList<String> dist1 = editDist1(word);
		ArrayList<String> dist2 = new ArrayList<String>(); //只加入编辑距离为2且出现在字典中的词
		for (int i = 0; i < dist1.size(); i++){
			ArrayList<String> temp = editDist1(dist1.get(i));
			for (int j = 0; j < temp.size(); j++)
				if(map.containsKey(temp.get(j)))
					dist2.add(temp.get(j));
		}
		
		return dist2;
	}
	/**
	 * 获得最大概率的词
	 * @param arr
	 * @param islegal 是否都是合法单词
	 * @return
	 */
	public String maxPx(ArrayList<String> arr, boolean islegal){
		int frq = 0;
		String result = null;
		if (islegal){
			for (int i = 0; i < arr.size(); i++)
				if (frq < map.get(arr.get(i))){
					frq = map.get(arr.get(i));
					result = arr.get(i);
				}
		}
		else
			for (int i = 0; i < arr.size(); i++){
				if (map.containsKey(arr.get(i)))
					if (frq < map.get(arr.get(i))){
						frq = map.get(arr.get(i));
						result = arr.get(i);
					}
			}
		return result;
	}
	/**
	 * 是否是合法单词(需要纠正)
	 * @param word
	 * @return
	 */
	private boolean legalSpell(String word){
		if (map.containsKey(word))
			return true;
		return false;
	}
	/**
	 * 英文词纠正
	 * 纠正数:1
	 * new Correction(String word).correct()
	 * @param null
	 * @return
	 */
	public String correct(String wd){
		if (legalSpell(wd))
			return wd;
		String r1 = maxPx(editDist1(wd), false);
		int score1 = 0;
		if (r1 != null)
			score1 = map.get(r1);
		String r2 = maxPx(editDist2(wd), true);
		int score2 = 0;
		if (r2 != null)
			score2 = map.get(r2);
		if (score2 > score1 * pow)
			return r2;
		if (score2 == score1 * pow)
			return wd;
		return r1;
	}
	/**
	 * 对多个词进行纠错
	 * @param queryterm
	 * @return false if no illegal word
	 */
	public boolean correct(String[] queryterm){
		boolean hascorrect = false;
		//String regex = "[\\u4e00-\\u9fa5]"; //非中文
		String regex = "[\\u4e00-\\u9fa5]+";
		for (int i = 0; i < queryterm.length; i++){
			if (!queryterm[i].matches(regex)){
				String temp = correct(queryterm[i]);
				if (temp.compareTo(queryterm[i]) != 0){
					hascorrect = true;
					queryterm[i] = temp;
				}
			}
		}
		return hascorrect;
	}
	public static void main(String[] args){
		String[] query = {"calendar", "conclsion", "ture","canlendae", "conclsion", "ture","中文"};
		System.out.println(new Correction().correct(query));
		long st = System.currentTimeMillis();
		for (int i = 0; i < query.length; i++)
			System.out.println(query[i]);
		System.out.println((System.currentTimeMillis() - st)/ 1000.0 + "s");
	}
}

class BinaryWord{
	String lw = null;
	String rw = null;
	public BinaryWord(String lw, String rw){
		this.lw = lw;
		this.rw = rw;
	}
}

效果图:

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章