Skip to content

错误处理与重试策略

概述

在生产环境中,API调用不可避免会遇到各种错误。如何优雅地处理错误、实现智能重试、设计降级方案,是构建稳定AI应用的关键。本章将详细介绍错误处理与重试策略,帮助你构建健壮的AI系统。

常见错误类型

错误分类

python
from enum import Enum

class ErrorType(Enum):
    RATE_LIMIT = "rate_limit"
    AUTHENTICATION = "authentication"
    INVALID_REQUEST = "invalid_request"
    MODEL_OVERLOAD = "model_overload"
    TIMEOUT = "timeout"
    NETWORK = "network"
    CONTENT_FILTER = "content_filter"
    INSUFFICIENT_QUOTA = "insufficient_quota"
    UNKNOWN = "unknown"

class APIError(Exception):
    def __init__(self, error_type: ErrorType, message: str, retry_after: int = None):
        self.error_type = error_type
        self.message = message
        self.retry_after = retry_after
        super().__init__(message)
    
    def is_retryable(self) -> bool:
        return self.error_type in [
            ErrorType.RATE_LIMIT,
            ErrorType.MODEL_OVERLOAD,
            ErrorType.TIMEOUT,
            ErrorType.NETWORK
        ]

OpenAI错误处理

python
from openai import (
    APIError,
    APIConnectionError,
    RateLimitError,
    AuthenticationError,
    BadRequestError,
    APITimeoutError
)

def handle_openai_error(error: Exception) -> APIError:
    if isinstance(error, RateLimitError):
        retry_after = getattr(error, 'retry_after', 60)
        return APIError(
            ErrorType.RATE_LIMIT,
            f"速率限制,{retry_after}秒后重试",
            retry_after
        )
    
    elif isinstance(error, AuthenticationError):
        return APIError(
            ErrorType.AUTHENTICATION,
            "API密钥无效或已过期"
        )
    
    elif isinstance(error, BadRequestError):
        return APIError(
            ErrorType.INVALID_REQUEST,
            f"请求参数错误: {str(error)}"
        )
    
    elif isinstance(error, APITimeoutError):
        return APIError(
            ErrorType.TIMEOUT,
            "请求超时"
        )
    
    elif isinstance(error, APIConnectionError):
        return APIError(
            ErrorType.NETWORK,
            "网络连接失败"
        )
    
    else:
        return APIError(
            ErrorType.UNKNOWN,
            f"未知错误: {str(error)}"
        )

def safe_chat(messages: list):
    try:
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=messages
        )
        return response.choices[0].message.content
    
    except Exception as e:
        api_error = handle_openai_error(e)
        print(f"错误类型: {api_error.error_type.value}")
        print(f"错误信息: {api_error.message}")
        print(f"是否可重试: {api_error.is_retryable()}")
        raise api_error

错误码映射

python
ERROR_CODE_MAPPING = {
    400: ErrorType.INVALID_REQUEST,
    401: ErrorType.AUTHENTICATION,
    403: ErrorType.AUTHENTICATION,
    404: ErrorType.INVALID_REQUEST,
    429: ErrorType.RATE_LIMIT,
    500: ErrorType.MODEL_OVERLOAD,
    502: ErrorType.MODEL_OVERLOAD,
    503: ErrorType.MODEL_OVERLOAD,
    504: ErrorType.TIMEOUT
}

def map_status_code(status_code: int) -> ErrorType:
    return ERROR_CODE_MAPPING.get(status_code, ErrorType.UNKNOWN)

重试策略

基础重试

python
import time
from typing import Callable, Any

def basic_retry(
    func: Callable,
    max_retries: int = 3,
    delay: float = 1.0
) -> Any:
    last_error = None
    
    for attempt in range(max_retries):
        try:
            return func()
        except Exception as e:
            last_error = e
            print(f"第{attempt + 1}次尝试失败: {e}")
            
            if attempt < max_retries - 1:
                time.sleep(delay)
    
    raise last_error

result = basic_retry(
    lambda: client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{"role": "user", "content": "你好"}]
    ),
    max_retries=3,
    delay=2.0
)

指数退避重试

python
import random

