Skip to content

RAG最佳实践

概述

构建高质量的RAG系统需要在文档预处理、索引构建、检索优化、生成增强等环节都做到位。本章总结实际项目中的最佳实践,帮助构建生产级RAG应用。

文档预处理最佳实践

文档质量检查

python
import re
from typing import List, Dict

class DocumentQualityChecker:
    def __init__(self):
        self.min_length = 50
        self.max_length = 100000
    
    def check(self, documents: List) -> Dict:
        issues = []
        
        for i, doc in enumerate(documents):
            content = doc.page_content
            
            if len(content) < self.min_length:
                issues.append({
                    "doc_index": i,
                    "issue": "content_too_short",
                    "length": len(content)
                })
            
            if len(content) > self.max_length:
                issues.append({
                    "doc_index": i,
                    "issue": "content_too_long",
                    "length": len(content)
                })
            
            if self._is_garbled(content):
                issues.append({
                    "doc_index": i,
                    "issue": "garbled_text"
                })
            
            if self._has_encoding_issues(content):
                issues.append({
                    "doc_index": i,
                    "issue": "encoding_error"
                })
        
        return {
            "total_docs": len(documents),
            "issues": issues,
            "quality_score": 1 - len(issues) / max(len(documents), 1)
        }
    
    def _is_garbled(self, text):
        chinese_ratio = len(re.findall(r'[\u4e00-\u9fff]', text)) / max(len(text), 1)
        return chinese_ratio < 0.1 and len(text) > 100
    
    def _has_encoding_issues(self, text):
        return bool(re.search(r'[\x00-\x08\x0b\x0c\x0e-\x1f]', text))

checker = DocumentQualityChecker()
quality_report = checker.check(documents)
print(f"文档质量评分: {quality_report['quality_score']:.2%}")

文档去重

python
from collections import defaultdict
import hashlib

def deduplicate_documents(documents, threshold=0.95):
    unique_docs = []
    seen_hashes = set()
    
    for doc in documents:
        content_hash = hashlib.md5(doc.page_content.encode()).hexdigest()
        
        if content_hash in seen_hashes:
            continue
        
        is_duplicate = False
        for unique_doc in unique_docs:
            similarity = calculate_similarity(doc.page_content, unique_doc.page_content)
            if similarity > threshold:
                is_duplicate = True
                break
        
        if not is_duplicate:
            unique_docs.append(doc)
            seen_hashes.add(content_hash)
    
    return unique_docs

def calculate_similarity(text1, text2):
    words1 = set(text1.split())
    words2 = set(text2.split())
    
    intersection = words1 & words2
    union = words1 | words2
    
    return len(intersection) / len(union) if union else 0

deduped_docs = deduplicate_documents(documents)
print(f"去重前: {len(documents)}, 去重后: {len(deduped_docs)}")

增量更新策略

python
import os
import json
from datetime import datetime

