记忆系统设计
概述
记忆系统是Agent实现持续交互和知识积累的核心组件。通过记忆系统,Agent可以记住历史对话、用户偏好、任务状态等信息,从而提供更连贯、个性化的服务。一个设计良好的记忆系统需要平衡存储效率、检索速度和相关性。
记忆类型
1. 短期记忆(Short-term Memory)
存储当前会话的上下文信息,通常以对话历史的形式存在。
python
from typing import List, Dict
from dataclasses import dataclass
from datetime import datetime
@dataclass
class Message:
role: str
content: str
timestamp: datetime = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = datetime.now()
class ShortTermMemory:
def __init__(self, max_messages: int = 20):
self.messages: List[Message] = []
self.max_messages = max_messages
def add(self, role: str, content: str):
message = Message(role=role, content=content)
self.messages.append(message)
if len(self.messages) > self.max_messages:
self.messages = self.messages[-self.max_messages:]
def get_context(self, max_tokens: int = 4000) -> List[Dict]:
context = []
total_tokens = 0
for message in reversed(self.messages):
tokens = self.estimate_tokens(message.content)
if total_tokens + tokens > max_tokens:
break
context.insert(0, {
"role": message.role,
"content": message.content
})
total_tokens += tokens
return context
def clear(self):
self.messages = []
def estimate_tokens(self, text: str) -> int:
return len(text) // 42. 长期记忆(Long-term Memory)
持久化存储重要信息,支持跨会话检索。
python
import json
from pathlib import Path
from typing import Any, Optional
class LongTermMemory:
def __init__(self, storage_path: str = "./memory"):
self.storage_path = Path(storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
self.index_file = self.storage_path / "index.json"
self.index = self._load_index()
def _load_index(self) -> Dict:
if self.index_file.exists():
with open(self.index_file, 'r', encoding='utf-8') as f:
return json.load(f)
return {}
def _save_index(self):
with open(self.index_file, 'w', encoding='utf-8') as f:
json.dump(self.index, f, ensure_ascii=False, indent=2)
def store(self, key: str, value: Any, metadata: Dict = None):
memory_item = {
"value": value,
"metadata": metadata or {},
"timestamp": datetime.now().isoformat()
}
file_path = self.storage_path / f"{key}.json"
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(memory_item, f, ensure_ascii=False, indent=2)
self.index[key] = {
"file": str(file_path),
"metadata": metadata
}
self._save_index()
def retrieve(self, key: str) -> Optional[Any]:
if key not in self.index:
return None
file_path = Path(self.index[key]["file"])
if not file_path.exists():
return None
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data["value"]
def search(self, query: str) -> List[Dict]:
results = []
for key, info in self.index.items():
if query.lower() in key.lower():
value = self.retrieve(key)
results.append({
"key": key,
"value": value,
"metadata": info.get("metadata", {})
})
return results
def delete(self, key: str):
if key in self.index:
file_path = Path(self.index[key]["file"])
if file_path.exists():
file_path.unlink()
del self.index[key]
self._save_index()3. 向量记忆(Vector Memory)
基于向量相似度的语义记忆检索。
python
import numpy as np
from typing import List, Tuple
from dataclasses import dataclass
@dataclass
class VectorMemoryItem:
id: str
content: str
embedding: np.ndarray
metadata: Dict
class VectorMemory:
def __init__(self, embedding_model, dimension: int = 1536):
self.embedding_model = embedding_model
self.dimension = dimension
self.items: List[VectorMemoryItem] = []
self.embeddings_matrix = None
def add(self, id: str, content: str, metadata: Dict = None):
embedding = self.embedding_model.embed(content)
item = VectorMemoryItem(
id=id,
content=content,
embedding=embedding,
metadata=metadata or {}
)
self.items.append(item)
self._update_embeddings_matrix()
def _update_embeddings_matrix(self):
if self.items:
self.embeddings_matrix = np.vstack([
item.embedding for item in self.items
])
def search(self, query: str, top_k: int = 5) -> List[Tuple[str, float, Dict]]:
if not self.items:
return []
query_embedding = self.embedding_model.embed(query)
similarities = self._cosine_similarity(
query_embedding,
self.embeddings_matrix
)
top_indices = np.argsort(similarities)[::-1][:top_k]
results = []
for idx in top_indices:
item = self.items[idx]
results.append((
item.content,
float(similarities[idx]),
item.metadata
))
return results
def _cosine_similarity(self, query: np.ndarray, matrix: np.ndarray) -> np.ndarray:
query_norm = query / np.linalg.norm(query)
matrix_norm = matrix / np.linalg.norm(matrix, axis=1, keepdims=True)
return np.dot(matrix_norm, query_norm)
def delete(self, id: str):
self.items = [item for item in self.items if item.id != id]
self._update_embeddings_matrix()记忆架构设计
分层记忆架构
python
class HierarchicalMemory:
def __init__(self, llm, embedding_model):
self.working_memory = WorkingMemory()
self.episodic_memory = EpisodicMemory(llm)
self.semantic_memory = SemanticMemory(embedding_model)
self.procedural_memory = ProceduralMemory()
def remember(self, content: str, memory_type: str = "auto"):
if memory_type == "auto":
memory_type = self._classify_memory(content)
if memory_type == "working":
self.working_memory.add(content)
elif memory_type == "episodic":
self.episodic_memory.add(content)
elif memory_type == "semantic":
self.semantic_memory.add(content)
elif memory_type == "procedural":
self.procedural_memory.add(content)
def recall(self, query: str, memory_types: List[str] = None):
memory_types = memory_types or ["working", "episodic", "semantic"]
results = []
if "working" in memory_types:
results.extend(self.working_memory.search(query))
if "episodic" in memory_types:
results.extend(self.episodic_memory.search(query))
if "semantic" in memory_types:
results.extend(self.semantic_memory.search(query))
return self._merge_results(results)
def _classify_memory(self, content: str) -> str:
pass
def _merge_results(self, results: List) -> List:
return sorted(results, key=lambda x: x["relevance"], reverse=True)
class WorkingMemory:
def __init__(self, capacity: int = 7):
self.capacity = capacity
self.items = []
def add(self, content: str):
self.items.append({
"content": content,
"timestamp": datetime.now()
})
if len(self.items) > self.capacity:
self.items.pop(0)
def search(self, query: str) -> List[Dict]:
return [
{"content": item["content"], "relevance": 1.0}
for item in self.items
]
class EpisodicMemory:
def __init__(self, llm):
self.llm = llm
self.episodes = []
def add(self, content: str):
summary = self.llm.generate(f"总结以下内容:{content}")
self.episodes.append({
"content": content,
"summary": summary,
"timestamp": datetime.now()
})
def search(self, query: str) -> List[Dict]:
results = []
for episode in self.episodes:
relevance = self._calculate_relevance(query, episode["summary"])
if relevance > 0.5:
results.append({
"content": episode["content"],
"relevance": relevance
})
return results
def _calculate_relevance(self, query: str, summary: str) -> float:
pass
class SemanticMemory:
def __init__(self, embedding_model):
self.vector_memory = VectorMemory(embedding_model)
def add(self, content: str):
self.vector_memory.add(
id=str(hash(content)),
content=content
)
def search(self, query: str, top_k: int = 5) -> List[Dict]:
results = self.vector_memory.search(query, top_k)
return [
{"content": content, "relevance": score}
for content, score, _ in results
]
class ProceduralMemory:
def __init__(self):
self.procedures = {}
def add(self, name: str, steps: List[str]):
self.procedures[name] = {
"steps": steps,
"usage_count": 0
}
def get(self, name: str) -> Optional[List[str]]:
if name in self.procedures:
self.procedures[name]["usage_count"] += 1
return self.procedures[name]["steps"]
return None
def get_most_used(self, top_k: int = 5) -> List[str]:
sorted_procedures = sorted(
self.procedures.items(),
key=lambda x: x[1]["usage_count"],
reverse=True
)
return [name for name, _ in sorted_procedures[:top_k]]记忆管理策略
1. 滑动窗口策略
python
class SlidingWindowMemory:
def __init__(self, window_size: int = 10):
self.window_size = window_size
self.memory = []
def add(self, item: Dict):
self.memory.append(item)
if len(self.memory) > self.window_size:
self.memory.pop(0)
def get_recent(self, n: int = None) -> List[Dict]:
n = n or self.window_size
return self.memory[-n:]2. 摘要压缩策略
python
class SummarizingMemory:
def __init__(self, llm, max_items: int = 20):
self.llm = llm
self.max_items = max_items
self.items = []
self.summary = ""
def add(self, content: str):
self.items.append(content)
if len(self.items) > self.max_items:
self._summarize_old_items()
def _summarize_old_items(self):
old_items = self.items[:self.max_items // 2]
new_summary = self.llm.generate(
f"总结以下内容:\n" + "\n".join(old_items)
)
if self.summary:
self.summary = self.llm.generate(
f"合并两个摘要:\n摘要1: {self.summary}\n摘要2: {new_summary}"
)
else:
self.summary = new_summary
self.items = self.items[self.max_items // 2:]
def get_context(self) -> str:
context_parts = []
if self.summary:
context_parts.append(f"历史摘要:{self.summary}")
if self.items:
context_parts.append("最近对话:\n" + "\n".join(self.items))
return "\n\n".join(context_parts)3. 重要性评分策略
python
class ImportanceBasedMemory:
def __init__(self, llm, max_items: int = 100):
self.llm = llm
self.max_items = max_items
self.items = []
def add(self, content: str):
importance = self._calculate_importance(content)
self.items.append({
"content": content,
"importance": importance,
"timestamp": datetime.now(),
"access_count": 0
})
if len(self.items) > self.max_items:
self._prune_low_importance()
def _calculate_importance(self, content: str) -> float:
prompt = f"""
评估以下内容的重要性(0-1分):
内容:{content}
考虑因素:
1. 是否包含关键信息
2. 是否是用户偏好
3. 是否是重要决策
4. 是否需要长期记忆
只返回分数。
"""
score = self.llm.generate(prompt)
return float(score.strip())
def _prune_low_importance(self):
self.items.sort(key=lambda x: x["importance"], reverse=True)
self.items = self.items[:int(self.max_items * 0.8)]
def search(self, query: str, top_k: int = 5) -> List[Dict]:
for item in self.items:
item["access_count"] += 1
scored_items = []
for item in self.items:
relevance = self._calculate_relevance(query, item["content"])
score = (
item["importance"] * 0.5 +
relevance * 0.3 +
min(item["access_count"] / 10, 0.2)
)
scored_items.append((item, score))
scored_items.sort(key=lambda x: x[1], reverse=True)
return [item for item, _ in scored_items[:top_k]]
def _calculate_relevance(self, query: str, content: str) -> float:
pass4. 时间衰减策略
python
from datetime import datetime, timedelta
class TimeDecayMemory:
def __init__(self, decay_rate: float = 0.1, half_life: int = 7):
self.decay_rate = decay_rate
self.half_life = half_life
self.items = []
def add(self, content: str, initial_weight: float = 1.0):
self.items.append({
"content": content,
"timestamp": datetime.now(),
"initial_weight": initial_weight
})
def get_weight(self, item: Dict) -> float:
age_days = (datetime.now() - item["timestamp"]).days
decay = (1 - self.decay_rate) ** age_days
return item["initial_weight"] * decay
def get_active_items(self, threshold: float = 0.3) -> List[Dict]:
active_items = []
for item in self.items:
weight = self.get_weight(item)
if weight >= threshold:
active_items.append({
**item,
"current_weight": weight
})
return sorted(active_items, key=lambda x: x["current_weight"], reverse=True)
def cleanup(self, threshold: float = 0.1):
self.items = [
item for item in self.items
if self.get_weight(item) >= threshold
]记忆检索优化
1. 混合检索
python
class HybridRetrieval:
def __init__(self, embedding_model, llm):
self.vector_memory = VectorMemory(embedding_model)
self.keyword_index = {}
self.llm = llm
def add(self, id: str, content: str, metadata: Dict = None):
self.vector_memory.add(id, content, metadata)
keywords = self._extract_keywords(content)
for keyword in keywords:
if keyword not in self.keyword_index:
self.keyword_index[keyword] = []
self.keyword_index[keyword].append(id)
def search(self, query: str, top_k: int = 5) -> List[Dict]:
vector_results = self.vector_memory.search(query, top_k * 2)
keywords = self._extract_keywords(query)
keyword_results = set()
for keyword in keywords:
if keyword in self.keyword_index:
keyword_results.update(self.keyword_index[keyword])
merged_results = self._merge_results(
vector_results,
keyword_results,
top_k
)
return merged_results
def _extract_keywords(self, text: str) -> List[str]:
prompt = f"提取以下文本的关键词(用逗号分隔):{text}"
keywords_str = self.llm.generate(prompt)
return [k.strip() for k in keywords_str.split(",")]
def _merge_results(self, vector_results, keyword_ids, top_k):
results = []
for content, score, metadata in vector_results:
boost = 1.2 if hash(content) in keyword_ids else 1.0
results.append({
"content": content,
"score": score * boost,
"metadata": metadata
})
return sorted(results, key=lambda x: x["score"], reverse=True)[:top_k]2. 重排序
python
class Reranker:
def __init__(self, llm):
self.llm = llm
def rerank(self, query: str, results: List[Dict], top_k: int = 5) -> List[Dict]:
if len(results) <= top_k:
return results
scores = []
for result in results:
score = self._calculate_relevance(query, result["content"])
scores.append((result, score))
scores.sort(key=lambda x: x[1], reverse=True)
return [result for result, _ in scores[:top_k]]
def _calculate_relevance(self, query: str, content: str) -> float:
prompt = f"""
评估查询和内容的相关性(0-1分):
查询:{query}
内容:{content}
只返回分数。
"""
score = self.llm.generate(prompt)
return float(score.strip())3. 查询扩展
python
class QueryExpander:
def __init__(self, llm):
self.llm = llm
def expand(self, query: str, num_expansions: int = 3) -> List[str]:
prompt = f"""
为以下查询生成{num_expansions}个相关的搜索词:
查询:{query}
每行一个搜索词。
"""
expansions = self.llm.generate(prompt).strip().split("\n")
return [query] + [e.strip() for e in expansions[:num_expansions]]
def search_with_expansion(self, query: str, memory, top_k: int = 5) -> List[Dict]:
expanded_queries = self.expand(query)
all_results = []
for q in expanded_queries:
results = memory.search(q, top_k)
all_results.extend(results)
unique_results = self._deduplicate(all_results)
return unique_results[:top_k]
def _deduplicate(self, results: List[Dict]) -> List[Dict]:
seen = set()
unique = []
for result in results:
content_hash = hash(result["content"])
if content_hash not in seen:
seen.add(content_hash)
unique.append(result)
return unique完整记忆系统实现
python
class AgentMemorySystem:
def __init__(self, llm, embedding_model):
self.short_term = ShortTermMemory(max_messages=20)
self.long_term = LongTermMemory()
self.vector_memory = VectorMemory(embedding_model)
self.summarizer = SummarizingMemory(llm)
self.importance_memory = ImportanceBasedMemory(llm)
self.retriever = HybridRetrieval(embedding_model, llm)
self.reranker = Reranker(llm)
self.query_expander = QueryExpander(llm)
def add_message(self, role: str, content: str):
self.short_term.add(role, content)
self.summarizer.add(f"{role}: {content}")
if self._is_important(content):
self.importance_memory.add(content)
self.vector_memory.add(
id=str(hash(content)),
content=content
)
def recall(self, query: str, top_k: int = 5) -> List[Dict]:
expanded_queries = self.query_expander.expand(query)
results = []
for q in expanded_queries:
vector_results = self.vector_memory.search(q, top_k)
results.extend([
{"content": c, "score": s, "source": "vector"}
for c, s, _ in vector_results
])
important_results = self.importance_memory.search(query, top_k)
results.extend([
{"content": r["content"], "score": r["importance"], "source": "importance"}
for r in important_results
])
results = self._deduplicate(results)
results = self.reranker.rerank(query, results, top_k)
return results
def get_context_for_query(self, query: str, max_tokens: int = 2000) -> str:
recent_context = self.short_term.get_context(max_tokens // 2)
relevant_memories = self.recall(query, top_k=3)
context_parts = []
if recent_context:
context_parts.append("最近对话:\n" + self._format_context(recent_context))
if relevant_memories:
context_parts.append(
"相关记忆:\n" +
"\n".join([m["content"] for m in relevant_memories])
)
return "\n\n".join(context_parts)
def _is_important(self, content: str) -> bool:
keywords = ["重要", "记住", "偏好", "设置", "决定"]
return any(kw in content for kw in keywords)
def _deduplicate(self, results: List[Dict]) -> List[Dict]:
seen = set()
unique = []
for result in results:
content_hash = hash(result["content"])
if content_hash not in seen:
seen.add(content_hash)
unique.append(result)
return unique
def _format_context(self, context: List[Dict]) -> str:
return "\n".join([
f"{msg['role']}: {msg['content']}"
for msg in context
])小结
记忆系统是Agent实现智能交互的基础,设计良好的记忆系统需要:
- 分层设计 - 短期记忆、长期记忆、向量记忆各司其职
- 智能管理 - 滑动窗口、摘要压缩、重要性评分、时间衰减
- 高效检索 - 混合检索、重排序、查询扩展
- 平衡性能 - 在存储效率和检索速度之间找到平衡
- 持续优化 - 根据使用情况动态调整记忆策略
下一章我们将探讨多Agent协作,学习如何让多个Agent协同工作完成复杂任务。