def exponential_backoff_retry(
    func: Callable,
    max_retries: int = 5,
    base_delay: float = 1.0,
    max_delay: float = 60.0,
    jitter: bool = True
) -> Any:
    last_error = None
    
    for attempt in range(max_retries):
        try:
            return func()
        except Exception as e:
            last_error = e
            
            if attempt == max_retries - 1:
                break
            
            delay = min(base_delay * (2 ** attempt), max_delay)
            
            if jitter:
                delay = delay + random.uniform(0, delay * 0.1)
            
            print(f"第{attempt + 1}次失败,{delay:.2f}秒后重试...")
            time.sleep(delay)
    
    raise last_error

自适应重试

python
class AdaptiveRetry:
    def __init__(
        self,
        max_retries: int = 5,
        base_delay: float = 1.0,
        max_delay: float = 60.0
    ):
        self.max_retries = max_retries
        self.base_delay = base_delay
        self.max_delay = max_delay
        self.consecutive_failures = 0
        self.last_success_time = time.time()
    
    def calculate_delay(self, error: Exception) -> float:
        if isinstance(error, RateLimitError):
            retry_after = getattr(error, 'retry_after', None)
            if retry_after:
                return retry_after
        
        delay = min(
            self.base_delay * (2 ** self.consecutive_failures),
            self.max_delay
        )
        
        jitter = random.uniform(0, delay * 0.1)
        return delay + jitter
    
    def execute(self, func: Callable) -> Any:
        last_error = None
        
        for attempt in range(self.max_retries):
            try:
                result = func()
                self.consecutive_failures = 0
                self.last_success_time = time.time()
                return result
            
            except Exception as e:
                last_error = e
                self.consecutive_failures += 1
                
                api_error = handle_openai_error(e)
                
                if not api_error.is_retryable():
                    raise e
                
                if attempt == self.max_retries - 1:
                    break
                
                delay = self.calculate_delay(e)
                print(f"第{attempt + 1}次失败,等待{delay:.2f}秒...")
                time.sleep(delay)
        
        raise last_error

adaptive_retry = AdaptiveRetry()

def robust_chat(messages: list):
    return adaptive_retry.execute(
        lambda: client.chat.completions.create(
            model="gpt-4o-mini",
            messages=messages
        )
    )

条件重试

python
def conditional_retry(
    func: Callable,
    should_retry: Callable[[Exception], bool],
    max_retries: int = 3,
    delay: float = 1.0
) -> Any:
    last_error = None
    
    for attempt in range(max_retries):
        try:
            return func()
        except Exception as e:
            last_error = e
            
            if not should_retry(e):
                print(f"错误不可重试: {e}")
                raise e
            
            if attempt < max_retries - 1:
                print(f"第{attempt + 1}次失败,准备重试...")
                time.sleep(delay)
    
    raise last_error

def is_retryable_error(error: Exception) -> bool:
    api_error = handle_openai_error(error)
    return api_error.is_retryable()

result = conditional_retry(
    lambda: client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{"role": "user", "content": "测试"}]
    ),
    should_retry=is_retryable_error,
    max_retries=5
)

断路器模式

基础断路器

python
from enum import Enum
from datetime import datetime, timedelta

class CircuitState(Enum):
    CLOSED = "closed"
    OPEN = "open"
    HALF_OPEN = "half_open"

