检索优化技巧
概述
检索质量直接影响RAG系统的最终效果。通过混合检索、重排序、查询扩展等技术,可以显著提升检索准确率和召回率,为生成环节提供更优质的上下文。
检索质量评估
关键指标
| 指标 | 定义 | 计算方式 |
|---|---|---|
| 召回率 | 相关文档被检索出的比例 | 检索出的相关文档数 / 总相关文档数 |
| 精确率 | 检索结果中相关文档的比例 | 检索出的相关文档数 / 检索出的总文档数 |
| MRR | 平均倒数排名 | 相关文档排名倒数的平均值 |
| NDCG | 归一化折损累计增益 | 考虑排序位置的相关性评分 |
评估代码示例
python
def evaluate_retrieval(queries, ground_truth, retriever, k=5):
recall_scores = []
precision_scores = []
for query, relevant_docs in zip(queries, ground_truth):
retrieved = retriever.get_relevant_documents(query, k=k)
retrieved_ids = [doc.metadata["id"] for doc in retrieved]
relevant_retrieved = len(set(retrieved_ids) & set(relevant_docs))
recall = relevant_retrieved / len(relevant_docs) if relevant_docs else 0
precision = relevant_retrieved / k
recall_scores.append(recall)
precision_scores.append(precision)
return {
"recall": sum(recall_scores) / len(recall_scores),
"precision": sum(precision_scores) / len(precision_scores)
}混合检索
原理说明
混合检索结合密集检索(向量检索)和稀疏检索(关键词检索),兼顾语义理解和精确匹配。
查询 → 密集检索 ──┐
├─→ 结果融合 → 最终结果
查询 → 稀疏检索 ──┘实现方式
python
from langchain.retrievers import EnsembleRetriever
from langchain.vectorstores import Chroma
from langchain.retrievers import BM25Retriever
vectorstore = Chroma.from_documents(documents, embeddings)
dense_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
bm25_retriever = BM25Retriever.from_documents(documents)
bm25_retriever.k = 5
ensemble_retriever = EnsembleRetriever(
retrievers=[dense_retriever, bm25_retriever],
weights=[0.5, 0.5]
)
results = ensemble_retriever.get_relevant_documents("查询内容")权重调优
python
def optimize_weights(queries, ground_truth, retrievers):
best_score = 0
best_weights = None
for w1 in np.arange(0.1, 1.0, 0.1):
w2 = 1.0 - w1
ensemble = EnsembleRetriever(
retrievers=retrievers,
weights=[w1, w2]
)
score = evaluate_retrieval(queries, ground_truth, ensemble)
if score["recall"] > best_score:
best_score = score["recall"]
best_weights = [w1, w2]
return best_weights重排序
为什么需要重排序
向量检索基于语义相似度,但可能遗漏精确匹配的关键信息。重排序使用更精细的模型对初步检索结果重新排序,提升相关性。
Cross-Encoder重排序
python
from sentence_transformers import CrossEncoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
def rerank(query, documents, top_k=5):
pairs = [[query, doc.page_content] for doc in documents]
scores = cross_encoder.predict(pairs)
scored_docs = list(zip(documents, scores))
scored_docs.sort(key=lambda x: x[1], reverse=True)
return [doc for doc, score in scored_docs[:top_k]]
initial_results = retriever.get_relevant_documents("查询内容", k=20)
reranked_results = rerank("查询内容", initial_results, top_k=5)使用Cohere重排序
python
import cohere
co = cohere.Client("your-api-key")
def cohere_rerank(query, documents, top_n=5):
docs_text = [doc.page_content for doc in documents]
results = co.rerank(
query=query,
documents=docs_text,
top_n=top_n,
model="rerank-multilingual-v2.0"
)
reranked_docs = [documents[r.index] for r in results.results]
return reranked_docsLangChain重排序器
python
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
base_retriever = vectorstore.as_retriever(search_kwargs={"k": 20})
compressor = CrossEncoderReranker(
model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
top_n=5
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=base_retriever
)
results = compression_retriever.get_relevant_documents("查询内容")查询扩展
多查询扩展
python
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
llm = OpenAI(temperature=0)
expand_prompt = PromptTemplate.from_template("""
原始查询:{query}
请生成3个语义相关但表述不同的查询,用于检索相关信息。
以JSON数组格式返回,例如:["查询1", "查询2", "查询3"]
""")
def expand_query(query):
response = llm(expand_prompt.format(query=query))
queries = eval(response)
return [query] + queries
def multi_query_retrieve(query, retriever, k=3):
expanded_queries = expand_query(query)
all_docs = []
for q in expanded_queries:
docs = retriever.get_relevant_documents(q, k=k)
all_docs.extend(docs)
unique_docs = list({doc.metadata["id"]: doc for doc in all_docs}.values())
return unique_docs[:k*2]HyDE(假设文档嵌入)
python
hyde_prompt = PromptTemplate.from_template("""
请根据以下问题,生成一段可能包含答案的文档片段:
问题:{query}
文档:
""")
def hyde_retrieve(query, vectorstore, k=5):
hypothetical_doc = llm(hyde_prompt.format(query=query))
results = vectorstore.similarity_search(
hypothetical_doc,
k=k
)
return results查询改写
python
rewrite_prompt = PromptTemplate.from_template("""
原始查询可能表述不清晰,请将其改写为更清晰、更具体的查询。
原始查询:{query}
改写后的查询:
""")
def rewrite_query(query):
return llm(rewrite_prompt.format(query=query)).strip()
original_query = "那个东西怎么用"
rewritten_query = rewrite_query(original_query)
print(f"改写后: {rewritten_query}")上下文压缩
LLM压缩
python
from langchain.retrievers.document_compressors import LLMChainExtractor
llm = OpenAI(temperature=0)
compressor = LLMChainExtractor.from_llm(llm)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=retriever
)
compressed_docs = compression_retriever.get_relevant_documents("查询内容")嵌入过滤
python
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
embeddings_filter = EmbeddingsFilter(
embeddings=embeddings,
similarity_threshold=0.76
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=embeddings_filter,
base_retriever=retriever
)元数据过滤
基础过滤
python
results = vectorstore.similarity_search(
query="查询内容",
k=5,
filter={
"category": "技术文档",
"year": {"$gte": 2023}
}
)自定义过滤函数
python
def filter_by_metadata(documents, filters):
filtered = []
for doc in documents:
match = True
for key, value in filters.items():
if isinstance(value, dict):
if "$gte" in value and doc.metadata.get(key, 0) < value["$gte"]:
match = False
if "$lte" in value and doc.metadata.get(key, 0) > value["$lte"]:
match = False
if "$in" in value and doc.metadata.get(key) not in value["$in"]:
match = False
else:
if doc.metadata.get(key) != value:
match = False
if match:
filtered.append(doc)
return filtered自适应检索
动态调整K值
python
def adaptive_retrieve(query, retriever, min_k=3, max_k=10):
initial_docs = retriever.get_relevant_documents(query, k=max_k)
if len(initial_docs) < max_k:
return initial_docs
scores = [doc.metadata.get("score", 1.0) for doc in initial_docs]
score_threshold = scores[0] * 0.7
filtered_docs = [
doc for doc, score in zip(initial_docs, scores)
if score >= score_threshold
]
if len(filtered_docs) < min_k:
filtered_docs = initial_docs[:min_k]
return filtered_docs查询复杂度判断
python
def analyze_query_complexity(query):
words = query.split()
complexity = {
"length": len(words),
"has_question": "?" in query,
"has_keywords": any(kw in query.lower() for kw in ["如何", "什么是", "为什么"])
}
if complexity["length"] > 20 or complexity["has_keywords"]:
return "complex"
else:
return "simple"
def smart_retrieve(query, retriever):
complexity = analyze_query_complexity(query)
if complexity == "complex":
return multi_query_retrieve(query, retriever, k=5)
else:
return retriever.get_relevant_documents(query, k=3)检索结果缓存
语义缓存
python
from langchain.cache import InMemoryCache
from langchain.embeddings import OpenAIEmbeddings
import numpy as np
class SemanticCache:
def __init__(self, threshold=0.95):
self.cache = {}
self.embeddings = OpenAIEmbeddings()
self.threshold = threshold
def get(self, query):
query_embedding = self.embeddings.embed_query(query)
for cached_query, (cached_embedding, result) in self.cache.items():
similarity = np.dot(query_embedding, cached_embedding) / (
np.linalg.norm(query_embedding) * np.linalg.norm(cached_embedding)
)
if similarity > self.threshold:
return result
return None
def set(self, query, result):
query_embedding = self.embeddings.embed_query(query)
self.cache[query] = (query_embedding, result)
semantic_cache = SemanticCache()
def cached_retrieve(query, retriever):
cached_result = semantic_cache.get(query)
if cached_result:
return cached_result
result = retriever.get_relevant_documents(query)
semantic_cache.set(query, result)
return result完整优化流程
python
class OptimizedRetriever:
def __init__(self, vectorstore, use_rerank=True, use_multi_query=True):
self.vectorstore = vectorstore
self.use_rerank = use_rerank
self.use_multi_query = use_multi_query
if use_rerank:
self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
self.cache = SemanticCache()
def retrieve(self, query, k=5):
cached = self.cache.get(query)
if cached:
return cached
if self.use_multi_query:
queries = expand_query(query)
all_docs = []
for q in queries:
docs = self.vectorstore.similarity_search(q, k=k*2)
all_docs.extend(docs)
unique_docs = list({doc.metadata["id"]: doc for doc in all_docs}.values())
else:
unique_docs = self.vectorstore.similarity_search(query, k=k*3)
if self.use_rerank:
unique_docs = self.rerank(query, unique_docs, top_k=k)
self.cache.set(query, unique_docs)
return unique_docs
def rerank(self, query, documents, top_k=5):
pairs = [[query, doc.page_content] for doc in documents]
scores = self.cross_encoder.predict(pairs)
scored_docs = list(zip(documents, scores))
scored_docs.sort(key=lambda x: x[1], reverse=True)
return [doc for doc, _ in scored_docs[:top_k]]
optimized_retriever = OptimizedRetriever(vectorstore)
results = optimized_retriever.retrieve("查询内容", k=5)小结
检索优化是提升RAG系统效果的关键环节。混合检索结合语义和关键词匹配,重排序提升结果相关性,查询扩展提高召回率。实际应用中应根据场景特点组合使用这些技术,并通过评估指标持续优化。
下一章将介绍RAG最佳实践,涵盖文档预处理、索引优化和召回率提升等完整方案。