Skip to content

检索优化技巧

概述

检索质量直接影响RAG系统的最终效果。通过混合检索、重排序、查询扩展等技术,可以显著提升检索准确率和召回率,为生成环节提供更优质的上下文。

检索质量评估

关键指标

指标定义计算方式
召回率相关文档被检索出的比例检索出的相关文档数 / 总相关文档数
精确率检索结果中相关文档的比例检索出的相关文档数 / 检索出的总文档数
MRR平均倒数排名相关文档排名倒数的平均值
NDCG归一化折损累计增益考虑排序位置的相关性评分

评估代码示例

python
def evaluate_retrieval(queries, ground_truth, retriever, k=5):
    recall_scores = []
    precision_scores = []
    
    for query, relevant_docs in zip(queries, ground_truth):
        retrieved = retriever.get_relevant_documents(query, k=k)
        retrieved_ids = [doc.metadata["id"] for doc in retrieved]
        
        relevant_retrieved = len(set(retrieved_ids) & set(relevant_docs))
        
        recall = relevant_retrieved / len(relevant_docs) if relevant_docs else 0
        precision = relevant_retrieved / k
        
        recall_scores.append(recall)
        precision_scores.append(precision)
    
    return {
        "recall": sum(recall_scores) / len(recall_scores),
        "precision": sum(precision_scores) / len(precision_scores)
    }

混合检索

原理说明

混合检索结合密集检索(向量检索)和稀疏检索(关键词检索),兼顾语义理解和精确匹配。

查询 → 密集检索 ──┐
                 ├─→ 结果融合 → 最终结果
查询 → 稀疏检索 ──┘

实现方式

python
from langchain.retrievers import EnsembleRetriever
from langchain.vectorstores import Chroma
from langchain.retrievers import BM25Retriever

vectorstore = Chroma.from_documents(documents, embeddings)
dense_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})

bm25_retriever = BM25Retriever.from_documents(documents)
bm25_retriever.k = 5

ensemble_retriever = EnsembleRetriever(
    retrievers=[dense_retriever, bm25_retriever],
    weights=[0.5, 0.5]
)

results = ensemble_retriever.get_relevant_documents("查询内容")

权重调优

python
def optimize_weights(queries, ground_truth, retrievers):
    best_score = 0
    best_weights = None
    
    for w1 in np.arange(0.1, 1.0, 0.1):
        w2 = 1.0 - w1
        ensemble = EnsembleRetriever(
            retrievers=retrievers,
            weights=[w1, w2]
        )
        
        score = evaluate_retrieval(queries, ground_truth, ensemble)
        
        if score["recall"] > best_score:
            best_score = score["recall"]
            best_weights = [w1, w2]
    
    return best_weights

重排序

为什么需要重排序

向量检索基于语义相似度,但可能遗漏精确匹配的关键信息。重排序使用更精细的模型对初步检索结果重新排序,提升相关性。

Cross-Encoder重排序

python
from sentence_transformers import CrossEncoder

cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

def rerank(query, documents, top_k=5):
    pairs = [[query, doc.page_content] for doc in documents]
    scores = cross_encoder.predict(pairs)
    
    scored_docs = list(zip(documents, scores))
    scored_docs.sort(key=lambda x: x[1], reverse=True)
    
    return [doc for doc, score in scored_docs[:top_k]]

initial_results = retriever.get_relevant_documents("查询内容", k=20)
reranked_results = rerank("查询内容", initial_results, top_k=5)

使用Cohere重排序

python
import cohere

co = cohere.Client("your-api-key")

def cohere_rerank(query, documents, top_n=5):
    docs_text = [doc.page_content for doc in documents]
    
    results = co.rerank(
        query=query,
        documents=docs_text,
        top_n=top_n,
        model="rerank-multilingual-v2.0"
    )
    
    reranked_docs = [documents[r.index] for r in results.results]
    return reranked_docs

LangChain重排序器

python
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker

base_retriever = vectorstore.as_retriever(search_kwargs={"k": 20})

compressor = CrossEncoderReranker(
    model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
    top_n=5
)

compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=base_retriever
)

results = compression_retriever.get_relevant_documents("查询内容")

查询扩展

多查询扩展

python
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate

llm = OpenAI(temperature=0)

expand_prompt = PromptTemplate.from_template("""
原始查询:{query}

请生成3个语义相关但表述不同的查询,用于检索相关信息。
以JSON数组格式返回,例如:["查询1", "查询2", "查询3"]
""")

def expand_query(query):
    response = llm(expand_prompt.format(query=query))
    queries = eval(response)
    return [query] + queries

def multi_query_retrieve(query, retriever, k=3):
    expanded_queries = expand_query(query)
    
    all_docs = []
    for q in expanded_queries:
        docs = retriever.get_relevant_documents(q, k=k)
        all_docs.extend(docs)
    
    unique_docs = list({doc.metadata["id"]: doc for doc in all_docs}.values())
    
    return unique_docs[:k*2]

HyDE(假设文档嵌入)

python
hyde_prompt = PromptTemplate.from_template("""
请根据以下问题,生成一段可能包含答案的文档片段:

问题:{query}

文档:
""")

def hyde_retrieve(query, vectorstore, k=5):
    hypothetical_doc = llm(hyde_prompt.format(query=query))
    
    results = vectorstore.similarity_search(
        hypothetical_doc,
        k=k
    )
    
    return results

查询改写

python
rewrite_prompt = PromptTemplate.from_template("""
原始查询可能表述不清晰,请将其改写为更清晰、更具体的查询。

原始查询:{query}

改写后的查询:
""")

def rewrite_query(query):
    return llm(rewrite_prompt.format(query=query)).strip()

