Skip to content

记忆系统设计

概述

记忆系统是Agent实现持续交互和知识积累的核心组件。通过记忆系统,Agent可以记住历史对话、用户偏好、任务状态等信息,从而提供更连贯、个性化的服务。一个设计良好的记忆系统需要平衡存储效率、检索速度和相关性。

记忆类型

1. 短期记忆(Short-term Memory)

存储当前会话的上下文信息,通常以对话历史的形式存在。

python
from typing import List, Dict
from dataclasses import dataclass
from datetime import datetime

@dataclass
class Message:
    role: str
    content: str
    timestamp: datetime = None
    
    def __post_init__(self):
        if self.timestamp is None:
            self.timestamp = datetime.now()

class ShortTermMemory:
    def __init__(self, max_messages: int = 20):
        self.messages: List[Message] = []
        self.max_messages = max_messages
    
    def add(self, role: str, content: str):
        message = Message(role=role, content=content)
        self.messages.append(message)
        
        if len(self.messages) > self.max_messages:
            self.messages = self.messages[-self.max_messages:]
    
    def get_context(self, max_tokens: int = 4000) -> List[Dict]:
        context = []
        total_tokens = 0
        
        for message in reversed(self.messages):
            tokens = self.estimate_tokens(message.content)
            
            if total_tokens + tokens > max_tokens:
                break
            
            context.insert(0, {
                "role": message.role,
                "content": message.content
            })
            total_tokens += tokens
        
        return context
    
    def clear(self):
        self.messages = []
    
    def estimate_tokens(self, text: str) -> int:
        return len(text) // 4

2. 长期记忆(Long-term Memory)

持久化存储重要信息,支持跨会话检索。

python
import json
from pathlib import Path
from typing import Any, Optional

class LongTermMemory:
    def __init__(self, storage_path: str = "./memory"):
        self.storage_path = Path(storage_path)
        self.storage_path.mkdir(parents=True, exist_ok=True)
        self.index_file = self.storage_path / "index.json"
        self.index = self._load_index()
    
    def _load_index(self) -> Dict:
        if self.index_file.exists():
            with open(self.index_file, 'r', encoding='utf-8') as f:
                return json.load(f)
        return {}
    
    def _save_index(self):
        with open(self.index_file, 'w', encoding='utf-8') as f:
            json.dump(self.index, f, ensure_ascii=False, indent=2)
    
    def store(self, key: str, value: Any, metadata: Dict = None):
        memory_item = {
            "value": value,
            "metadata": metadata or {},
            "timestamp": datetime.now().isoformat()
        }
        
        file_path = self.storage_path / f"{key}.json"
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(memory_item, f, ensure_ascii=False, indent=2)
        
        self.index[key] = {
            "file": str(file_path),
            "metadata": metadata
        }
        self._save_index()
    
    def retrieve(self, key: str) -> Optional[Any]:
        if key not in self.index:
            return None
        
        file_path = Path(self.index[key]["file"])
        if not file_path.exists():
            return None
        
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        return data["value"]
    
    def search(self, query: str) -> List[Dict]:
        results = []
        
        for key, info in self.index.items():
            if query.lower() in key.lower():
                value = self.retrieve(key)
                results.append({
                    "key": key,
                    "value": value,
                    "metadata": info.get("metadata", {})
                })
        
        return results
    
    def delete(self, key: str):
        if key in self.index:
            file_path = Path(self.index[key]["file"])
            if file_path.exists():
                file_path.unlink()
            
            del self.index[key]
            self._save_index()

3. 向量记忆(Vector Memory)

基于向量相似度的语义记忆检索。

python
import numpy as np
from typing import List, Tuple
from dataclasses import dataclass

@dataclass
class VectorMemoryItem:
    id: str
    content: str
    embedding: np.ndarray
    metadata: Dict

