Sparse稀疏檢索介紹與實踐

Sparse稀疏檢索介紹

在處理大規模文本數據時,我們經常會遇到一些挑戰,比如如何有效地表示和檢索文檔,當前主要有兩個主要方法,傳統的文本BM25檢索,以及將文檔映射到向量空間的向量檢索。

BM25效果是有上限的,但是文本檢索在一些場景仍具備較好的魯棒性和可解釋性,因此不可或缺,那麼在NN模型一統天下的今天,是否能用NN模型來增強文本檢索呢,答案是有的,也就是我們今天要說的sparse 稀疏檢索。

傳統的BM25文本檢索其實就是典型的sparse稀疏檢索,在BM25檢索算法中,向量維度爲整個詞表,但是其中大部分爲0,只有出現的關鍵詞或子詞(tokens)有值,其餘的值都設爲零。這種表示方法不僅節省了存儲空間,而且提高了檢索效率。

向量的形式, 大概類似:

{
   '19828': 0.2085,
   '3508': 0.2374,
   '7919': 0.2544,
   '43': 0.0897,
   '6': 0.0967,
   '79299': 0.3079
}

key是term的編號,value是NN模型計算出來的權重。

稀疏向量與傳統方法的比較

當前流行的sparse檢索,大概是通過transformer模型,爲doc中的term計算weight,這樣與傳統的BM25等基於頻率的方法相比,sparse向量可以利用神經網絡的力量,提高了檢索的準確性和效率。BM25雖然能夠計算文檔的相關性,但它無法理解詞語的含義或上下文的重要性。而稀疏向量則能夠通過神經網絡捕捉到這些細微的差別。

稀疏向量的優勢

  1. 計算效率:稀疏向量在處理包含零元素的操作時,通常比密集向量更高效。
  2. 信息密度:稀疏向量專注於關鍵特徵,而不是捕捉所有細微的關係,這使得它們在文本搜索等應用中更爲高效。
  3. 領域適應性:稀疏向量在處理專業術語或罕見關鍵詞時表現出色,例如在醫療領域,許多專業術語不會出現在通用詞彙表中,稀疏向量能夠更好地捕捉這些術語的細微差別

稀疏向量舉例

SPLADE 是一款開源的transformer模型,提供sparse向量生成,下面是效果對比,可以看到sparse介於BM25和dense之間,比BM25效果好。

Model MRR@10 (MS MARCO Dev) Type
BM25 0.184 Sparse
TCT-ColBERT 0.359 Dense
doc2query-T5 link 0.277 Sparse
SPLADE 0.322 Sparse
SPLADE-max 0.340 Sparse
SPLADE-doc 0.322 Sparse
DistilSPLADE-max 0.368 Sparse

Sparse稀疏檢索實踐

模型介紹

國內的開源模型中,BAAI的BGE-M3提供sparse向量向量生成能力,我們用這個來進行實踐。

BGE是通過RetroMAE的預訓練方式訓練的類似bert的預訓練模型。

常規的Bert預訓練採用了將輸入文本隨機Mask再輸出完整文本這種自監督式的任務,RetroMAE採用一種巧妙的方式提高了Embedding的表徵能力,具體操作是:將低掩碼率的的文本A輸入到Encoder種得到Embedding向量,將該Embedding向量與高掩碼率的文本A輸入到淺層的Decoder向量中,輸出完整文本。這種預訓練方式迫使Encoder生成強大的Embedding向量,在表徵模型中提升效果顯著。

image.png

向量生成

  • 先安裝

    !pip install -U FlagEmbedding

  • 然後引入模型

from FlagEmbedding import BGEM3FlagModel
model = BGEM3FlagModel('BAAI/bge-m3',  use_fp16=True)

編寫一個函數用於計算embedding:

def embed_with_progress(model, docs, batch_size):
    batch_count = int(len(docs) / batch_size) + 1
    print("start embedding docs", batch_count)
    query_embeddings = []
    for i in tqdm(range(batch_count), desc="Embedding...", unit="batch"):
        start = i * batch_size
        end = min(len(docs), (i + 1) * batch_size)
        if end <= start:
            break
        output = model.encode(docs[start:end], return_dense=False, return_sparse=True, return_colbert_vecs=False)
        query_embeddings.extend(output['lexical_weights'])

    return query_embeddings

然後分別計算query和doc的:

query_embeddings = embed_with_progress(model, test_sets.queries, batch_size)
doc_embeddings = embed_with_progress(model, test_sets.docs, batch_size)

然後是計算query和doc的分數,model.compute_lexical_matching_score(交集的權重相乘,然後累加),注意下面的代碼是query和每個doc都計算了,計算量會比較大,在工程實踐中需要用類似向量索引的方案(當前qdrant、milvus等都提供sparse檢索支持)

# 檢索topk
recall_results = []
import numpy as np
for i in tqdm(range(len(test_sets.query_ids)), desc="recall...", unit="query"):
    query_embeding = query_embeddings[i]
    query_id = test_sets.query_ids[i]
    if query_id not in test_sets.relevant_docs:
        continue
    socres = [model.compute_lexical_matching_score(query_embeding, doc_embedding) for doc_embedding in doc_embeddings]
    topk_doc_ids = [test_sets.doc_ids[i] for i in np.argsort(socres)[-20:][::-1]]
    recall_results.append(json.dumps({"query": test_sets.queries[i], "topk_doc_ids": topk_doc_ids, "marked_doc_ids": list(test_sets.relevant_docs[query_id].keys())}))

# recall_results 寫入到文件

with open("recall_results.txt", "w", encoding="utf-8") as f:
    f.write("\n".join(recall_results))

最後,基於測試集,我們可以計算召回率:

import json

# 讀取 JSON line 文件
topk_doc_ids_list = []
marked_doc_ids_list = []

with open("recall_results.txt", "r") as file:
    for line in file:
        data = json.loads(line)
        topk_doc_ids_list.append(data["topk_doc_ids"])
        marked_doc_ids_list.append(data["marked_doc_ids"])


# 計算 recall@k
def recall_at_k(k):
    recalls = []
    for topk_doc_ids, marked_doc_ids in zip(topk_doc_ids_list, marked_doc_ids_list):
        # 提取前 k 個召回結果
        topk = set(topk_doc_ids[:k])
        # 計算交集
        intersection = topk.intersection(set(marked_doc_ids))
        # 計算 recall
        recall = len(intersection) / min(len(marked_doc_ids), k)
        recalls.append(recall)
    # 計算平均 recall
    average_recall = sum(recalls) / len(recalls)
    return average_recall

# 計算 recall@5, 10, 20
recall_at_5 = recall_at_k(5)
recall_at_10 = recall_at_k(10)
recall_at_20 = recall_at_k(20)

print("Recall@5:", recall_at_5)
print("Recall@10:", recall_at_10)
print("Recall@20:", recall_at_20)

在測試集中,測試結果:

Recall@5: 0.7350086355785777 
Recall@10: 0.8035261945883735 
Recall@20: 0.8926130345462158

在這個測試集上,比BM25測試出來的結果要更好,但是僅憑這個尚不能否定BM25,需要綜合看各自的覆蓋度,綜合考慮成本與效果。

參考

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章