Skip to content

缓存策略

概述

缓存是优化AI应用性能和降低成本的重要手段。通过合理使用缓存,可以避免重复计算、减少API调用次数、提升响应速度。本文将介绍多种缓存策略及其实现方法,帮助开发者构建高效的AI应用。

核心内容

缓存类型

1. 精确匹配缓存

适用于完全相同的请求,实现简单高效:

python
from functools import lru_cache
import hashlib

class ExactMatchCache:
    def __init__(self, maxsize=1000):
        self.cache = {}
        self.maxsize = maxsize
        self.access_count = {}
    
    def get(self, prompt, model, params):
        cache_key = self._generate_key(prompt, model, params)
        
        if cache_key in self.cache:
            self.access_count[cache_key] += 1
            return self.cache[cache_key]
        
        return None
    
    def set(self, prompt, model, params, response):
        cache_key = self._generate_key(prompt, model, params)
        
        if len(self.cache) >= self.maxsize:
            self._evict_lru()
        
        self.cache[cache_key] = response
        self.access_count[cache_key] = 1
    
    def _generate_key(self, prompt, model, params):
        content = f"{prompt}|{model}|{str(sorted(params.items()))}"
        return hashlib.sha256(content.encode()).hexdigest()
    
    def _evict_lru(self):
        lru_key = min(self.access_count, key=self.access_count.get)
        del self.cache[lru_key]
        del self.access_count[lru_key]

2. 语义相似缓存

适用于语义相近的请求,提高缓存命中率:

python
import numpy as np
from sentence_transformers import SentenceTransformer

class SemanticCache:
    def __init__(self, similarity_threshold=0.95, maxsize=500):
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
        self.cache = []
        self.threshold = similarity_threshold
        self.maxsize = maxsize
    
    def get(self, prompt):
        prompt_embedding = self.model.encode(prompt)
        
        for cached_prompt, embedding, response in self.cache:
            similarity = self._cosine_similarity(prompt_embedding, embedding)
            if similarity > self.threshold:
                return response
        
        return None
    
    def set(self, prompt, response):
        embedding = self.model.encode(prompt)
        
        if len(self.cache) >= self.maxsize:
            self.cache.pop(0)
        
        self.cache.append((prompt, embedding, response))
    
    def _cosine_similarity(self, a, b):
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

3. 分层缓存

结合内存缓存和持久化存储:

python
import json
import time
from pathlib import Path

class TieredCache:
    def __init__(self, memory_size=100, disk_path="./cache"):
        self.memory_cache = {}
        self.memory_size = memory_size
        self.disk_path = Path(disk_path)
        self.disk_path.mkdir(exist_ok=True)
    
    def get(self, key):
        if key in self.memory_cache:
            return self.memory_cache[key]
        
        disk_cache = self._load_from_disk(key)
        if disk_cache:
            self._promote_to_memory(key, disk_cache)
            return disk_cache
        
        return None
    
    def set(self, key, value, ttl=3600):
        cache_entry = {
            "value": value,
            "expires_at": time.time() + ttl
        }
        
        if len(self.memory_cache) >= self.memory_size:
            self._evict_from_memory()
        
        self.memory_cache[key] = cache_entry
        self._save_to_disk(key, cache_entry)
    
    def _load_from_disk(self, key):
        file_path = self.disk_path / f"{key}.json"
        if file_path.exists():
            with open(file_path, 'r') as f:
                entry = json.load(f)
                if entry["expires_at"] > time.time():
                    return entry["value"]
        return None
    
    def _save_to_disk(self, key, entry):
        file_path = self.disk_path / f"{key}.json"
        with open(file_path, 'w') as f:
            json.dump(entry, f)

缓存策略

1. TTL(Time To Live)策略

python
class TTLCache:
    def __init__(self, default_ttl=3600):
        self.cache = {}
        self.default_ttl = default_ttl
    
    def get(self, key):
        if key in self.cache:
            entry = self.cache[key]
            if time.time() < entry["expires_at"]:
                return entry["value"]
            else:
                del self.cache[key]
        return None
    
    def set(self, key, value, ttl=None):
        ttl = ttl or self.default_ttl
        self.cache[key] = {
            "value": value,
            "expires_at": time.time() + ttl,
            "created_at": time.time()
        }
    
    def cleanup_expired(self):
        current_time = time.time()
        expired_keys = [
            k for k, v in self.cache.items()
            if v["expires_at"] < current_time
        ]
        for key in expired_keys:
            del self.cache[key]

2. LRU(Least Recently Used)策略

python
from collections import OrderedDict

class LRUCache:
    def __init__(self, capacity=1000):
        self.cache = OrderedDict()
        self.capacity = capacity
    
    def get(self, key):
        if key not in self.cache:
            return None
        
        self.cache.move_to_end(key)
        return self.cache[key]
    
    def set(self, key, value):
        if key in self.cache:
            self.cache.move_to_end(key)
        else:
            if len(self.cache) >= self.capacity:
                self.cache.popitem(last=False)
        
        self.cache[key] = value
    
    def get_stats(self):
        return {
            "size": len(self.cache),
            "capacity": self.capacity
        }

3. LFU(Least Frequently Used)策略

python
import heapq

