RAG最佳实践
概述
构建高质量的RAG系统需要在文档预处理、索引构建、检索优化、生成增强等环节都做到位。本章总结实际项目中的最佳实践,帮助构建生产级RAG应用。
文档预处理最佳实践
文档质量检查
python
import re
from typing import List, Dict
class DocumentQualityChecker:
def __init__(self):
self.min_length = 50
self.max_length = 100000
def check(self, documents: List) -> Dict:
issues = []
for i, doc in enumerate(documents):
content = doc.page_content
if len(content) < self.min_length:
issues.append({
"doc_index": i,
"issue": "content_too_short",
"length": len(content)
})
if len(content) > self.max_length:
issues.append({
"doc_index": i,
"issue": "content_too_long",
"length": len(content)
})
if self._is_garbled(content):
issues.append({
"doc_index": i,
"issue": "garbled_text"
})
if self._has_encoding_issues(content):
issues.append({
"doc_index": i,
"issue": "encoding_error"
})
return {
"total_docs": len(documents),
"issues": issues,
"quality_score": 1 - len(issues) / max(len(documents), 1)
}
def _is_garbled(self, text):
chinese_ratio = len(re.findall(r'[\u4e00-\u9fff]', text)) / max(len(text), 1)
return chinese_ratio < 0.1 and len(text) > 100
def _has_encoding_issues(self, text):
return bool(re.search(r'[\x00-\x08\x0b\x0c\x0e-\x1f]', text))
checker = DocumentQualityChecker()
quality_report = checker.check(documents)
print(f"文档质量评分: {quality_report['quality_score']:.2%}")文档去重
python
from collections import defaultdict
import hashlib
def deduplicate_documents(documents, threshold=0.95):
unique_docs = []
seen_hashes = set()
for doc in documents:
content_hash = hashlib.md5(doc.page_content.encode()).hexdigest()
if content_hash in seen_hashes:
continue
is_duplicate = False
for unique_doc in unique_docs:
similarity = calculate_similarity(doc.page_content, unique_doc.page_content)
if similarity > threshold:
is_duplicate = True
break
if not is_duplicate:
unique_docs.append(doc)
seen_hashes.add(content_hash)
return unique_docs
def calculate_similarity(text1, text2):
words1 = set(text1.split())
words2 = set(text2.split())
intersection = words1 & words2
union = words1 | words2
return len(intersection) / len(union) if union else 0
deduped_docs = deduplicate_documents(documents)
print(f"去重前: {len(documents)}, 去重后: {len(deduped_docs)}")增量更新策略
python
import os
import json
from datetime import datetime
class IncrementalUpdater:
def __init__(self, index_path, file_tracker_path):
self.index_path = index_path
self.file_tracker_path = file_tracker_path
self.file_tracker = self._load_tracker()
def _load_tracker(self):
if os.path.exists(self.file_tracker_path):
with open(self.file_tracker_path, 'r') as f:
return json.load(f)
return {}
def _save_tracker(self):
with open(self.file_tracker_path, 'w') as f:
json.dump(self.file_tracker, f, indent=2)
def get_file_hash(self, file_path):
with open(file_path, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
def detect_changes(self, file_paths):
changes = {
"new": [],
"modified": [],
"deleted": []
}
current_files = set(file_paths)
tracked_files = set(self.file_tracker.keys())
for file_path in file_paths:
current_hash = self.get_file_hash(file_path)
if file_path not in self.file_tracker:
changes["new"].append(file_path)
elif self.file_tracker[file_path]["hash"] != current_hash:
changes["modified"].append(file_path)
changes["deleted"] = list(tracked_files - current_files)
return changes
def update_index(self, vectorstore, file_paths):
changes = self.detect_changes(file_paths)
for file_path in changes["deleted"]:
vectorstore.delete(filter={"source": file_path})
del self.file_tracker[file_path]
for file_path in changes["modified"]:
vectorstore.delete(filter={"source": file_path})
self._add_file(vectorstore, file_path)
for file_path in changes["new"]:
self._add_file(vectorstore, file_path)
self._save_tracker()
return changes
def _add_file(self, vectorstore, file_path):
documents = load_document(file_path)
chunks = split_documents(documents)
for chunk in chunks:
chunk.metadata["source"] = file_path
vectorstore.add_documents(chunks)
self.file_tracker[file_path] = {
"hash": self.get_file_hash(file_path),
"indexed_at": datetime.now().isoformat(),
"chunk_count": len(chunks)
}
updater = IncrementalUpdater("./index", "./file_tracker.json")
changes = updater.update_index(vectorstore, file_list)
print(f"新增: {len(changes['new'])}, 修改: {len(changes['modified'])}, 删除: {len(changes['deleted'])}")索引优化最佳实践
分层索引
python
class HierarchicalIndex:
def __init__(self, vectorstore):
self.vectorstore = vectorstore
self.summary_index = {}
def build_hierarchical_index(self, documents):
for doc in documents:
doc_id = doc.metadata.get("doc_id")
summary = self._generate_summary(doc.page_content)
self.summary_index[doc_id] = {
"summary": summary,
"summary_embedding": get_embedding(summary),
"chunk_count": len(doc.page_content) // 1000
}
def retrieve(self, query, k=5):
query_embedding = get_embedding(query)
doc_scores = []
for doc_id, info in self.summary_index.items():
similarity = cosine_similarity(query_embedding, info["summary_embedding"])
doc_scores.append((doc_id, similarity))
doc_scores.sort(key=lambda x: x[1], reverse=True)
top_docs = [doc_id for doc_id, _ in doc_scores[:k*2]]
chunks = self.vectorstore.similarity_search(
query,
k=k,
filter={"doc_id": {"$in": top_docs}}
)
return chunks
def _generate_summary(self, text):
prompt = f"请用一句话总结以下内容:\n\n{text[:1000]}"
return llm(prompt)
hierarchical_index = HierarchicalIndex(vectorstore)
hierarchical_index.build_hierarchical_index(documents)多字段索引
python
from langchain.vectorstores import Chroma
from langchain.schema import Document
class MultiFieldIndex:
def __init__(self, vectorstore):
self.vectorstore = vectorstore
def add_document(self, doc):
title_embedding = get_embedding(doc.metadata.get("title", ""))
content_embedding = get_embedding(doc.page_content)
keywords_embedding = get_embedding(" ".join(doc.metadata.get("keywords", [])))
self.vectorstore.add_texts(
texts=[doc.page_content],
metadatas=[{
**doc.metadata,
"title_embedding": title_embedding,
"keywords_embedding": keywords_embedding
}]
)
def search(self, query, weights={"content": 0.5, "title": 0.3, "keywords": 0.2}):
query_embedding = get_embedding(query)
content_results = self.vectorstore.similarity_search_by_vector(
query_embedding,
k=10
)
return content_results索引性能监控
python
import time
from collections import defaultdict
class IndexMonitor:
def __init__(self):
self.metrics = defaultdict(list)
def track_query(self, query_type, duration, result_count):
self.metrics[f"{query_type}_duration"].append(duration)
self.metrics[f"{query_type}_results"].append(result_count)
def get_stats(self):
stats = {}
for key, values in self.metrics.items():
if "duration" in key:
stats[key] = {
"avg": sum(values) / len(values),
"max": max(values),
"min": min(values)
}
else:
stats[key] = {
"avg": sum(values) / len(values),
"total": sum(values)
}
return stats
monitor = IndexMonitor()
def monitored_search(vectorstore, query, k=5):
start_time = time.time()
results = vectorstore.similarity_search(query, k=k)
duration = time.time() - start_time
monitor.track_query("similarity_search", duration, len(results))
return results召回率提升策略
多路召回
python
class MultiPathRetriever:
def __init__(self, vectorstore, bm25_index):
self.vectorstore = vectorstore
self.bm25_index = bm25_index
def retrieve(self, query, k=10):
vector_results = self._vector_search(query, k=k*2)
bm25_results = self._bm25_search(query, k=k*2)
keyword_results = self._keyword_search(query, k=k*2)
merged = self._merge_results(
[vector_results, bm25_results, keyword_results],
weights=[0.5, 0.3, 0.2]
)
return merged[:k]
def _vector_search(self, query, k):
return self.vectorstore.similarity_search(query, k=k)
def _bm25_search(self, query, k):
return self.bm25_index.get_top_k(query, k)
def _keyword_search(self, query, k):
keywords = self._extract_keywords(query)
return self.vectorstore.search(
filter={"keywords": {"$in": keywords}},
k=k
)
def _merge_results(self, result_lists, weights):
doc_scores = defaultdict(float)
for results, weight in zip(result_lists, weights):
for rank, doc in enumerate(results):
doc_id = doc.metadata.get("id")
score = weight / (rank + 1)
doc_scores[doc_id] += score
sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
return [self._get_doc_by_id(doc_id) for doc_id, _ in sorted_docs]查询理解与改写
python
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
class QueryUnderstanding:
def __init__(self):
self.llm = OpenAI(temperature=0)
def understand(self, query):
intent = self._detect_intent(query)
entities = self._extract_entities(query)
expanded = self._expand_query(query)
return {
"original": query,
"intent": intent,
"entities": entities,
"expanded_queries": expanded
}
def _detect_intent(self, query):
prompt = PromptTemplate.from_template("""
分析以下查询的意图,返回最匹配的类型:
- 定义查询:询问概念定义
- 操作查询:询问如何操作
- 对比查询:询问对比分析
- 列举查询:询问列举项目
- 其他
查询:{query}
意图:
""")
return self.llm(prompt.format(query=query)).strip()
def _extract_entities(self, query):
prompt = PromptTemplate.from_template("""
从以下查询中提取关键实体(名词、术语、专有名词):
查询:{query}
实体(JSON数组格式):
""")
response = self.llm(prompt.format(query=query))
return eval(response)
def _expand_query(self, query):
prompt = PromptTemplate.from_template("""
将以下查询改写为3个不同的表述,保持语义一致:
查询:{query}
改写(JSON数组格式):
""")
response = self.llm(prompt.format(query=query))
return eval(response)
query_understanding = QueryUnderstanding()
understood = query_understanding.understand("如何使用Python处理PDF文件?")生成增强技巧
上下文组织
python
class ContextOrganizer:
def __init__(self, max_tokens=4000):
self.max_tokens = max_tokens
def organize(self, query, documents):
documents = self._deduplicate(documents)
documents = self._sort_by_relevance(query, documents)
documents = self._truncate(documents)
context = self._format_context(documents)
return context
def _deduplicate(self, documents):
seen = set()
unique = []
for doc in documents:
doc_id = doc.metadata.get("id", hash(doc.page_content))
if doc_id not in seen:
seen.add(doc_id)
unique.append(doc)
return unique
def _sort_by_relevance(self, query, documents):
query_embedding = get_embedding(query)
scored_docs = []
for doc in documents:
doc_embedding = get_embedding(doc.page_content[:500])
similarity = cosine_similarity(query_embedding, doc_embedding)
scored_docs.append((doc, similarity))
scored_docs.sort(key=lambda x: x[1], reverse=True)
return [doc for doc, _ in scored_docs]
def _truncate(self, documents):
total_tokens = 0
truncated = []
for doc in documents:
doc_tokens = len(doc.page_content) // 4
if total_tokens + doc_tokens > self.max_tokens:
remaining = self.max_tokens - total_tokens
if remaining > 100:
truncated_doc = Document(
page_content=doc.page_content[:remaining*4],
metadata=doc.metadata
)
truncated.append(truncated_doc)
break
truncated.append(doc)
total_tokens += doc_tokens
return truncated
def _format_context(self, documents):
context_parts = []
for i, doc in enumerate(documents, 1):
source = doc.metadata.get("source", "未知来源")
context_parts.append(f"[文档{i}] 来源:{source}\n{doc.page_content}\n")
return "\n".join(context_parts)
organizer = ContextOrganizer(max_tokens=4000)
context = organizer.organize(query, retrieved_docs)引用追踪
python
class CitationTracker:
def __init__(self):
self.citations = []
def track(self, documents):
self.citations = []
for i, doc in enumerate(documents, 1):
citation = {
"id": i,
"source": doc.metadata.get("source", "未知"),
"page": doc.metadata.get("page", "N/A"),
"chunk_id": doc.metadata.get("chunk_id", "N/A"),
"content_preview": doc.page_content[:100] + "..."
}
self.citations.append(citation)
return self.citations
def format_citations(self):
formatted = "\n\n参考文献:\n"
for citation in self.citations:
formatted += f"[{citation['id']}] {citation['source']}"
if citation['page'] != "N/A":
formatted += f", 第{citation['page']}页"
formatted += "\n"
return formatted
def add_to_response(self, response):
return response + self.format_citations()
citation_tracker = CitationTracker()
citations = citation_tracker.track(retrieved_docs)
final_response = citation_tracker.add_to_response(llm_response)错误处理与降级
检索失败处理
python
class RobustRetriever:
def __init__(self, primary_retriever, fallback_retriever=None):
self.primary_retriever = primary_retriever
self.fallback_retriever = fallback_retriever
def retrieve(self, query, k=5):
try:
results = self.primary_retriever.get_relevant_documents(query, k=k)
if not results:
return self._handle_empty_results(query, k)
return results
except Exception as e:
return self._handle_error(query, k, str(e))
def _handle_empty_results(self, query, k):
if self.fallback_retriever:
return self.fallback_retriever.get_relevant_documents(query, k=k)
return [Document(
page_content="抱歉,未找到相关信息。请尝试更换查询方式。",
metadata={"type": "fallback"}
)]
def _handle_error(self, query, k, error_message):
print(f"检索错误: {error_message}")
if self.fallback_retriever:
try:
return self.fallback_retriever.get_relevant_documents(query, k=k)
except:
pass
return [Document(
page_content="检索服务暂时不可用,请稍后重试。",
metadata={"type": "error", "error": error_message}
)]
robust_retriever = RobustRetriever(
primary_retriever=vector_retriever,
fallback_retriever=bm25_retriever
)响应质量检查
python
class ResponseQualityChecker:
def __init__(self):
self.min_length = 20
self.max_length = 2000
def check(self, response, context):
issues = []
if len(response) < self.min_length:
issues.append("response_too_short")
if len(response) > self.max_length:
issues.append("response_too_long")
if not self._has_context_support(response, context):
issues.append("no_context_support")
if self._has_hallucination_indicators(response):
issues.append("potential_hallucination")
return {
"is_valid": len(issues) == 0,
"issues": issues
}
def _has_context_support(self, response, context):
context_keywords = set(context.split())
response_keywords = set(response.split())
overlap = context_keywords & response_keywords
return len(overlap) > 5
def _has_hallucination_indicators(self, response):
indicators = [
"我不确定",
"可能",
"猜测",
"应该"
]
return any(indicator in response for indicator in indicators)
quality_checker = ResponseQualityChecker()
quality = quality_checker.check(response, context)监控与日志
完整监控系统
python
import logging
from datetime import datetime
import json
class RAGMonitor:
def __init__(self, log_file="rag_monitor.log"):
self.logger = self._setup_logger(log_file)
self.metrics = {
"queries": 0,
"total_latency": 0,
"errors": 0,
"avg_retrieval_time": 0,
"avg_generation_time": 0
}
def _setup_logger(self, log_file):
logger = logging.getLogger("RAGMonitor")
logger.setLevel(logging.INFO)
handler = logging.FileHandler(log_file, encoding='utf-8')
handler.setFormatter(logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s'
))
logger.addHandler(handler)
return logger
def log_query(self, query, retrieval_time, generation_time, result_count):
self.metrics["queries"] += 1
self.metrics["total_latency"] += retrieval_time + generation_time
self.logger.info(json.dumps({
"event": "query",
"query": query[:100],
"retrieval_time": retrieval_time,
"generation_time": generation_time,
"result_count": result_count,
"timestamp": datetime.now().isoformat()
}))
def log_error(self, query, error_type, error_message):
self.metrics["errors"] += 1
self.logger.error(json.dumps({
"event": "error",
"query": query[:100],
"error_type": error_type,
"error_message": error_message,
"timestamp": datetime.now().isoformat()
}))
def get_stats(self):
if self.metrics["queries"] == 0:
return self.metrics
return {
**self.metrics,
"avg_latency": self.metrics["total_latency"] / self.metrics["queries"],
"error_rate": self.metrics["errors"] / self.metrics["queries"]
}
monitor = RAGMonitor()
def monitored_rag_query(query, retriever, generator):
import time
try:
start_time = time.time()
docs = retriever.get_relevant_documents(query)
retrieval_time = time.time() - start_time
start_time = time.time()
response = generator.generate(query, docs)
generation_time = time.time() - start_time
monitor.log_query(query, retrieval_time, generation_time, len(docs))
return response
except Exception as e:
monitor.log_error(query, type(e).__name__, str(e))
raise小结
构建生产级RAG系统需要在各个环节精益求精:文档预处理确保数据质量,索引优化提升检索效率,召回率增强保证结果全面,生成增强提高回答质量。同时,完善的错误处理和监控体系是系统稳定运行的保障。
RAG技术栈发展迅速,建议持续关注新技术和最佳实践,不断优化和迭代系统。结合实际业务场景,灵活应用这些技术,才能构建出真正有价值的RAG应用。