class VectorMemory:
    def __init__(self, embedding_model, dimension: int = 1536):
        self.embedding_model = embedding_model
        self.dimension = dimension
        self.items: List[VectorMemoryItem] = []
        self.embeddings_matrix = None
    
    def add(self, id: str, content: str, metadata: Dict = None):
        embedding = self.embedding_model.embed(content)
        
        item = VectorMemoryItem(
            id=id,
            content=content,
            embedding=embedding,
            metadata=metadata or {}
        )
        
        self.items.append(item)
        self._update_embeddings_matrix()
    
    def _update_embeddings_matrix(self):
        if self.items:
            self.embeddings_matrix = np.vstack([
                item.embedding for item in self.items
            ])
    
    def search(self, query: str, top_k: int = 5) -> List[Tuple[str, float, Dict]]:
        if not self.items:
            return []
        
        query_embedding = self.embedding_model.embed(query)
        
        similarities = self._cosine_similarity(
            query_embedding,
            self.embeddings_matrix
        )
        
        top_indices = np.argsort(similarities)[::-1][:top_k]
        
        results = []
        for idx in top_indices:
            item = self.items[idx]
            results.append((
                item.content,
                float(similarities[idx]),
                item.metadata
            ))
        
        return results
    
    def _cosine_similarity(self, query: np.ndarray, matrix: np.ndarray) -> np.ndarray:
        query_norm = query / np.linalg.norm(query)
        matrix_norm = matrix / np.linalg.norm(matrix, axis=1, keepdims=True)
        
        return np.dot(matrix_norm, query_norm)
    
    def delete(self, id: str):
        self.items = [item for item in self.items if item.id != id]
        self._update_embeddings_matrix()

记忆架构设计

分层记忆架构

python
class HierarchicalMemory:
    def __init__(self, llm, embedding_model):
        self.working_memory = WorkingMemory()
        self.episodic_memory = EpisodicMemory(llm)
        self.semantic_memory = SemanticMemory(embedding_model)
        self.procedural_memory = ProceduralMemory()
    
    def remember(self, content: str, memory_type: str = "auto"):
        if memory_type == "auto":
            memory_type = self._classify_memory(content)
        
        if memory_type == "working":
            self.working_memory.add(content)
        elif memory_type == "episodic":
            self.episodic_memory.add(content)
        elif memory_type == "semantic":
            self.semantic_memory.add(content)
        elif memory_type == "procedural":
            self.procedural_memory.add(content)
    
    def recall(self, query: str, memory_types: List[str] = None):
        memory_types = memory_types or ["working", "episodic", "semantic"]
        
        results = []
        
        if "working" in memory_types:
            results.extend(self.working_memory.search(query))
        
        if "episodic" in memory_types:
            results.extend(self.episodic_memory.search(query))
        
        if "semantic" in memory_types:
            results.extend(self.semantic_memory.search(query))
        
        return self._merge_results(results)
    
    def _classify_memory(self, content: str) -> str:
        pass
    
    def _merge_results(self, results: List) -> List:
        return sorted(results, key=lambda x: x["relevance"], reverse=True)

class WorkingMemory:
    def __init__(self, capacity: int = 7):
        self.capacity = capacity
        self.items = []
    
    def add(self, content: str):
        self.items.append({
            "content": content,
            "timestamp": datetime.now()
        })
        
        if len(self.items) > self.capacity:
            self.items.pop(0)
    
    def search(self, query: str) -> List[Dict]:
        return [
            {"content": item["content"], "relevance": 1.0}
            for item in self.items
        ]

class EpisodicMemory:
    def __init__(self, llm):
        self.llm = llm
        self.episodes = []
    
    def add(self, content: str):
        summary = self.llm.generate(f"总结以下内容:{content}")
        
        self.episodes.append({
            "content": content,
            "summary": summary,
            "timestamp": datetime.now()
        })
    
    def search(self, query: str) -> List[Dict]:
        results = []
        
        for episode in self.episodes:
            relevance = self._calculate_relevance(query, episode["summary"])
            if relevance > 0.5:
                results.append({
                    "content": episode["content"],
                    "relevance": relevance
                })
        
        return results
    
    def _calculate_relevance(self, query: str, summary: str) -> float:
        pass