class IncrementalUpdater:
    def __init__(self, index_path, file_tracker_path):
        self.index_path = index_path
        self.file_tracker_path = file_tracker_path
        self.file_tracker = self._load_tracker()
    
    def _load_tracker(self):
        if os.path.exists(self.file_tracker_path):
            with open(self.file_tracker_path, 'r') as f:
                return json.load(f)
        return {}
    
    def _save_tracker(self):
        with open(self.file_tracker_path, 'w') as f:
            json.dump(self.file_tracker, f, indent=2)
    
    def get_file_hash(self, file_path):
        with open(file_path, 'rb') as f:
            return hashlib.md5(f.read()).hexdigest()
    
    def detect_changes(self, file_paths):
        changes = {
            "new": [],
            "modified": [],
            "deleted": []
        }
        
        current_files = set(file_paths)
        tracked_files = set(self.file_tracker.keys())
        
        for file_path in file_paths:
            current_hash = self.get_file_hash(file_path)
            
            if file_path not in self.file_tracker:
                changes["new"].append(file_path)
            elif self.file_tracker[file_path]["hash"] != current_hash:
                changes["modified"].append(file_path)
        
        changes["deleted"] = list(tracked_files - current_files)
        
        return changes
    
    def update_index(self, vectorstore, file_paths):
        changes = self.detect_changes(file_paths)
        
        for file_path in changes["deleted"]:
            vectorstore.delete(filter={"source": file_path})
            del self.file_tracker[file_path]
        
        for file_path in changes["modified"]:
            vectorstore.delete(filter={"source": file_path})
            self._add_file(vectorstore, file_path)
        
        for file_path in changes["new"]:
            self._add_file(vectorstore, file_path)
        
        self._save_tracker()
        
        return changes
    
    def _add_file(self, vectorstore, file_path):
        documents = load_document(file_path)
        chunks = split_documents(documents)
        
        for chunk in chunks:
            chunk.metadata["source"] = file_path
        
        vectorstore.add_documents(chunks)
        
        self.file_tracker[file_path] = {
            "hash": self.get_file_hash(file_path),
            "indexed_at": datetime.now().isoformat(),
            "chunk_count": len(chunks)
        }

updater = IncrementalUpdater("./index", "./file_tracker.json")
changes = updater.update_index(vectorstore, file_list)
print(f"新增: {len(changes['new'])}, 修改: {len(changes['modified'])}, 删除: {len(changes['deleted'])}")

索引优化最佳实践

分层索引

python
class HierarchicalIndex:
    def __init__(self, vectorstore):
        self.vectorstore = vectorstore
        self.summary_index = {}
    
    def build_hierarchical_index(self, documents):
        for doc in documents:
            doc_id = doc.metadata.get("doc_id")
            summary = self._generate_summary(doc.page_content)
            
            self.summary_index[doc_id] = {
                "summary": summary,
                "summary_embedding": get_embedding(summary),
                "chunk_count": len(doc.page_content) // 1000
            }
    
    def retrieve(self, query, k=5):
        query_embedding = get_embedding(query)
        
        doc_scores = []
        for doc_id, info in self.summary_index.items():
            similarity = cosine_similarity(query_embedding, info["summary_embedding"])
            doc_scores.append((doc_id, similarity))
        
        doc_scores.sort(key=lambda x: x[1], reverse=True)
        top_docs = [doc_id for doc_id, _ in doc_scores[:k*2]]
        
        chunks = self.vectorstore.similarity_search(
            query,
            k=k,
            filter={"doc_id": {"$in": top_docs}}
        )
        
        return chunks
    
    def _generate_summary(self, text):
        prompt = f"请用一句话总结以下内容:\n\n{text[:1000]}"
        return llm(prompt)

hierarchical_index = HierarchicalIndex(vectorstore)
hierarchical_index.build_hierarchical_index(documents)

多字段索引

python
from langchain.vectorstores import Chroma
from langchain.schema import Document

class MultiFieldIndex:
    def __init__(self, vectorstore):
        self.vectorstore = vectorstore
    
    def add_document(self, doc):
        title_embedding = get_embedding(doc.metadata.get("title", ""))
        content_embedding = get_embedding(doc.page_content)
        keywords_embedding = get_embedding(" ".join(doc.metadata.get("keywords", [])))
        
        self.vectorstore.add_texts(
            texts=[doc.page_content],
            metadatas=[{
                **doc.metadata,
                "title_embedding": title_embedding,
                "keywords_embedding": keywords_embedding
            }]
        )
    
    def search(self, query, weights={"content": 0.5, "title": 0.3, "keywords": 0.2}):
        query_embedding = get_embedding(query)
        
        content_results = self.vectorstore.similarity_search_by_vector(
            query_embedding,
            k=10
        )
        
        return content_results

索引性能监控

python
import time
from collections import defaultdict