class CircuitBreaker:
    def __init__(
        self,
        failure_threshold: int = 5,
        recovery_timeout: int = 60,
        half_open_max_calls: int = 3
    ):
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.half_open_max_calls = half_open_max_calls
        
        self.state = CircuitState.CLOSED
        self.failure_count = 0
        self.last_failure_time = None
        self.half_open_calls = 0
    
    def can_execute(self) -> bool:
        if self.state == CircuitState.CLOSED:
            return True
        
        if self.state == CircuitState.OPEN:
            if self._should_attempt_recovery():
                self.state = CircuitState.HALF_OPEN
                self.half_open_calls = 0
                return True
            return False
        
        if self.state == CircuitState.HALF_OPEN:
            return self.half_open_calls < self.half_open_max_calls
        
        return False
    
    def _should_attempt_recovery(self) -> bool:
        if self.last_failure_time is None:
            return True
        
        elapsed = datetime.now() - self.last_failure_time
        return elapsed > timedelta(seconds=self.recovery_timeout)
    
    def record_success(self):
        if self.state == CircuitState.HALF_OPEN:
            self.state = CircuitState.CLOSED
        
        self.failure_count = 0
    
    def record_failure(self):
        self.failure_count += 1
        self.last_failure_time = datetime.now()
        
        if self.state == CircuitState.HALF_OPEN:
            self.half_open_calls += 1
            if self.half_open_calls >= self.half_open_max_calls:
                self.state = CircuitState.OPEN
        
        elif self.failure_count >= self.failure_threshold:
            self.state = CircuitState.OPEN
    
    def execute(self, func: Callable) -> Any:
        if not self.can_execute():
            raise Exception("断路器处于开启状态,拒绝请求")
        
        try:
            result = func()
            self.record_success()
            return result
        except Exception as e:
            self.record_failure()
            raise e

circuit_breaker = CircuitBreaker(
    failure_threshold=5,
    recovery_timeout=60
)

def protected_chat(messages: list):
    return circuit_breaker.execute(
        lambda: client.chat.completions.create(
            model="gpt-4o-mini",
            messages=messages
        )
    )

带监控的断路器

python
from dataclasses import dataclass
from typing import List
import threading

@dataclass
class CircuitStats:
    total_requests: int = 0
    successful_requests: int = 0
    failed_requests: int = 0
    rejected_requests: int = 0

class MonitoredCircuitBreaker:
    def __init__(self, **kwargs):
        self.breaker = CircuitBreaker(**kwargs)
        self.stats = CircuitStats()
        self.lock = threading.Lock()
    
    def execute(self, func: Callable) -> Any:
        with self.lock:
            self.stats.total_requests += 1
        
        if not self.breaker.can_execute():
            with self.lock:
                self.stats.rejected_requests += 1
            raise Exception("断路器开启,请求被拒绝")
        
        try:
            result = self.breaker.execute(func)
            with self.lock:
                self.stats.successful_requests += 1
            return result
        except Exception as e:
            with self.lock:
                self.stats.failed_requests += 1
            raise e
    
    def get_stats(self) -> dict:
        with self.lock:
            return {
                "state": self.breaker.state.value,
                "total_requests": self.stats.total_requests,
                "successful_requests": self.stats.successful_requests,
                "failed_requests": self.stats.failed_requests,
                "rejected_requests": self.stats.rejected_requests,
                "success_rate": (
                    self.stats.successful_requests / self.stats.total_requests
                    if self.stats.total_requests > 0 else 0
                )
            }

monitored_breaker = MonitoredCircuitBreaker(
    failure_threshold=5,
    recovery_timeout=60
)

降级策略

多模型降级

python
class ModelFallback:
    def __init__(self, models: list):
        self.models = models
        self.current_index = 0
    
    def execute(self, messages: list, **kwargs):
        errors = []
        
        for i in range(len(self.models)):
            model = self.models[(self.current_index + i) % len(self.models)]
            
            try:
                response = client.chat.completions.create(
                    model=model,
                    messages=messages,
                    **kwargs
                )
                
                self.current_index = (self.current_index + i) % len(self.models)
                return response.choices[0].message.content
            
            except Exception as e:
                errors.append(f"{model}: {str(e)}")
                print(f"模型 {model} 失败: {e}")
                continue
        
        raise Exception(f"所有模型都失败:\n" + "\n".join(errors))

fallback = ModelFallback([
    "gpt-4o",
    "gpt-4o-mini",
    "gpt-3.5-turbo"
])

result = fallback.execute([{"role": "user", "content": "你好"}])

缓存降级

python
class CacheFallback:
    def __init__(self):
        self.cache = {}
    
    def execute(self, messages: list, use_cache: bool = True):
        cache_key = str(messages)
        
        try:
            response = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=messages
            )
            
            result = response.choices[0].message.content
            self.cache[cache_key] = result
            return result
        
        except Exception as e:
            print(f"API调用失败: {e}")
            
            if use_cache and cache_key in self.cache:
                print("使用缓存响应")
                return self.cache[cache_key]
            
            raise e

