單詞拼寫校正原理及實現(貝葉斯推斷)

 

做網頁檢索系統的時候,當時還想加給它加上一個拼寫糾錯的功能(當然,只適合英文...),後來找到了個效果不錯的方法給大家分享一下。

拼寫糾錯、校正是一個提高搜索引擎用戶體驗的很關鍵的一項能力。如我們在Google搜索 saferi,然後會看到它自動幫助我們糾正了輸錯的單詞。

仔細想一想爲什麼它會知道我們輸錯的單詞就是safari,我們會發現主要原因有兩個;

1. 錯寫的單詞與正確單詞的拼寫相似,容易錯寫;這裏safari是否容易錯寫成saferi需要統計數據的支持;爲了簡化問題,我們認爲字形越相近的錯寫率越高,用編輯距離來表示。字形相近要求單詞之間編輯距離小於等於2,這裏saferi與safari編輯距離爲1,後面我們再具體瞭解編輯距離的定義。

2. 正確單詞有很多,除去語義因素外最有可能的單詞,也就是這個單詞的使用頻率了。所以我們確認的標準還有一項就是,單詞使用頻率。

下面介紹一個機器學習拼寫檢查方法,基於貝葉斯定理的拼寫檢查法,主要思想就是上面2條,列舉所有可能的正確拼寫,根據編輯距離以及詞頻從中選取可能性最大的用於校正。

原理:

用戶輸入的錯誤的單詞記做w,用戶想要輸入的拼寫正確的單詞記做c,則

    P(c | w) : 用戶輸錯成w時,想要的單詞是c的概率。

    P(w | c) :   用戶將c錯寫成w的概率,與編輯距離有關。

    P(c) :    正確詞是c的概率,可以認爲是c的使用頻率,需要數據訓練。

根據貝葉斯公式 

    P(c | w) = P(w | c) * P(c) / P(w)

因爲同一次糾正中w是不變的,所以公式中我們不必理會P(w),它是一個常量。比較 P(c | w) 就是比較 P(w | c) * P(c) 的大小。

 

一、P(c)

P(c)替換成“使用頻率”,我們從足夠大的文本庫(詞典)點擊打開鏈接中統計出各個單詞的出現頻率,也可以將頻率歸一化縮小方便比較。    

二、P(w | c)

P(w | c)替換成常數lambda * editDist

editDist編輯距離只計算editDist = 1與editDist = 2的,

editDist1,編輯距離爲1的有下面幾種情況:

 

(1)splits:將word依次按照每一位分割成前後兩半。比如,'abc'會被分割成 [('', 'abc'), ('a', 'bc'), ('ab', 'c'), ('abc', '')] 。

  (2)beletes:依次刪除word的每一位後、所形成的所有新詞。比如,'abc'對應的deletes就是 ['bc', 'ac', 'ab'] 。

  (3)transposes:依次交換word的鄰近兩位,所形成的所有新詞。比如,'abc'對應的transposes就是 ['bac', 'acb'] 。

  (4)replaces:將word的每一位依次替換成其他25個字母,所形成的所有新詞。比如,'abc'對應的replaces就是 ['abc', 'bbc', 'cbc', ... , 'abx', ' aby', 'abz' ] ,一共包含78個詞(26 × 3)。

  (5)inserts:在word的鄰近兩位之間依次插入一個字母,所形成的所有新詞。比如,'abc' 對應的inserts就是['aabc', 'babc', 'cabc', ..., 'abcx', 'abcy', 'abcz'],一共包含104個詞(26 × 4)。

 

 

 

editDist2則是在editDist1得到的單詞集合的基礎上再對它們作以上五種變換,得到所有編輯距離爲2的單詞(無論是否存在,在詞典中不存在的記P(c) = 1)。

 

三、糾正規則

1. 如果拼寫的單詞在詞典中出現的,直接返回。

2. 如果詞典中不存在的,返回編輯距離爲1和2的所有單詞中,P(c | w) * P(c)最大的單詞。

 

Java實現代碼:

參考鏈接:https://github.com/toufuChew/webSearch/tree/master/src/QueryCorrection

詞頻統計(訓練數據)

package QueryCorrection;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import IKAnalyze.IKAnalyze;
/**
 * 生成訓練集
 * @author cq
 *
 */
public class Statistics {
	private static final String dict = "/Users/cq/Desktop/dict.txt";
	private static final String big = "/Users/cq/Downloads/big.txt";
	public Statistics(){
		File dictfile = new File(dict);
		if (dictfile.exists() && dictfile.length() != 0)
			return;
		HashMap<String, Integer> map = new HashMap<String, Integer>();
		try {
			BufferedInputStream bin = new BufferedInputStream(new FileInputStream(new File(big)));
			byte[] b = new byte[65535];
			int ch;
			try {
				while((ch = bin.read(b)) != -1){
					String txt = new String(b, "utf-8");
					String[] arr = IKAnalyze.CNAnalyzerBStr(txt); //會過濾停用詞,如and,the...
					for (int i = 0; i < arr.length; i++) {
						if (map.containsKey(arr[i]))
							map.put(arr[i], map.get(arr[i]) + 1);
						else
							map.put(arr[i], 1);
					}
				}
				Iterator<Map.Entry<String, Integer>> it = map.entrySet().iterator();
				BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(dictfile)));
				while(it.hasNext()){
					Map.Entry<String, Integer> entry = it.next();
					bw.write(entry.getKey() + " " + entry.getValue() + "\n");
				}
				bw.flush();
				bw.close();
				System.out.println("已完成統計!");
			} catch (IOException e) {
				e.printStackTrace();
			}
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		}	
	}
	public static HashMap<String, Integer> words(){
		try {
			BufferedReader br = new BufferedReader(new InputStreamReader(new BufferedInputStream(new FileInputStream(dict))));
			String row;
			HashMap<String, Integer> map = new HashMap<String, Integer>();
			try {
				while((row = br.readLine()) != null){
					String[] s = row.split(" ");
					map.put(s[0], Integer.valueOf(s[1]));
				}
				br.close();
				return map;
			} catch (IOException e) {
				e.printStackTrace();
			}
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		}
		return null;
	}
	public static void main(String[] args){
		new Statistics();
		
	}
}

