Sparse稀疏檢索介紹
在處理大規模文本數據時,我們經常會遇到一些挑戰,比如如何有效地表示和檢索文檔,當前主要有兩個主要方法,傳統的文本BM25檢索,以及將文檔映射到向量空間的向量檢索。
BM25效果是有上限的,但是文本檢索在一些場景仍具備較好的魯棒性和可解釋性,因此不可或缺,那麼在NN模型一統天下的今天,是否能用NN模型來增強文本檢索呢,答案是有的,也就是我們今天要說的sparse 稀疏檢索。
傳統的BM25文本檢索其實就是典型的sparse稀疏檢索,在BM25檢索算法中,向量維度爲整個詞表,但是其中大部分爲0,只有出現的關鍵詞或子詞(tokens)有值,其餘的值都設爲零。這種表示方法不僅節省了存儲空間,而且提高了檢索效率。
向量的形式, 大概類似:
{
'19828': 0.2085,
'3508': 0.2374,
'7919': 0.2544,
'43': 0.0897,
'6': 0.0967,
'79299': 0.3079
}
key是term的編號,value是NN模型計算出來的權重。
稀疏向量與傳統方法的比較
當前流行的sparse檢索,大概是通過transformer模型,爲doc中的term計算weight,這樣與傳統的BM25等基於頻率的方法相比,sparse向量可以利用神經網絡的力量,提高了檢索的準確性和效率。BM25雖然能夠計算文檔的相關性,但它無法理解詞語的含義或上下文的重要性。而稀疏向量則能夠通過神經網絡捕捉到這些細微的差別。
稀疏向量的優勢
- 計算效率:稀疏向量在處理包含零元素的操作時,通常比密集向量更高效。
- 信息密度:稀疏向量專注於關鍵特徵,而不是捕捉所有細微的關係,這使得它們在文本搜索等應用中更爲高效。
- 領域適應性:稀疏向量在處理專業術語或罕見關鍵詞時表現出色,例如在醫療領域,許多專業術語不會出現在通用詞彙表中,稀疏向量能夠更好地捕捉這些術語的細微差別。
稀疏向量舉例
SPLADE 是一款開源的transformer模型,提供sparse向量生成,下面是效果對比,可以看到sparse介於BM25和dense之間,比BM25效果好。
Model | MRR@10 (MS MARCO Dev) | Type |
---|---|---|
BM25 | 0.184 | Sparse |
TCT-ColBERT | 0.359 | Dense |
doc2query-T5 link | 0.277 | Sparse |
SPLADE | 0.322 | Sparse |
SPLADE-max | 0.340 | Sparse |
SPLADE-doc | 0.322 | Sparse |
DistilSPLADE-max | 0.368 | Sparse |
Sparse稀疏檢索實踐
模型介紹
國內的開源模型中,BAAI的BGE-M3提供sparse向量向量生成能力,我們用這個來進行實踐。
BGE是通過RetroMAE的預訓練方式訓練的類似bert的預訓練模型。
常規的Bert預訓練採用了將輸入文本隨機Mask再輸出完整文本這種自監督式的任務,RetroMAE採用一種巧妙的方式提高了Embedding的表徵能力,具體操作是:將低掩碼率的的文本A輸入到Encoder種得到Embedding向量,將該Embedding向量與高掩碼率的文本A輸入到淺層的Decoder向量中,輸出完整文本。這種預訓練方式迫使Encoder生成強大的Embedding向量,在表徵模型中提升效果顯著。
向量生成
-
先安裝
!pip install -U FlagEmbedding
-
然後引入模型
from FlagEmbedding import BGEM3FlagModel
model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
編寫一個函數用於計算embedding:
def embed_with_progress(model, docs, batch_size):
batch_count = int(len(docs) / batch_size) + 1
print("start embedding docs", batch_count)
query_embeddings = []
for i in tqdm(range(batch_count), desc="Embedding...", unit="batch"):
start = i * batch_size
end = min(len(docs), (i + 1) * batch_size)
if end <= start:
break
output = model.encode(docs[start:end], return_dense=False, return_sparse=True, return_colbert_vecs=False)
query_embeddings.extend(output['lexical_weights'])
return query_embeddings
然後分別計算query和doc的:
query_embeddings = embed_with_progress(model, test_sets.queries, batch_size)
doc_embeddings = embed_with_progress(model, test_sets.docs, batch_size)
然後是計算query和doc的分數,model.compute_lexical_matching_score
(交集的權重相乘,然後累加),注意下面的代碼是query和每個doc都計算了,計算量會比較大,在工程實踐中需要用類似向量索引的方案(當前qdrant、milvus等都提供sparse檢索支持)
# 檢索topk
recall_results = []
import numpy as np
for i in tqdm(range(len(test_sets.query_ids)), desc="recall...", unit="query"):
query_embeding = query_embeddings[i]
query_id = test_sets.query_ids[i]
if query_id not in test_sets.relevant_docs:
continue
socres = [model.compute_lexical_matching_score(query_embeding, doc_embedding) for doc_embedding in doc_embeddings]
topk_doc_ids = [test_sets.doc_ids[i] for i in np.argsort(socres)[-20:][::-1]]
recall_results.append(json.dumps({"query": test_sets.queries[i], "topk_doc_ids": topk_doc_ids, "marked_doc_ids": list(test_sets.relevant_docs[query_id].keys())}))
# recall_results 寫入到文件
with open("recall_results.txt", "w", encoding="utf-8") as f:
f.write("\n".join(recall_results))
最後,基於測試集,我們可以計算召回率:
import json
# 讀取 JSON line 文件
topk_doc_ids_list = []
marked_doc_ids_list = []
with open("recall_results.txt", "r") as file:
for line in file:
data = json.loads(line)
topk_doc_ids_list.append(data["topk_doc_ids"])
marked_doc_ids_list.append(data["marked_doc_ids"])
# 計算 recall@k
def recall_at_k(k):
recalls = []
for topk_doc_ids, marked_doc_ids in zip(topk_doc_ids_list, marked_doc_ids_list):
# 提取前 k 個召回結果
topk = set(topk_doc_ids[:k])
# 計算交集
intersection = topk.intersection(set(marked_doc_ids))
# 計算 recall
recall = len(intersection) / min(len(marked_doc_ids), k)
recalls.append(recall)
# 計算平均 recall
average_recall = sum(recalls) / len(recalls)
return average_recall
# 計算 recall@5, 10, 20
recall_at_5 = recall_at_k(5)
recall_at_10 = recall_at_k(10)
recall_at_20 = recall_at_k(20)
print("Recall@5:", recall_at_5)
print("Recall@10:", recall_at_10)
print("Recall@20:", recall_at_20)
在測試集中,測試結果:
Recall@5: 0.7350086355785777
Recall@10: 0.8035261945883735
Recall@20: 0.8926130345462158
在這個測試集上,比BM25測試出來的結果要更好,但是僅憑這個尚不能否定BM25,需要綜合看各自的覆蓋度,綜合考慮成本與效果。
參考
- Sparse Vectors in Qdrant: Pure Vector-based Hybrid Search https://qdrant.tech/articles/sparse-vectors/
- BGE(BAAI General Embedding)解讀 https://zhuanlan.zhihu.com/p/690856333