缓存策略
概述
缓存是优化AI应用性能和降低成本的重要手段。通过合理使用缓存,可以避免重复计算、减少API调用次数、提升响应速度。本文将介绍多种缓存策略及其实现方法,帮助开发者构建高效的AI应用。
核心内容
缓存类型
1. 精确匹配缓存
适用于完全相同的请求,实现简单高效:
python
from functools import lru_cache
import hashlib
class ExactMatchCache:
def __init__(self, maxsize=1000):
self.cache = {}
self.maxsize = maxsize
self.access_count = {}
def get(self, prompt, model, params):
cache_key = self._generate_key(prompt, model, params)
if cache_key in self.cache:
self.access_count[cache_key] += 1
return self.cache[cache_key]
return None
def set(self, prompt, model, params, response):
cache_key = self._generate_key(prompt, model, params)
if len(self.cache) >= self.maxsize:
self._evict_lru()
self.cache[cache_key] = response
self.access_count[cache_key] = 1
def _generate_key(self, prompt, model, params):
content = f"{prompt}|{model}|{str(sorted(params.items()))}"
return hashlib.sha256(content.encode()).hexdigest()
def _evict_lru(self):
lru_key = min(self.access_count, key=self.access_count.get)
del self.cache[lru_key]
del self.access_count[lru_key]2. 语义相似缓存
适用于语义相近的请求,提高缓存命中率:
python
import numpy as np
from sentence_transformers import SentenceTransformer
class SemanticCache:
def __init__(self, similarity_threshold=0.95, maxsize=500):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.cache = []
self.threshold = similarity_threshold
self.maxsize = maxsize
def get(self, prompt):
prompt_embedding = self.model.encode(prompt)
for cached_prompt, embedding, response in self.cache:
similarity = self._cosine_similarity(prompt_embedding, embedding)
if similarity > self.threshold:
return response
return None
def set(self, prompt, response):
embedding = self.model.encode(prompt)
if len(self.cache) >= self.maxsize:
self.cache.pop(0)
self.cache.append((prompt, embedding, response))
def _cosine_similarity(self, a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))3. 分层缓存
结合内存缓存和持久化存储:
python
import json
import time
from pathlib import Path
class TieredCache:
def __init__(self, memory_size=100, disk_path="./cache"):
self.memory_cache = {}
self.memory_size = memory_size
self.disk_path = Path(disk_path)
self.disk_path.mkdir(exist_ok=True)
def get(self, key):
if key in self.memory_cache:
return self.memory_cache[key]
disk_cache = self._load_from_disk(key)
if disk_cache:
self._promote_to_memory(key, disk_cache)
return disk_cache
return None
def set(self, key, value, ttl=3600):
cache_entry = {
"value": value,
"expires_at": time.time() + ttl
}
if len(self.memory_cache) >= self.memory_size:
self._evict_from_memory()
self.memory_cache[key] = cache_entry
self._save_to_disk(key, cache_entry)
def _load_from_disk(self, key):
file_path = self.disk_path / f"{key}.json"
if file_path.exists():
with open(file_path, 'r') as f:
entry = json.load(f)
if entry["expires_at"] > time.time():
return entry["value"]
return None
def _save_to_disk(self, key, entry):
file_path = self.disk_path / f"{key}.json"
with open(file_path, 'w') as f:
json.dump(entry, f)缓存策略
1. TTL(Time To Live)策略
python
class TTLCache:
def __init__(self, default_ttl=3600):
self.cache = {}
self.default_ttl = default_ttl
def get(self, key):
if key in self.cache:
entry = self.cache[key]
if time.time() < entry["expires_at"]:
return entry["value"]
else:
del self.cache[key]
return None
def set(self, key, value, ttl=None):
ttl = ttl or self.default_ttl
self.cache[key] = {
"value": value,
"expires_at": time.time() + ttl,
"created_at": time.time()
}
def cleanup_expired(self):
current_time = time.time()
expired_keys = [
k for k, v in self.cache.items()
if v["expires_at"] < current_time
]
for key in expired_keys:
del self.cache[key]2. LRU(Least Recently Used)策略
python
from collections import OrderedDict
class LRUCache:
def __init__(self, capacity=1000):
self.cache = OrderedDict()
self.capacity = capacity
def get(self, key):
if key not in self.cache:
return None
self.cache.move_to_end(key)
return self.cache[key]
def set(self, key, value):
if key in self.cache:
self.cache.move_to_end(key)
else:
if len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
self.cache[key] = value
def get_stats(self):
return {
"size": len(self.cache),
"capacity": self.capacity
}3. LFU(Least Frequently Used)策略
python
import heapq
class LFUCache:
def __init__(self, capacity=1000):
self.capacity = capacity
self.cache = {}
self.frequency = {}
self.heap = []
def get(self, key):
if key not in self.cache:
return None
self.frequency[key] += 1
return self.cache[key]
def set(self, key, value):
if len(self.cache) >= self.capacity:
self._evict()
self.cache[key] = value
self.frequency[key] = 1
heapq.heappush(self.heap, (1, key))
def _evict(self):
while self.heap:
freq, key = heapq.heappop(self.heap)
if self.frequency.get(key) == freq:
del self.cache[key]
del self.frequency[key]
breakAI特定缓存优化
1. 提示词模板缓存
python
class PromptTemplateCache:
def __init__(self):
self.templates = {}
self.compiled_templates = {}
def register(self, name, template):
self.templates[name] = template
self.compiled_templates[name] = self._compile(template)
def _compile(self, template):
import re
variables = re.findall(r'\{(\w+)\}', template)
return {
"template": template,
"variables": variables
}
def render(self, name, **kwargs):
if name not in self.templates:
raise ValueError(f"Template {name} not found")
return self.templates[name].format(**kwargs)
def get_cache_key(self, name, **kwargs):
template_info = self.compiled_templates[name]
sorted_kwargs = sorted(kwargs.items())
return f"{name}|{sorted_kwargs}"2. 响应缓存装饰器
python
def cached_response(cache_instance, key_func=None):
def decorator(func):
def wrapper(*args, **kwargs):
cache_key = key_func(*args, **kwargs) if key_func else str((args, kwargs))
cached = cache_instance.get(cache_key)
if cached is not None:
return cached
result = func(*args, **kwargs)
cache_instance.set(cache_key, result)
return result
return wrapper
return decorator3. 部分响应缓存
python
class PartialResponseCache:
def __init__(self):
self.cache = {}
def cache_chunk(self, request_id, chunk_index, content):
if request_id not in self.cache:
self.cache[request_id] = {}
self.cache[request_id][chunk_index] = content
def get_chunk(self, request_id, chunk_index):
return self.cache.get(request_id, {}).get(chunk_index)
def get_full_response(self, request_id):
if request_id not in self.cache:
return None
chunks = self.cache[request_id]
return "".join(chunks[i] for i in sorted(chunks.keys()))缓存失效策略
1. 主动失效
python
class CacheInvalidator:
def __init__(self, cache):
self.cache = cache
self.dependencies = {}
def register_dependency(self, cache_key, depends_on):
if depends_on not in self.dependencies:
self.dependencies[depends_on] = []
self.dependencies[depends_on].append(cache_key)
def invalidate(self, dependency):
if dependency in self.dependencies:
for cache_key in self.dependencies[dependency]:
self.cache.delete(cache_key)
del self.dependencies[dependency]2. 版本控制失效
python
class VersionedCache:
def __init__(self):
self.cache = {}
self.versions = {}
def set_version(self, namespace, version):
self.versions[namespace] = version
def get(self, key, namespace):
version = self.versions.get(namespace, 0)
full_key = f"{namespace}:{version}:{key}"
return self.cache.get(full_key)
def set(self, key, value, namespace):
version = self.versions.get(namespace, 0)
full_key = f"{namespace}:{version}:{key}"
self.cache[full_key] = value
def invalidate_namespace(self, namespace):
self.versions[namespace] = self.versions.get(namespace, 0) + 1实用技巧
1. 缓存预热
python
def warmup_cache(common_queries, api_func, cache):
for query in common_queries:
cache_key = generate_cache_key(query)
if cache.get(cache_key) is None:
response = api_func(query)
cache.set(cache_key, response)2. 缓存监控
python
class CacheMonitor:
def __init__(self, cache):
self.cache = cache
self.hits = 0
self.misses = 0
def get(self, key):
result = self.cache.get(key)
if result is not None:
self.hits += 1
else:
self.misses += 1
return result
def get_hit_rate(self):
total = self.hits + self.misses
return self.hits / total if total > 0 else 0
def get_stats(self):
return {
"hits": self.hits,
"misses": self.misses,
"hit_rate": self.get_hit_rate()
}3. 智能缓存决策
python
def should_cache(response, prompt):
if response.get("error"):
return False
if count_tokens(prompt) < 50:
return False
if is_time_sensitive(prompt):
return False
return True4. 缓存压缩
python
import gzip
import json
class CompressedCache:
def __init__(self, cache_backend):
self.backend = cache_backend
def set(self, key, value):
json_str = json.dumps(value)
compressed = gzip.compress(json_str.encode())
self.backend.set(key, compressed)
def get(self, key):
compressed = self.backend.get(key)
if compressed:
json_str = gzip.decompress(compressed).decode()
return json.loads(json_str)
return None小结
有效的缓存策略需要考虑:
- 选择合适的缓存类型:精确匹配、语义相似、分层缓存
- 实施缓存策略:TTL、LRU、LFU等淘汰策略
- AI特定优化:模板缓存、响应缓存、部分缓存
- 缓存失效管理:主动失效、版本控制
- 监控与优化:命中率监控、智能决策、压缩存储
通过合理运用缓存策略,可以显著提升AI应用的响应速度,降低API调用成本,改善用户体验。建议根据具体场景选择合适的缓存方案,并持续监控和优化缓存效果。