BK樹或者稱爲Burkhard-Keller樹,是一種基於樹的數據結構。用於快速查找近似字符串匹配,比方說拼寫糾錯,或模糊查找,當搜索”aeek”時能返回與其最相似的字符串”seek”和”peek”。
在構建BK樹之前,我們需要定義一種用於比較字符串相似度的度量方法。通常都是採用編輯距離(Levenshtein Distance),這是一種用於表示兩個字符串相互轉換需要進行的最少編輯步數。
在確定度量方法後,可以構建出基於該比較方法的度量空間,該空間具有以下3種特性:
假設存在字符串A、B、C,d(A,B)表示兩字符串的編輯距離
1.如果d(A,B)=0,那麼表示A,B字符串相等
2.d(A,B)與d(B,A)相等
3.d(A,C)>= d(A,B)+d(B,C)
最後一條又叫做三角不等式,表示A與C的編輯距離一定大於A變爲B後再變爲C的編輯距離和。
BK建樹
首先我們隨便找一個單詞作爲根(比如GAME)。以後插入一個單詞時首先計算單詞與根的Levenshtein距離:如果這個距離值是該節點處頭一次出現,建立一個新的兒子節點;否則沿着對應的邊遞歸下去。例如,我們插入單詞FAME,它與GAME的距離爲1,於是新建一個兒子,連一條標號爲1的邊;下一次插入GAIN,算得它與GAME的距離爲2,於是放在編號爲2的邊下。再下次我們插入GATE,它與GAME距離爲1,於是沿着那條編號爲1的邊下去,遞歸地插入到FAME所在子樹;GATE與FAME的距離爲2,於是把GATE放在FAME節點下,邊的編號爲2。
BK查詢
如果我們需要返回與錯誤單詞距離不超過n的單詞,這個錯誤單詞與樹根所對應的單詞距離爲d,那麼接下來我們只需要遞歸地考慮編號在d-n到d+n範圍內的邊所連接的子樹。由於n通常很小,因此每次與某個節點進行比較時都可以排除很多子樹。
實現:
整體實現需要有構建樹與查詢樹兩塊功能,查詢時需要返回編輯距離與節點的字符串
1.首先實現編輯距離計算方法
def calculate_edit_distance(word1, word2):
len1 = len(word1)
len2 = len(word2)
dp = np.zeros((len1 + 1, len2 + 1))
for i in range(len1 + 1):
dp[i][0] = i
for j in range(len2 + 1):
dp[0][j] = j
for i in range(1, len1 + 1):
for j in range(1, len2 + 1):
delta = 0 if word1[i - 1] == word2[j - 1] else 1
dp[i][j] = min(dp[i - 1][j - 1] + delta, min(dp[i - 1][j] + 1, dp[i][j - 1] + 1))
return int(dp[len1][len2])
2.實現結果返回類
class ResultNode:
def __init__(self, data, distance):
self.data = data
self.distance = distance
3.實現節點類
class TreeNode:
def __init__(self, data):
self.data = data
self.child_node_dict = {}
def put(self, chars):
distance = ed.calculate_edit_distance(chars, self.data)
if distance == 0:
return
keys = self.child_node_dict.keys()
if distance in keys:
self.child_node_dict[distance].put(chars)
else:
self.child_node_dict[distance] = TreeNode(chars)
def query(self, target_char, n):
results = []
keys = self.child_node_dict.keys()
distance = ed.calculate_edit_distance(target_char, self.data)
if distance <= n:
results.append(ResultNode(self.data, distance))
if distance != 0:
for query_distance in range(max(distance - n, 1), distance + n + 1):
if query_distance not in keys:
continue
value_node = self.child_node_dict[query_distance]
results += value_node.query(target_char, n)
return results
def get_all_data(self):
results = []
keys = self.child_node_dict.keys()
values = self.child_node_dict.values()
results += [node.data for node in values]
for key in keys:
value_node = self.child_node_dict[key]
results += value_node.get_all_data()
return results
4.實現樹類
class BKTree:
def __init__(self, root_chars):
self.root_node = TreeNode(root_chars)
def put(self, chars):
self.root_node.put(chars)
def query(self, target_char, n):
if self.root_node is None:
return ResultNode(target_char, 0)
else:
queries = self.root_node.query(target_char, n)
if len(queries) == 0:
return ResultNode(target_char, 0)
else:
queries.sort(key=lambda x: x.distance, reverse=False)
return queries[0]
def get_all_data(self):
if self.root_node is None:
return []
else:
return self.root_node.get_all_data()
5.實現樹的保存和恢復
import pickle
import os
import random
from model import BKTree
from utils import read_dict
bk_tree_path = 'bk_tree.pkl'
def dump_bk_tree(bk_tree):
with open(bk_tree_path, 'wb') as f:
pickle.dump(bk_tree, f)
def load_bk_tree():
if os.path.exists(bk_tree_path):
print('load build tree')
with open(bk_tree_path, 'rb') as f:
return pickle.load(f)
else:
char_list = read_dict('dict_en.txt')
randint = random.randint(0, len(char_list) - 1)
bk_tree = BKTree(char_list[randint])
print('start build tree')
for index, item in enumerate(char_list):
print('build tree:' + str(index) + '/' + str(len(char_list)))
bk_tree.put(item)
dump_bk_tree(bk_tree)
return bk_tree
6.調用測試
from load_tree import load_bk_tree
from datetime import datetime
bk_tree = load_bk_tree()
query_word = 'lavishnessa'
be = datetime.now()
query = bk_tree.query(query_word, 3)
delta_time = datetime.now() - be
print("spent:" + str(delta_time))
print(query.data)
print(query.distance)
直接使用python實現效率非常慢,可以考慮使用cython加速計算編輯距離的邏輯,可以達到近15倍的加速效率。
7.Cython加速 calculate edit distance
from libc.stdlib cimport malloc, free
def calculate_edit_distance(word1, word2):
len1 = len(word1)
len2 = len(word2)
cdef int** dp = <int**> malloc((len1 + 1) * sizeof(int*))
for i in range(len1 + 1):
dp[i] = <int*> malloc((len2 + 1) * sizeof(int))
for i in range(len1 + 1):
dp[i][0] = i
for j in range(len2 + 1):
dp[0][j] = j
for i in range(1, len1 + 1):
for j in range(1, len2 + 1):
delta = 0 if word1[i - 1] == word2[j - 1] else 1
dp[i][j] = min(dp[i - 1][j - 1] + delta, min(dp[i - 1][j] + 1, dp[i][j - 1] + 1))
cdef result = dp[len1][len2]
for i in range(len1 + 1):
free(dp[i])
free(dp)
return result
cython編寫和python實現幾乎一致,此處只對於動態對象進行了內存管理