class IndexMonitor:
    def __init__(self):
        self.metrics = defaultdict(list)
    
    def track_query(self, query_type, duration, result_count):
        self.metrics[f"{query_type}_duration"].append(duration)
        self.metrics[f"{query_type}_results"].append(result_count)
    
    def get_stats(self):
        stats = {}
        
        for key, values in self.metrics.items():
            if "duration" in key:
                stats[key] = {
                    "avg": sum(values) / len(values),
                    "max": max(values),
                    "min": min(values)
                }
            else:
                stats[key] = {
                    "avg": sum(values) / len(values),
                    "total": sum(values)
                }
        
        return stats

monitor = IndexMonitor()

def monitored_search(vectorstore, query, k=5):
    start_time = time.time()
    results = vectorstore.similarity_search(query, k=k)
    duration = time.time() - start_time
    
    monitor.track_query("similarity_search", duration, len(results))
    
    return results

召回率提升策略

多路召回

python
class MultiPathRetriever:
    def __init__(self, vectorstore, bm25_index):
        self.vectorstore = vectorstore
        self.bm25_index = bm25_index
    
    def retrieve(self, query, k=10):
        vector_results = self._vector_search(query, k=k*2)
        bm25_results = self._bm25_search(query, k=k*2)
        keyword_results = self._keyword_search(query, k=k*2)
        
        merged = self._merge_results(
            [vector_results, bm25_results, keyword_results],
            weights=[0.5, 0.3, 0.2]
        )
        
        return merged[:k]
    
    def _vector_search(self, query, k):
        return self.vectorstore.similarity_search(query, k=k)
    
    def _bm25_search(self, query, k):
        return self.bm25_index.get_top_k(query, k)
    
    def _keyword_search(self, query, k):
        keywords = self._extract_keywords(query)
        return self.vectorstore.search(
            filter={"keywords": {"$in": keywords}},
            k=k
        )
    
    def _merge_results(self, result_lists, weights):
        doc_scores = defaultdict(float)
        
        for results, weight in zip(result_lists, weights):
            for rank, doc in enumerate(results):
                doc_id = doc.metadata.get("id")
                score = weight / (rank + 1)
                doc_scores[doc_id] += score
        
        sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
        
        return [self._get_doc_by_id(doc_id) for doc_id, _ in sorted_docs]

查询理解与改写

python
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate

class QueryUnderstanding:
    def __init__(self):
        self.llm = OpenAI(temperature=0)
    
    def understand(self, query):
        intent = self._detect_intent(query)
        entities = self._extract_entities(query)
        expanded = self._expand_query(query)
        
        return {
            "original": query,
            "intent": intent,
            "entities": entities,
            "expanded_queries": expanded
        }
    
    def _detect_intent(self, query):
        prompt = PromptTemplate.from_template("""
分析以下查询的意图,返回最匹配的类型:
- 定义查询:询问概念定义
- 操作查询:询问如何操作
- 对比查询:询问对比分析
- 列举查询:询问列举项目
- 其他

查询:{query}

意图:
""")
        return self.llm(prompt.format(query=query)).strip()
    
    def _extract_entities(self, query):
        prompt = PromptTemplate.from_template("""
从以下查询中提取关键实体(名词、术语、专有名词):

查询:{query}

实体(JSON数组格式):
""")
        response = self.llm(prompt.format(query=query))
        return eval(response)
    
    def _expand_query(self, query):
        prompt = PromptTemplate.from_template("""
将以下查询改写为3个不同的表述,保持语义一致:

查询:{query}

改写(JSON数组格式):
""")
        response = self.llm(prompt.format(query=query))
        return eval(response)

query_understanding = QueryUnderstanding()
understood = query_understanding.understand("如何使用Python处理PDF文件?")

生成增强技巧

上下文组织