class SemanticMemory:
    def __init__(self, embedding_model):
        self.vector_memory = VectorMemory(embedding_model)
    
    def add(self, content: str):
        self.vector_memory.add(
            id=str(hash(content)),
            content=content
        )
    
    def search(self, query: str, top_k: int = 5) -> List[Dict]:
        results = self.vector_memory.search(query, top_k)
        
        return [
            {"content": content, "relevance": score}
            for content, score, _ in results
        ]

class ProceduralMemory:
    def __init__(self):
        self.procedures = {}
    
    def add(self, name: str, steps: List[str]):
        self.procedures[name] = {
            "steps": steps,
            "usage_count": 0
        }
    
    def get(self, name: str) -> Optional[List[str]]:
        if name in self.procedures:
            self.procedures[name]["usage_count"] += 1
            return self.procedures[name]["steps"]
        return None
    
    def get_most_used(self, top_k: int = 5) -> List[str]:
        sorted_procedures = sorted(
            self.procedures.items(),
            key=lambda x: x[1]["usage_count"],
            reverse=True
        )
        return [name for name, _ in sorted_procedures[:top_k]]

记忆管理策略

1. 滑动窗口策略

python
class SlidingWindowMemory:
    def __init__(self, window_size: int = 10):
        self.window_size = window_size
        self.memory = []
    
    def add(self, item: Dict):
        self.memory.append(item)
        
        if len(self.memory) > self.window_size:
            self.memory.pop(0)
    
    def get_recent(self, n: int = None) -> List[Dict]:
        n = n or self.window_size
        return self.memory[-n:]

2. 摘要压缩策略