cache_fallback = CacheFallback()

默认响应降级

python
class DefaultResponseFallback:
    def __init__(self):
        self.default_responses = {
            "greeting": "你好!我是AI助手,很高兴为你服务。",
            "error": "抱歉,系统暂时不可用,请稍后再试。",
            "unknown": "我暂时无法回答这个问题,请换个方式提问。"
        }
    
    def classify_intent(self, message: str) -> str:
        message = message.lower()
        
        if any(word in message for word in ["你好", "hi", "hello"]):
            return "greeting"
        
        return "unknown"
    
    def execute(self, messages: list):
        try:
            response = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=messages
            )
            return response.choices[0].message.content
        
        except Exception as e:
            print(f"API失败,使用默认响应: {e}")
            
            user_message = messages[-1]["content"] if messages else ""
            intent = self.classify_intent(user_message)
            
            return self.default_responses.get(intent, self.default_responses["error"])

default_fallback = DefaultResponseFallback()

综合降级策略

python
class ComprehensiveFallback:
    def __init__(self):
        self.model_fallback = ModelFallback([
            "gpt-4o",
            "gpt-4o-mini",
            "gpt-3.5-turbo"
        ])
        self.cache_fallback = CacheFallback()
        self.default_fallback = DefaultResponseFallback()
        self.circuit_breaker = MonitoredCircuitBreaker()
    
    def execute(self, messages: list):
        try:
            return self.circuit_breaker.execute(
                lambda: self.model_fallback.execute(messages)
            )
        except Exception as e:
            print(f"模型降级失败: {e}")
        
        try:
            return self.cache_fallback.execute(messages)
        except Exception as e:
            print(f"缓存降级失败: {e}")
        
        return self.default_fallback.execute(messages)

comprehensive_fallback = ComprehensiveFallback()

def robust_request(messages: list):
    return comprehensive_fallback.execute(messages)

错误监控与告警

错误统计

python
from collections import defaultdict
from datetime import datetime, timedelta

class ErrorMonitor:
    def __init__(self, window_minutes: int = 60):
        self.window_minutes = window_minutes
        self.errors = defaultdict(list)
        self.lock = threading.Lock()
    
    def record_error(self, error_type: ErrorType, error_message: str):
        with self.lock:
            now = datetime.now()
            self.errors[error_type].append({
                "timestamp": now,
                "message": error_message
            })
            
            self._cleanup_old_errors()
    
    def _cleanup_old_errors(self):
        cutoff = datetime.now() - timedelta(minutes=self.window_minutes)
        
        for error_type in self.errors:
            self.errors[error_type] = [
                e for e in self.errors[error_type]
                if e["timestamp"] > cutoff
            ]
    
    def get_error_stats(self) -> dict:
        with self.lock:
            stats = {}
            
            for error_type, errors in self.errors.items():
                stats[error_type.value] = {
                    "count": len(errors),
                    "recent_errors": [
                        e["message"] for e in errors[-5:]
                    ]
                }
            
            return stats
    
    def get_error_rate(self) -> float:
        with self.lock:
            total_errors = sum(len(errors) for errors in self.errors.values())
            window_seconds = self.window_minutes * 60
            return total_errors / window_seconds

error_monitor = ErrorMonitor()

def monitored_request(messages: list):
    try:
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=messages
        )
        return response.choices[0].message.content
    
    except Exception as e:
        api_error = handle_openai_error(e)
        error_monitor.record_error(api_error.error_type, api_error.message)
        raise

告警机制

python
from typing import Callable

class AlertManager:
    def __init__(self):
        self.alert_rules = []
        self.alert_handlers = []
    
    def add_rule(
        self,
        condition: Callable[[dict], bool],
        message: str
    ):
        self.alert_rules.append({
            "condition": condition,
            "message": message
        })
    
    def add_handler(self, handler: Callable[[str], None]):
        self.alert_handlers.append(handler)
    
    def check_alerts(self, stats: dict):
        for rule in self.alert_rules:
            if rule["condition"](stats):
                alert_message = rule["message"]
                
                for handler in self.alert_handlers:
                    handler(alert_message)
    
    def log_alert(self, message: str):
        print(f"[ALERT] {datetime.now()}: {message}")
    
    def send_email_alert(self, message: str):
        print(f"[EMAIL] 发送告警邮件: {message}")