class LFUCache:
    def __init__(self, capacity=1000):
        self.capacity = capacity
        self.cache = {}
        self.frequency = {}
        self.heap = []
    
    def get(self, key):
        if key not in self.cache:
            return None
        
        self.frequency[key] += 1
        return self.cache[key]
    
    def set(self, key, value):
        if len(self.cache) >= self.capacity:
            self._evict()
        
        self.cache[key] = value
        self.frequency[key] = 1
        heapq.heappush(self.heap, (1, key))
    
    def _evict(self):
        while self.heap:
            freq, key = heapq.heappop(self.heap)
            if self.frequency.get(key) == freq:
                del self.cache[key]
                del self.frequency[key]
                break

AI特定缓存优化

1. 提示词模板缓存

python
class PromptTemplateCache:
    def __init__(self):
        self.templates = {}
        self.compiled_templates = {}
    
    def register(self, name, template):
        self.templates[name] = template
        self.compiled_templates[name] = self._compile(template)
    
    def _compile(self, template):
        import re
        variables = re.findall(r'\{(\w+)\}', template)
        return {
            "template": template,
            "variables": variables
        }
    
    def render(self, name, **kwargs):
        if name not in self.templates:
            raise ValueError(f"Template {name} not found")
        
        return self.templates[name].format(**kwargs)
    
    def get_cache_key(self, name, **kwargs):
        template_info = self.compiled_templates[name]
        sorted_kwargs = sorted(kwargs.items())
        return f"{name}|{sorted_kwargs}"

2. 响应缓存装饰器

python
def cached_response(cache_instance, key_func=None):
    def decorator(func):
        def wrapper(*args, **kwargs):
            cache_key = key_func(*args, **kwargs) if key_func else str((args, kwargs))
            
            cached = cache_instance.get(cache_key)
            if cached is not None:
                return cached
            
            result = func(*args, **kwargs)
            cache_instance.set(cache_key, result)
            return result
        
        return wrapper
    return decorator

3. 部分响应缓存

python
class PartialResponseCache:
    def __init__(self):
        self.cache = {}
    
    def cache_chunk(self, request_id, chunk_index, content):
        if request_id not in self.cache:
            self.cache[request_id] = {}
        
        self.cache[request_id][chunk_index] = content
    
    def get_chunk(self, request_id, chunk_index):
        return self.cache.get(request_id, {}).get(chunk_index)
    
    def get_full_response(self, request_id):
        if request_id not in self.cache:
            return None
        
        chunks = self.cache[request_id]
        return "".join(chunks[i] for i in sorted(chunks.keys()))

缓存失效策略

1. 主动失效

python
class CacheInvalidator:
    def __init__(self, cache):
        self.cache = cache
        self.dependencies = {}
    
    def register_dependency(self, cache_key, depends_on):
        if depends_on not in self.dependencies:
            self.dependencies[depends_on] = []
        self.dependencies[depends_on].append(cache_key)
    
    def invalidate(self, dependency):
        if dependency in self.dependencies:
            for cache_key in self.dependencies[dependency]:
                self.cache.delete(cache_key)
            del self.dependencies[dependency]

2. 版本控制失效

python
class VersionedCache:
    def __init__(self):
        self.cache = {}
        self.versions = {}
    
    def set_version(self, namespace, version):
        self.versions[namespace] = version
    
    def get(self, key, namespace):
        version = self.versions.get(namespace, 0)
        full_key = f"{namespace}:{version}:{key}"
        return self.cache.get(full_key)
    
    def set(self, key, value, namespace):
        version = self.versions.get(namespace, 0)
        full_key = f"{namespace}:{version}:{key}"
        self.cache[full_key] = value
    
    def invalidate_namespace(self, namespace):
        self.versions[namespace] = self.versions.get(namespace, 0) + 1

实用技巧

1. 缓存预热

python
def warmup_cache(common_queries, api_func, cache):
    for query in common_queries:
        cache_key = generate_cache_key(query)
        if cache.get(cache_key) is None:
            response = api_func(query)
            cache.set(cache_key, response)

2. 缓存监控

python
class CacheMonitor:
    def __init__(self, cache):
        self.cache = cache
        self.hits = 0
        self.misses = 0
    
    def get(self, key):
        result = self.cache.get(key)
        if result is not None:
            self.hits += 1
        else:
            self.misses += 1
        return result
    
    def get_hit_rate(self):
        total = self.hits + self.misses
        return self.hits / total if total > 0 else 0
    
    def get_stats(self):
        return {
            "hits": self.hits,
            "misses": self.misses,
            "hit_rate": self.get_hit_rate()
        }

3. 智能缓存决策

python
def should_cache(response, prompt):
    if response.get("error"):
        return False
    
    if count_tokens(prompt) < 50:
        return False
    
    if is_time_sensitive(prompt):
        return False
    
    return True

4. 缓存压缩

python
import gzip
import json

class CompressedCache:
    def __init__(self, cache_backend):
        self.backend = cache_backend
    
    def set(self, key, value):
        json_str = json.dumps(value)
        compressed = gzip.compress(json_str.encode())
        self.backend.set(key, compressed)
    
    def get(self, key):
        compressed = self.backend.get(key)
        if compressed:
            json_str = gzip.decompress(compressed).decode()
            return json.loads(json_str)
        return None

小结

有效的缓存策略需要考虑:

  1. 选择合适的缓存类型:精确匹配、语义相似、分层缓存
  2. 实施缓存策略:TTL、LRU、LFU等淘汰策略
  3. AI特定优化:模板缓存、响应缓存、部分缓存
  4. 缓存失效管理:主动失效、版本控制
  5. 监控与优化:命中率监控、智能决策、压缩存储

通过合理运用缓存策略,可以显著提升AI应用的响应速度,降低API调用成本,改善用户体验。建议根据具体场景选择合适的缓存方案,并持续监控和优化缓存效果。