python
class ContextOrganizer:
    def __init__(self, max_tokens=4000):
        self.max_tokens = max_tokens
    
    def organize(self, query, documents):
        documents = self._deduplicate(documents)
        documents = self._sort_by_relevance(query, documents)
        documents = self._truncate(documents)
        
        context = self._format_context(documents)
        
        return context
    
    def _deduplicate(self, documents):
        seen = set()
        unique = []
        
        for doc in documents:
            doc_id = doc.metadata.get("id", hash(doc.page_content))
            if doc_id not in seen:
                seen.add(doc_id)
                unique.append(doc)
        
        return unique
    
    def _sort_by_relevance(self, query, documents):
        query_embedding = get_embedding(query)
        
        scored_docs = []
        for doc in documents:
            doc_embedding = get_embedding(doc.page_content[:500])
            similarity = cosine_similarity(query_embedding, doc_embedding)
            scored_docs.append((doc, similarity))
        
        scored_docs.sort(key=lambda x: x[1], reverse=True)
        
        return [doc for doc, _ in scored_docs]
    
    def _truncate(self, documents):
        total_tokens = 0
        truncated = []
        
        for doc in documents:
            doc_tokens = len(doc.page_content) // 4
            
            if total_tokens + doc_tokens > self.max_tokens:
                remaining = self.max_tokens - total_tokens
                if remaining > 100:
                    truncated_doc = Document(
                        page_content=doc.page_content[:remaining*4],
                        metadata=doc.metadata
                    )
                    truncated.append(truncated_doc)
                break
            
            truncated.append(doc)
            total_tokens += doc_tokens
        
        return truncated
    
    def _format_context(self, documents):
        context_parts = []
        
        for i, doc in enumerate(documents, 1):
            source = doc.metadata.get("source", "未知来源")
            context_parts.append(f"[文档{i}] 来源:{source}\n{doc.page_content}\n")
        
        return "\n".join(context_parts)

organizer = ContextOrganizer(max_tokens=4000)
context = organizer.organize(query, retrieved_docs)

引用追踪

python
class CitationTracker:
    def __init__(self):
        self.citations = []
    
    def track(self, documents):
        self.citations = []
        
        for i, doc in enumerate(documents, 1):
            citation = {
                "id": i,
                "source": doc.metadata.get("source", "未知"),
                "page": doc.metadata.get("page", "N/A"),
                "chunk_id": doc.metadata.get("chunk_id", "N/A"),
                "content_preview": doc.page_content[:100] + "..."
            }
            self.citations.append(citation)
        
        return self.citations
    
    def format_citations(self):
        formatted = "\n\n参考文献:\n"
        
        for citation in self.citations:
            formatted += f"[{citation['id']}] {citation['source']}"
            if citation['page'] != "N/A":
                formatted += f", 第{citation['page']}页"
            formatted += "\n"
        
        return formatted
    
    def add_to_response(self, response):
        return response + self.format_citations()

citation_tracker = CitationTracker()
citations = citation_tracker.track(retrieved_docs)
final_response = citation_tracker.add_to_response(llm_response)

错误处理与降级

检索失败处理

python
class RobustRetriever:
    def __init__(self, primary_retriever, fallback_retriever=None):
        self.primary_retriever = primary_retriever
        self.fallback_retriever = fallback_retriever
    
    def retrieve(self, query, k=5):
        try:
            results = self.primary_retriever.get_relevant_documents(query, k=k)
            
            if not results:
                return self._handle_empty_results(query, k)
            
            return results
        
        except Exception as e:
            return self._handle_error(query, k, str(e))
    
    def _handle_empty_results(self, query, k):
        if self.fallback_retriever:
            return self.fallback_retriever.get_relevant_documents(query, k=k)
        
        return [Document(
            page_content="抱歉,未找到相关信息。请尝试更换查询方式。",
            metadata={"type": "fallback"}
        )]
    
    def _handle_error(self, query, k, error_message):
        print(f"检索错误: {error_message}")
        
        if self.fallback_retriever:
            try:
                return self.fallback_retriever.get_relevant_documents(query, k=k)
            except:
                pass
        
        return [Document(
            page_content="检索服务暂时不可用,请稍后重试。",
            metadata={"type": "error", "error": error_message}
        )]