python
class SummarizingMemory:
    def __init__(self, llm, max_items: int = 20):
        self.llm = llm
        self.max_items = max_items
        self.items = []
        self.summary = ""
    
    def add(self, content: str):
        self.items.append(content)
        
        if len(self.items) > self.max_items:
            self._summarize_old_items()
    
    def _summarize_old_items(self):
        old_items = self.items[:self.max_items // 2]
        new_summary = self.llm.generate(
            f"总结以下内容:\n" + "\n".join(old_items)
        )
        
        if self.summary:
            self.summary = self.llm.generate(
                f"合并两个摘要:\n摘要1: {self.summary}\n摘要2: {new_summary}"
            )
        else:
            self.summary = new_summary
        
        self.items = self.items[self.max_items // 2:]
    
    def get_context(self) -> str:
        context_parts = []
        
        if self.summary:
            context_parts.append(f"历史摘要:{self.summary}")
        
        if self.items:
            context_parts.append("最近对话:\n" + "\n".join(self.items))
        
        return "\n\n".join(context_parts)

3. 重要性评分策略

python
class ImportanceBasedMemory:
    def __init__(self, llm, max_items: int = 100):
        self.llm = llm
        self.max_items = max_items
        self.items = []
    
    def add(self, content: str):
        importance = self._calculate_importance(content)
        
        self.items.append({
            "content": content,
            "importance": importance,
            "timestamp": datetime.now(),
            "access_count": 0
        })
        
        if len(self.items) > self.max_items:
            self._prune_low_importance()
    
    def _calculate_importance(self, content: str) -> float:
        prompt = f"""
        评估以下内容的重要性(0-1分):
        内容:{content}
        
        考虑因素:
        1. 是否包含关键信息
        2. 是否是用户偏好
        3. 是否是重要决策
        4. 是否需要长期记忆
        
        只返回分数。
        """
        
        score = self.llm.generate(prompt)
        return float(score.strip())
    
    def _prune_low_importance(self):
        self.items.sort(key=lambda x: x["importance"], reverse=True)
        self.items = self.items[:int(self.max_items * 0.8)]
    
    def search(self, query: str, top_k: int = 5) -> List[Dict]:
        for item in self.items:
            item["access_count"] += 1
        
        scored_items = []
        for item in self.items:
            relevance = self._calculate_relevance(query, item["content"])
            score = (
                item["importance"] * 0.5 +
                relevance * 0.3 +
                min(item["access_count"] / 10, 0.2)
            )
            scored_items.append((item, score))
        
        scored_items.sort(key=lambda x: x[1], reverse=True)
        
        return [item for item, _ in scored_items[:top_k]]
    
    def _calculate_relevance(self, query: str, content: str) -> float:
        pass

4. 时间衰减策略

python
from datetime import datetime, timedelta

class TimeDecayMemory:
    def __init__(self, decay_rate: float = 0.1, half_life: int = 7):
        self.decay_rate = decay_rate
        self.half_life = half_life
        self.items = []
    
    def add(self, content: str, initial_weight: float = 1.0):
        self.items.append({
            "content": content,
            "timestamp": datetime.now(),
            "initial_weight": initial_weight
        })
    
    def get_weight(self, item: Dict) -> float:
        age_days = (datetime.now() - item["timestamp"]).days
        decay = (1 - self.decay_rate) ** age_days
        return item["initial_weight"] * decay
    
    def get_active_items(self, threshold: float = 0.3) -> List[Dict]:
        active_items = []
        
        for item in self.items:
            weight = self.get_weight(item)
            if weight >= threshold:
                active_items.append({
                    **item,
                    "current_weight": weight
                })
        
        return sorted(active_items, key=lambda x: x["current_weight"], reverse=True)
    
    def cleanup(self, threshold: float = 0.1):
        self.items = [
            item for item in self.items
            if self.get_weight(item) >= threshold
        ]

记忆检索优化

1. 混合检索

python
class HybridRetrieval:
    def __init__(self, embedding_model, llm):
        self.vector_memory = VectorMemory(embedding_model)
        self.keyword_index = {}
        self.llm = llm
    
    def add(self, id: str, content: str, metadata: Dict = None):
        self.vector_memory.add(id, content, metadata)
        
        keywords = self._extract_keywords(content)
        for keyword in keywords:
            if keyword not in self.keyword_index:
                self.keyword_index[keyword] = []
            self.keyword_index[keyword].append(id)
    
    def search(self, query: str, top_k: int = 5) -> List[Dict]:
        vector_results = self.vector_memory.search(query, top_k * 2)
        
        keywords = self._extract_keywords(query)
        keyword_results = set()
        for keyword in keywords:
            if keyword in self.keyword_index:
                keyword_results.update(self.keyword_index[keyword])
        
        merged_results = self._merge_results(
            vector_results,
            keyword_results,
            top_k
        )
        
        return merged_results
    
    def _extract_keywords(self, text: str) -> List[str]:
        prompt = f"提取以下文本的关键词(用逗号分隔):{text}"
        keywords_str = self.llm.generate(prompt)
        return [k.strip() for k in keywords_str.split(",")]
    
    def _merge_results(self, vector_results, keyword_ids, top_k):
        results = []
        
        for content, score, metadata in vector_results:
            boost = 1.2 if hash(content) in keyword_ids else 1.0
            results.append({
                "content": content,
                "score": score * boost,
                "metadata": metadata
            })
        
        return sorted(results, key=lambda x: x["score"], reverse=True)[:top_k]

2. 重排序

python
class Reranker:
    def __init__(self, llm):
        self.llm = llm
    
    def rerank(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]:
        if len(results) <= top_k:
            return results
        
        scores = []
        for result in results:
            score = self._calculate_relevance(query, result["content"])
            scores.append((result, score))
        
        scores.sort(key=lambda x: x[1], reverse=True)
        
        return [result for result, _ in scores[:top_k]]
    
    def _calculate_relevance(self, query: str, content: str) -> float:
        prompt = f"""
        评估查询和内容的相关性(0-1分):
        查询:{query}
        内容:{content}
        
        只返回分数。
        """
        
        score = self.llm.generate(prompt)
        return float(score.strip())

3. 查询扩展

python
class QueryExpander:
    def __init__(self, llm):
        self.llm = llm
    
    def expand(self, query: str, num_expansions: int = 3) -> List[str]:
        prompt = f"""
        为以下查询生成{num_expansions}个相关的搜索词:
        查询:{query}
        
        每行一个搜索词。
        """
        
        expansions = self.llm.generate(prompt).strip().split("\n")
        
        return [query] + [e.strip() for e in expansions[:num_expansions]]
    
    def search_with_expansion(self, query: str, memory, top_k: int = 5) -> List[Dict]:
        expanded_queries = self.expand(query)
        
        all_results = []
        for q in expanded_queries:
            results = memory.search(q, top_k)
            all_results.extend(results)
        
        unique_results = self._deduplicate(all_results)
        
        return unique_results[:top_k]
    
    def _deduplicate(self, results: List[Dict]) -> List[Dict]:
        seen = set()
        unique = []
        
        for result in results:
            content_hash = hash(result["content"])
            if content_hash not in seen:
                seen.add(content_hash)
                unique.append(result)
        
        return unique

完整记忆系统实现

python
class AgentMemorySystem:
    def __init__(self, llm, embedding_model):
        self.short_term = ShortTermMemory(max_messages=20)
        self.long_term = LongTermMemory()
        self.vector_memory = VectorMemory(embedding_model)
        self.summarizer = SummarizingMemory(llm)
        self.importance_memory = ImportanceBasedMemory(llm)
        
        self.retriever = HybridRetrieval(embedding_model, llm)
        self.reranker = Reranker(llm)
        self.query_expander = QueryExpander(llm)
    
    def add_message(self, role: str, content: str):
        self.short_term.add(role, content)
        self.summarizer.add(f"{role}: {content}")
        
        if self._is_important(content):
            self.importance_memory.add(content)
            self.vector_memory.add(
                id=str(hash(content)),
                content=content
            )
    
    def recall(self, query: str, top_k: int = 5) -> List[Dict]:
        expanded_queries = self.query_expander.expand(query)
        
        results = []
        
        for q in expanded_queries:
            vector_results = self.vector_memory.search(q, top_k)
            results.extend([
                {"content": c, "score": s, "source": "vector"}
                for c, s, _ in vector_results
            ])
        
        important_results = self.importance_memory.search(query, top_k)
        results.extend([
            {"content": r["content"], "score": r["importance"], "source": "importance"}
            for r in important_results
        ])
        
        results = self._deduplicate(results)
        results = self.reranker.rerank(query, results, top_k)
        
        return results
    
    def get_context_for_query(self, query: str, max_tokens: int = 2000) -> str:
        recent_context = self.short_term.get_context(max_tokens // 2)
        
        relevant_memories = self.recall(query, top_k=3)
        
        context_parts = []
        
        if recent_context:
            context_parts.append("最近对话:\n" + self._format_context(recent_context))
        
        if relevant_memories:
            context_parts.append(
                "相关记忆:\n" + 
                "\n".join([m["content"] for m in relevant_memories])
            )
        
        return "\n\n".join(context_parts)
    
    def _is_important(self, content: str) -> bool:
        keywords = ["重要", "记住", "偏好", "设置", "决定"]
        return any(kw in content for kw in keywords)
    
    def _deduplicate(self, results: List[Dict]) -> List[Dict]:
        seen = set()
        unique = []
        
        for result in results:
            content_hash = hash(result["content"])
            if content_hash not in seen:
                seen.add(content_hash)
                unique.append(result)
        
        return unique
    
    def _format_context(self, context: List[Dict]) -> str:
        return "\n".join([
            f"{msg['role']}: {msg['content']}"
            for msg in context
        ])

小结

记忆系统是Agent实现智能交互的基础,设计良好的记忆系统需要:

  1. 分层设计 - 短期记忆、长期记忆、向量记忆各司其职
  2. 智能管理 - 滑动窗口、摘要压缩、重要性评分、时间衰减
  3. 高效检索 - 混合检索、重排序、查询扩展
  4. 平衡性能 - 在存储效率和检索速度之间找到平衡
  5. 持续优化 - 根据使用情况动态调整记忆策略

下一章我们将探讨多Agent协作,学习如何让多个Agent协同工作完成复杂任务。