編輯距離處理及單詞校正

package QueryCorrection;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.regex.Pattern;

public class Correction{
	private HashMap<String, Integer> map = null; //字典
	private static final String alphabet = "abcdefghijklmnopqrstuvwxyz";
	private static final double pow = 0.8;
	public Correction(){
		map = Statistics.words();
	}
	/**
	 * 編輯距離爲1的所有可能的詞
	 * @param word
	 * @return
	 */
	private ArrayList<String> editDist1(String word){
		ArrayList<String> words = new ArrayList<String>();
		BinaryWord[] binary = new BinaryWord[word.length()];
		//split
		for (int i = 0; i < word.length(); i++)
			binary[i] = new BinaryWord(word.substring(0, i), word.substring(i));
		//deletes
		for (int i = 0; i < binary.length; i++)
			if (binary[i].rw.length() > 1)
				words.add(binary[i].lw + binary[i].rw.substring(1));
			else
				words.add(binary[i].lw);
		//transposes
		for (int i = 0; i < binary.length; i++)
			if (binary[i].rw.length() > 1)
				words.add(binary[i].lw + binary[i].rw.charAt(1) + binary[i].rw.charAt(0) + binary[i].rw.substring(2));
		//replaces & inserts
		for (int i = 0; i < binary.length; i++)
			for (int j = 0; j < alphabet.length(); j++){
				words.add(binary[i].lw + alphabet.charAt(j) + binary[i].rw.substring(1));
				words.add(binary[i].lw + alphabet.charAt(j) + binary[i].rw); //inserts
			}
		//last inserts
		for (int i = 0; i < alphabet.length(); i++)
			words.add(word + alphabet.charAt(i));
		return words;
	}
	/**
	 * 編輯距離爲2
	 * @param word
	 * @return
	 */
	private ArrayList<String> editDist2(String word){
		ArrayList<String> dist1 = editDist1(word);
		ArrayList<String> dist2 = new ArrayList<String>(); //只加入編輯距離爲2且出現在字典中的詞
		for (int i = 0; i < dist1.size(); i++){
			ArrayList<String> temp = editDist1(dist1.get(i));
			for (int j = 0; j < temp.size(); j++)
				if(map.containsKey(temp.get(j)))
					dist2.add(temp.get(j));
		}
		
		return dist2;
	}
	/**
	 * 獲得最大概率的詞
	 * @param arr
	 * @param islegal 是否都是合法單詞
	 * @return
	 */
	public String maxPx(ArrayList<String> arr, boolean islegal){
		int frq = 0;
		String result = null;
		if (islegal){
			for (int i = 0; i < arr.size(); i++)
				if (frq < map.get(arr.get(i))){
					frq = map.get(arr.get(i));
					result = arr.get(i);
				}
		}
		else
			for (int i = 0; i < arr.size(); i++){
				if (map.containsKey(arr.get(i)))
					if (frq < map.get(arr.get(i))){
						frq = map.get(arr.get(i));
						result = arr.get(i);
					}
			}
		return result;
	}
	/**
	 * 是否是合法單詞(需要糾正)
	 * @param word
	 * @return
	 */
	private boolean legalSpell(String word){
		if (map.containsKey(word))
			return true;
		return false;
	}
	/**
	 * 英文詞糾正
	 * 糾正數:1
	 * new Correction(String word).correct()
	 * @param null
	 * @return
	 */
	public String correct(String wd){
		if (legalSpell(wd))
			return wd;
		String r1 = maxPx(editDist1(wd), false);
		int score1 = 0;
		if (r1 != null)
			score1 = map.get(r1);
		String r2 = maxPx(editDist2(wd), true);
		int score2 = 0;
		if (r2 != null)
			score2 = map.get(r2);
		if (score2 > score1 * pow)
			return r2;
		if (score2 == score1 * pow)
			return wd;
		return r1;
	}
	/**
	 * 對多個詞進行糾錯
	 * @param queryterm
	 * @return false if no illegal word
	 */
	public boolean correct(String[] queryterm){
		boolean hascorrect = false;
		//String regex = "[\\u4e00-\\u9fa5]"; //非中文
		String regex = "[\\u4e00-\\u9fa5]+";
		for (int i = 0; i < queryterm.length; i++){
			if (!queryterm[i].matches(regex)){
				String temp = correct(queryterm[i]);
				if (temp.compareTo(queryterm[i]) != 0){
					hascorrect = true;
					queryterm[i] = temp;
				}
			}
		}
		return hascorrect;
	}
	public static void main(String[] args){
		String[] query = {"calendar", "conclsion", "ture","canlendae", "conclsion", "ture","中文"};
		System.out.println(new Correction().correct(query));
		long st = System.currentTimeMillis();
		for (int i = 0; i < query.length; i++)
			System.out.println(query[i]);
		System.out.println((System.currentTimeMillis() - st)/ 1000.0 + "s");
	}
}

class BinaryWord{
	String lw = null;
	String rw = null;
	public BinaryWord(String lw, String rw){
		this.lw = lw;
		this.rw = rw;
	}
}

效果圖:

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章