alert_manager = AlertManager()
alert_manager.add_handler(alert_manager.log_alert)

alert_manager.add_rule(
    lambda stats: stats.get("rate_limit", {}).get("count", 0) > 10,
    "速率限制错误过多,请检查API使用情况"
)

alert_manager.add_rule(
    lambda stats: stats.get("network", {}).get("count", 0) > 5,
    "网络错误频繁,请检查网络连接"
)

日志记录

结构化日志

python
import json
from datetime import datetime

class StructuredLogger:
    def __init__(self, log_file: str = "api_errors.log"):
        self.log_file = log_file
    
    def log(
        self,
        level: str,
        message: str,
        error_type: str = None,
        retry_count: int = None,
        **kwargs
    ):
        log_entry = {
            "timestamp": datetime.now().isoformat(),
            "level": level,
            "message": message,
            "error_type": error_type,
            "retry_count": retry_count,
            **kwargs
        }
        
        log_line = json.dumps(log_entry, ensure_ascii=False)
        
        print(log_line)
        
        with open(self.log_file, 'a', encoding='utf-8') as f:
            f.write(log_line + '\n')

logger = StructuredLogger()

def logged_request(messages: list):
    retry_count = 0
    
    while retry_count < 3:
        try:
            response = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=messages
            )
            
            logger.log(
                "INFO",
                "请求成功",
                retry_count=retry_count
            )
            
            return response.choices[0].message.content
        
        except Exception as e:
            api_error = handle_openai_error(e)
            
            logger.log(
                "ERROR",
                api_error.message,
                error_type=api_error.error_type.value,
                retry_count=retry_count
            )
            
            retry_count += 1
            time.sleep(2 ** retry_count)
    
    raise Exception("重试次数耗尽")

最佳实践总结

完整的错误处理示例

python
class RobustAPIClient:
    def __init__(self):
        self.client = OpenAI()
        self.circuit_breaker = MonitoredCircuitBreaker(
            failure_threshold=5,
            recovery_timeout=60
        )
        self.error_monitor = ErrorMonitor()
        self.logger = StructuredLogger()
        self.fallback = ComprehensiveFallback()
    
    def chat(self, messages: list, **kwargs):
        try:
            return self.circuit_breaker.execute(
                lambda: self._execute_with_retry(messages, **kwargs)
            )
        except Exception as e:
            self.logger.log(
                "ERROR",
                f"所有重试失败: {str(e)}",
                error_type="critical"
            )
            return self.fallback.execute(messages)
    
    def _execute_with_retry(self, messages: list, **kwargs):
        max_retries = 3
        last_error = None
        
        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model="gpt-4o-mini",
                    messages=messages,
                    **kwargs
                )
                
                return response.choices[0].message.content
            
            except Exception as e:
                last_error = e
                api_error = handle_openai_error(e)
                
                self.error_monitor.record_error(
                    api_error.error_type,
                    api_error.message
                )
                
                if not api_error.is_retryable():
                    raise e
                
                if attempt < max_retries - 1:
                    delay = 2 ** attempt
                    self.logger.log(
                        "WARNING",
                        f"请求失败,{delay}秒后重试",
                        error_type=api_error.error_type.value,
                        retry_count=attempt
                    )
                    time.sleep(delay)
        
        raise last_error

robust_client = RobustAPIClient()

response = robust_client.chat([
    {"role": "user", "content": "你好"}
])

小结

本章详细介绍了错误处理与重试策略:

  1. 错误分类 - 识别不同类型的错误,判断是否可重试
  2. 重试策略 - 基础重试、指数退避、自适应重试
  3. 断路器模式 - 防止级联故障,保护系统稳定性
  4. 降级策略 - 多模型降级、缓存降级、默认响应降级
  5. 监控告警 - 错误统计、告警机制
  6. 日志记录 - 结构化日志,便于排查问题

通过合理的错误处理和重试策略,可以显著提高AI应用的稳定性和可用性。结合前面的API调用最佳实践,你已经掌握了构建生产级AI应用的核心技能。

参考资源