robust_retriever = RobustRetriever(
    primary_retriever=vector_retriever,
    fallback_retriever=bm25_retriever
)

响应质量检查

python
class ResponseQualityChecker:
    def __init__(self):
        self.min_length = 20
        self.max_length = 2000
    
    def check(self, response, context):
        issues = []
        
        if len(response) < self.min_length:
            issues.append("response_too_short")
        
        if len(response) > self.max_length:
            issues.append("response_too_long")
        
        if not self._has_context_support(response, context):
            issues.append("no_context_support")
        
        if self._has_hallucination_indicators(response):
            issues.append("potential_hallucination")
        
        return {
            "is_valid": len(issues) == 0,
            "issues": issues
        }
    
    def _has_context_support(self, response, context):
        context_keywords = set(context.split())
        response_keywords = set(response.split())
        
        overlap = context_keywords & response_keywords
        
        return len(overlap) > 5
    
    def _has_hallucination_indicators(self, response):
        indicators = [
            "我不确定",
            "可能",
            "猜测",
            "应该"
        ]
        
        return any(indicator in response for indicator in indicators)

quality_checker = ResponseQualityChecker()
quality = quality_checker.check(response, context)

监控与日志

完整监控系统

python
import logging
from datetime import datetime
import json

class RAGMonitor:
    def __init__(self, log_file="rag_monitor.log"):
        self.logger = self._setup_logger(log_file)
        self.metrics = {
            "queries": 0,
            "total_latency": 0,
            "errors": 0,
            "avg_retrieval_time": 0,
            "avg_generation_time": 0
        }
    
    def _setup_logger(self, log_file):
        logger = logging.getLogger("RAGMonitor")
        logger.setLevel(logging.INFO)
        
        handler = logging.FileHandler(log_file, encoding='utf-8')
        handler.setFormatter(logging.Formatter(
            '%(asctime)s - %(levelname)s - %(message)s'
        ))
        logger.addHandler(handler)
        
        return logger
    
    def log_query(self, query, retrieval_time, generation_time, result_count):
        self.metrics["queries"] += 1
        self.metrics["total_latency"] += retrieval_time + generation_time
        
        self.logger.info(json.dumps({
            "event": "query",
            "query": query[:100],
            "retrieval_time": retrieval_time,
            "generation_time": generation_time,
            "result_count": result_count,
            "timestamp": datetime.now().isoformat()
        }))
    
    def log_error(self, query, error_type, error_message):
        self.metrics["errors"] += 1
        
        self.logger.error(json.dumps({
            "event": "error",
            "query": query[:100],
            "error_type": error_type,
            "error_message": error_message,
            "timestamp": datetime.now().isoformat()
        }))
    
    def get_stats(self):
        if self.metrics["queries"] == 0:
            return self.metrics
        
        return {
            **self.metrics,
            "avg_latency": self.metrics["total_latency"] / self.metrics["queries"],
            "error_rate": self.metrics["errors"] / self.metrics["queries"]
        }

monitor = RAGMonitor()

def monitored_rag_query(query, retriever, generator):
    import time
    
    try:
        start_time = time.time()
        docs = retriever.get_relevant_documents(query)
        retrieval_time = time.time() - start_time
        
        start_time = time.time()
        response = generator.generate(query, docs)
        generation_time = time.time() - start_time
        
        monitor.log_query(query, retrieval_time, generation_time, len(docs))
        
        return response
    
    except Exception as e:
        monitor.log_error(query, type(e).__name__, str(e))
        raise

小结

构建生产级RAG系统需要在各个环节精益求精:文档预处理确保数据质量,索引优化提升检索效率,召回率增强保证结果全面,生成增强提高回答质量。同时,完善的错误处理和监控体系是系统稳定运行的保障。

RAG技术栈发展迅速,建议持续关注新技术和最佳实践,不断优化和迭代系统。结合实际业务场景,灵活应用这些技术,才能构建出真正有价值的RAG应用。