original_query = "那个东西怎么用"
rewritten_query = rewrite_query(original_query)
print(f"改写后: {rewritten_query}")

上下文压缩

LLM压缩

python
from langchain.retrievers.document_compressors import LLMChainExtractor

llm = OpenAI(temperature=0)
compressor = LLMChainExtractor.from_llm(llm)

compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=retriever
)

compressed_docs = compression_retriever.get_relevant_documents("查询内容")

嵌入过滤

python
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.embeddings import OpenAIEmbeddings

embeddings = OpenAIEmbeddings()
embeddings_filter = EmbeddingsFilter(
    embeddings=embeddings,
    similarity_threshold=0.76
)

compression_retriever = ContextualCompressionRetriever(
    base_compressor=embeddings_filter,
    base_retriever=retriever
)

元数据过滤

基础过滤

python
results = vectorstore.similarity_search(
    query="查询内容",
    k=5,
    filter={
        "category": "技术文档",
        "year": {"$gte": 2023}
    }
)

自定义过滤函数

python
def filter_by_metadata(documents, filters):
    filtered = []
    
    for doc in documents:
        match = True
        for key, value in filters.items():
            if isinstance(value, dict):
                if "$gte" in value and doc.metadata.get(key, 0) < value["$gte"]:
                    match = False
                if "$lte" in value and doc.metadata.get(key, 0) > value["$lte"]:
                    match = False
                if "$in" in value and doc.metadata.get(key) not in value["$in"]:
                    match = False
            else:
                if doc.metadata.get(key) != value:
                    match = False
        
        if match:
            filtered.append(doc)
    
    return filtered

自适应检索

动态调整K值

python
def adaptive_retrieve(query, retriever, min_k=3, max_k=10):
    initial_docs = retriever.get_relevant_documents(query, k=max_k)
    
    if len(initial_docs) < max_k:
        return initial_docs
    
    scores = [doc.metadata.get("score", 1.0) for doc in initial_docs]
    
    score_threshold = scores[0] * 0.7
    
    filtered_docs = [
        doc for doc, score in zip(initial_docs, scores)
        if score >= score_threshold
    ]
    
    if len(filtered_docs) < min_k:
        filtered_docs = initial_docs[:min_k]
    
    return filtered_docs

查询复杂度判断

python
def analyze_query_complexity(query):
    words = query.split()
    
    complexity = {
        "length": len(words),
        "has_question": "?" in query,
        "has_keywords": any(kw in query.lower() for kw in ["如何", "什么是", "为什么"])
    }
    
    if complexity["length"] > 20 or complexity["has_keywords"]:
        return "complex"
    else:
        return "simple"

def smart_retrieve(query, retriever):
    complexity = analyze_query_complexity(query)
    
    if complexity == "complex":
        return multi_query_retrieve(query, retriever, k=5)
    else:
        return retriever.get_relevant_documents(query, k=3)

检索结果缓存

语义缓存

python
from langchain.cache import InMemoryCache
from langchain.embeddings import OpenAIEmbeddings
import numpy as np

class SemanticCache:
    def __init__(self, threshold=0.95):
        self.cache = {}
        self.embeddings = OpenAIEmbeddings()
        self.threshold = threshold
    
    def get(self, query):
        query_embedding = self.embeddings.embed_query(query)
        
        for cached_query, (cached_embedding, result) in self.cache.items():
            similarity = np.dot(query_embedding, cached_embedding) / (
                np.linalg.norm(query_embedding) * np.linalg.norm(cached_embedding)
            )
            
            if similarity > self.threshold:
                return result
        
        return None
    
    def set(self, query, result):
        query_embedding = self.embeddings.embed_query(query)
        self.cache[query] = (query_embedding, result)

semantic_cache = SemanticCache()

def cached_retrieve(query, retriever):
    cached_result = semantic_cache.get(query)
    if cached_result:
        return cached_result
    
    result = retriever.get_relevant_documents(query)
    semantic_cache.set(query, result)
    return result

完整优化流程

python
class OptimizedRetriever:
    def __init__(self, vectorstore, use_rerank=True, use_multi_query=True):
        self.vectorstore = vectorstore
        self.use_rerank = use_rerank
        self.use_multi_query = use_multi_query
        
        if use_rerank:
            self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
        
        self.cache = SemanticCache()
    
    def retrieve(self, query, k=5):
        cached = self.cache.get(query)
        if cached:
            return cached
        
        if self.use_multi_query:
            queries = expand_query(query)
            all_docs = []
            for q in queries:
                docs = self.vectorstore.similarity_search(q, k=k*2)
                all_docs.extend(docs)
            
            unique_docs = list({doc.metadata["id"]: doc for doc in all_docs}.values())
        else:
            unique_docs = self.vectorstore.similarity_search(query, k=k*3)
        
        if self.use_rerank:
            unique_docs = self.rerank(query, unique_docs, top_k=k)
        
        self.cache.set(query, unique_docs)
        
        return unique_docs
    
    def rerank(self, query, documents, top_k=5):
        pairs = [[query, doc.page_content] for doc in documents]
        scores = self.cross_encoder.predict(pairs)
        
        scored_docs = list(zip(documents, scores))
        scored_docs.sort(key=lambda x: x[1], reverse=True)
        
        return [doc for doc, _ in scored_docs[:top_k]]

optimized_retriever = OptimizedRetriever(vectorstore)
results = optimized_retriever.retrieve("查询内容", k=5)

小结

检索优化是提升RAG系统效果的关键环节。混合检索结合语义和关键词匹配,重排序提升结果相关性,查询扩展提高召回率。实际应用中应根据场景特点组合使用这些技术,并通过评估指标持续优化。

下一章将介绍RAG最佳实践,涵盖文档预处理、索引优化和召回率提升等完整方案。