Skip to content

Kitex 中间件机制

概述

中间件是 Kitex 的核心特性之一,用于在请求处理过程中插入通用逻辑,如日志记录、监控统计、认证鉴权等。本章将深入介绍 Kitex 中间件的原理、使用方式和最佳实践。

核心内容

中间件原理

Kitex 中间件采用洋葱模型:

请求 → Middleware1 → Middleware2 → Handler → Middleware2 → Middleware1 → 响应
         ↓            ↓            ↓            ↑            ↑
       前置处理      前置处理      业务处理     后置处理     后置处理

中间件签名

go
type Middleware func(next endpoint.Endpoint) endpoint.Endpoint

type Endpoint func(ctx context.Context, req, resp interface{}) error

基本使用

服务端中间件

go
package main

import (
    "context"
    "time"
    
    "github.com/cloudwego/kitex/pkg/klog"
    "github.com/cloudwego/kitex/server"
)

// 日志中间件
func LoggingMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
    return func(ctx context.Context, req, resp interface{}) error {
        start := time.Now()
        
        // 前置处理
        klog.CtxInfof(ctx, "[Server] request: %v", req)
        
        // 调用下一个中间件或处理器
        err := next(ctx, req, resp)
        
        // 后置处理
        klog.CtxInfof(ctx, "[Server] response: %v, cost: %v, err: %v",
            resp, time.Since(start), err)
        
        return err
    }
}

func main() {
    svr := helloservice.NewServer(&HelloHandler{},
        server.WithMiddleware(LoggingMiddleware),
    )
    svr.Run()
}

客户端中间件

go
package main

import (
    "github.com/cloudwego/kitex/client"
)

func main() {
    cli := helloservice.MustNewClient("hello",
        client.WithHostPorts("127.0.0.1:8888"),
        client.WithMiddleware(LoggingMiddleware),
    )
}

常用中间件实现

1. 日志中间件

go
func LoggingMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
    return func(ctx context.Context, req, resp interface{}) error {
        start := time.Now()
        
        // 获取方法名
        ri := rpcinfo.GetRPCInfo(ctx)
        methodName := ri.To().Method()
        
        klog.CtxInfof(ctx, "[%s] request: %+v", methodName, req)
        
        err := next(ctx, req, resp)
        
        cost := time.Since(start)
        if err != nil {
            klog.CtxErrorf(ctx, "[%s] error: %v, cost: %v", methodName, err, cost)
        } else {
            klog.CtxInfof(ctx, "[%s] success, cost: %v", methodName, cost)
        }
        
        return err
    }
}

2. 监控中间件

go
import (
    "github.com/prometheus/client_golang/prometheus"
    "github.com/prometheus/client_golang/prometheus/promauto"
)

var (
    requestCounter = promauto.NewCounterVec(prometheus.CounterOpts{
        Name: "rpc_requests_total",
        Help: "Total number of RPC requests",
    }, []string{"service", "method", "status"})
    
    requestLatency = promauto.NewHistogramVec(prometheus.HistogramOpts{
        Name:    "rpc_request_latency_seconds",
        Help:    "RPC request latency in seconds",
        Buckets: []float64{.001, .005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10},
    }, []string{"service", "method"})
)

func MetricsMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
    return func(ctx context.Context, req, resp interface{}) error {
        start := time.Now()
        
        ri := rpcinfo.GetRPCInfo(ctx)
        service := ri.To().ServiceName()
        method := ri.To().Method()
        
        err := next(ctx, req, resp)
        
        status := "success"
        if err != nil {
            status = "error"
        }
        
        requestCounter.WithLabelValues(service, method, status).Inc()
        requestLatency.WithLabelValues(service, method).Observe(time.Since(start).Seconds())
        
        return err
    }
}

3. 认证中间件

go
func AuthMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
    return func(ctx context.Context, req, resp interface{}) error {
        // 获取认证信息
        token := transmeta.GetMeta(ctx, "authorization")
        if token == "" {
            return errors.New("missing authorization token")
        }
        
        // 验证 token
        userId, err := validateToken(token)
        if err != nil {
            return fmt.Errorf("invalid token: %w", err)
        }
        
        // 设置用户信息到上下文
        ctx = context.WithValue(ctx, "user-id", userId)
        
        return next(ctx, req, resp)
    }
}

4. 限流中间件

go
import "golang.org/x/time/rate"

func RateLimitMiddleware(limiter *rate.Limiter) endpoint.Middleware {
    return func(next endpoint.Endpoint) endpoint.Endpoint {
        return func(ctx context.Context, req, resp interface{}) error {
            if !limiter.Allow() {
                return errors.New("rate limit exceeded")
            }
            return next(ctx, req, resp)
        }
    }
}

// 使用
func main() {
    limiter := rate.NewLimiter(rate.Limit(1000), 100) // 1000 QPS,突发 100
    
    svr := helloservice.NewServer(&HelloHandler{},
        server.WithMiddleware(RateLimitMiddleware(limiter)),
    )
}

5. 恢复中间件

go
func RecoveryMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
    return func(ctx context.Context, req, resp interface{}) (err error) {
        defer func() {
            if r := recover(); r != nil {
                klog.CtxErrorf(ctx, "panic recovered: %v\n%s", r, debug.Stack())
                err = fmt.Errorf("internal error: %v", r)
            }
        }()
        return next(ctx, req, resp)
    }
}

6. 缓存中间件

go
import "sync"

type CacheItem struct {
    value      interface{}
    expiration time.Time
}

type Cache struct {
    items map[string]*CacheItem
    mu    sync.RWMutex
}

func NewCache() *Cache {
    return &Cache{
        items: make(map[string]*CacheItem),
    }
}

func (c *Cache) Get(key string) (interface{}, bool) {
    c.mu.RLock()
    defer c.mu.RUnlock()
    
    item, ok := c.items[key]
    if !ok {
        return nil, false
    }
    
    if time.Now().After(item.expiration) {
        return nil, false
    }
    
    return item.value, true
}

func (c *Cache) Set(key string, value interface{}, ttl time.Duration) {
    c.mu.Lock()
    defer c.mu.Unlock()
    
    c.items[key] = &CacheItem{
        value:      value,
        expiration: time.Now().Add(ttl),
    }
}

func CacheMiddleware(cache *Cache, ttl time.Duration) endpoint.Middleware {
    return func(next endpoint.Endpoint) endpoint.Endpoint {
        return func(ctx context.Context, req, resp interface{}) error {
            // 生成缓存键
            key := generateCacheKey(ctx, req)
            
            // 检查缓存
            if cached, found := cache.Get(key); found {
                // 复制缓存值到 resp
                if err := copyResponse(resp, cached); err == nil {
                    klog.CtxInfof(ctx, "cache hit for key: %s", key)
                    return nil
                }
            }
            
            // 调用下一个中间件
            err := next(ctx, req, resp)
            
            // 缓存响应
            if err == nil {
                cache.Set(key, resp, ttl)
                klog.CtxInfof(ctx, "cache set for key: %s", key)
            }
            
            return err
        }
    }
}

func generateCacheKey(ctx context.Context, req interface{}) string {
    ri := rpcinfo.GetRPCInfo(ctx)
    method := ri.To().Method()
    return fmt.Sprintf("%s:%v", method, req)
}

func copyResponse(dst, src interface{}) error {
    // 实现响应复制逻辑
    // 可以使用反射或序列化/反序列化
    return nil
}

7. 超时控制中间件

go
func TimeoutMiddleware(timeout time.Duration) endpoint.Middleware {
    return func(next endpoint.Endpoint) endpoint.Endpoint {
        return func(ctx context.Context, req, resp interface{}) error {
            // 创建带超时的上下文
            ctxWithTimeout, cancel := context.WithTimeout(ctx, timeout)
            defer cancel()
            
            // 通道用于接收结果
            errCh := make(chan error, 1)
            
            go func() {
                errCh <- next(ctxWithTimeout, req, resp)
            }()
            
            select {
            case err := <-errCh:
                return err
            case <-ctxWithTimeout.Done():
                return fmt.Errorf("request timeout after %v", timeout)
            }
        }
    }
}

8. CORS 中间件

go
func CORSMiddleware() endpoint.Middleware {
    return func(next endpoint.Endpoint) endpoint.Endpoint {
        return func(ctx context.Context, req, resp interface{}) error {
            // 获取 HTTP 上下文
            if httpCtx, ok := httpx.GetHTTPContext(ctx); ok {
                // 设置 CORS 头部
                httpCtx.Response.Header.Set("Access-Control-Allow-Origin", "*")
                httpCtx.Response.Header.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
                httpCtx.Response.Header.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
                
                // 处理 OPTIONS 请求
                if httpCtx.Request.Method == "OPTIONS" {
                    httpCtx.Response.StatusCode = 204
                    return nil
                }
            }
            
            return next(ctx, req, resp)
        }
    }
}

9. 灰度发布中间件

go
func CanaryMiddleware(canaryRatio float64) endpoint.Middleware {
    return func(next endpoint.Endpoint) endpoint.Endpoint {
        return func(ctx context.Context, req, resp interface{}) error {
            // 检查是否应该路由到灰度版本
            if shouldRouteToCanary(canaryRatio) {
                klog.CtxInfof(ctx, "routing to canary version")
                // 这里可以修改上下文或请求,路由到灰度服务
            }
            
            return next(ctx, req, resp)
        }
    }
}

func shouldRouteToCanary(ratio float64) bool {
    return rand.Float64() < ratio
}

10. 链路追踪中间件

go
import "go.opentelemetry.io/otel"

func TracingMiddleware() endpoint.Middleware {
    return func(next endpoint.Endpoint) endpoint.Endpoint {
        return func(ctx context.Context, req, resp interface{}) error {
            ri := rpcinfo.GetRPCInfo(ctx)
            method := ri.To().Method()
            service := ri.To().ServiceName()
            
            // 创建 span
            tracer := otel.Tracer("kitex")
            ctx, span := tracer.Start(ctx, method)
            defer span.End()
            
            // 设置 span 属性
            span.SetAttributes(
                attribute.String("service.name", service),
                attribute.String("method.name", method),
            )
            
            // 调用下一个中间件
            err := next(ctx, req, resp)
            
            // 记录错误
            if err != nil {
                span.RecordError(err)
            }
            
            return err
        }
    }
}

中间件顺序

中间件的执行顺序很重要:

go
func main() {
    svr := helloservice.NewServer(&HelloHandler{},
        // 执行顺序:Recovery → Metrics → Auth → Logging → Handler
        server.WithMiddleware(RecoveryMiddleware),   // 最外层,捕获 panic
        server.WithMiddleware(MetricsMiddleware),    // 监控统计
        server.WithMiddleware(AuthMiddleware),       // 认证鉴权
        server.WithMiddleware(LoggingMiddleware),    // 日志记录
    )
}

中间件分组

go
// 业务中间件组
func BusinessMiddlewareGroup() []endpoint.Middleware {
    return []endpoint.Middleware{
        RecoveryMiddleware,
        LoggingMiddleware,
        MetricsMiddleware,
    }
}

func main() {
    var middlewares []endpoint.Middleware
    middlewares = append(middlewares, BusinessMiddlewareGroup()...)
    
    svr := helloservice.NewServer(&HelloHandler{},
        server.WithMiddlewares(middlewares),
    )
}

条件中间件

go
func ConditionalMiddleware(condition func(ctx context.Context) bool, mw endpoint.Middleware) endpoint.Middleware {
    return func(next endpoint.Endpoint) endpoint.Endpoint {
        return func(ctx context.Context, req, resp interface{}) error {
            if condition(ctx) {
                return mw(next)(ctx, req, resp)
            }
            return next(ctx, req, resp)
        }
    }
}

// 使用示例:仅对特定方法启用认证
func main() {
    svr := helloservice.NewServer(&HelloHandler{},
        server.WithMiddleware(
            ConditionalMiddleware(
                func(ctx context.Context) bool {
                    ri := rpcinfo.GetRPCInfo(ctx)
                    return ri.To().Method() != "HealthCheck"
                },
                AuthMiddleware,
            ),
        ),
    )
}

中间件设计原则

1. 单一职责

  • 每个中间件只负责一个功能
  • 避免在一个中间件中实现多个功能
  • 保持中间件逻辑简洁清晰

2. 可组合性

  • 中间件应该可以任意组合
  • 不依赖于其他中间件的存在
  • 支持链式调用

3. 性能考虑

  • 避免在中间件中执行耗时操作
  • 减少内存分配和复制
  • 合理使用缓存
  • 避免不必要的网络调用

4. 错误处理

  • 正确处理和传递错误
  • 不要吞掉错误
  • 提供清晰的错误信息

5. 上下文管理

  • 合理使用上下文传递信息
  • 避免在上下文中存储过多数据
  • 清理不再需要的上下文数据

中间件性能优化

1. 内存优化

go
// 预分配内存
func MemoryOptimizedMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
    return func(ctx context.Context, req, resp interface{}) error {
        // 预分配切片容量
        buffer := make([]byte, 0, 1024)
        // 使用 buffer 进行操作
        // ...
        return next(ctx, req, resp)
    }
}

2. 并发优化

go
// 使用 sync.Pool 复用对象
var bufferPool = sync.Pool{
    New: func() interface{} {
        return make([]byte, 1024)
    },
}

func ConcurrentOptimizedMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
    return func(ctx context.Context, req, resp interface{}) error {
        buffer := bufferPool.Get().([]byte)
        defer bufferPool.Put(buffer)
        // 使用 buffer
        // ...
        return next(ctx, req, resp)
    }
}

3. 缓存优化

go
// 使用本地缓存
var localCache = sync.Map{}

func CacheOptimizedMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
    return func(ctx context.Context, req, resp interface{}) error {
        key := generateKey(req)
        if val, ok := localCache.Load(key); ok {
            // 使用缓存值
            return nil
        }
        // 调用下一个中间件
        err := next(ctx, req, resp)
        // 缓存结果
        localCache.Store(key, resp)
        return err
    }
}

中间件最佳实践

1. 中间件注册顺序

推荐顺序

  1. Recovery:最外层,捕获 panic
  2. Metrics:监控统计
  3. Tracing:链路追踪
  4. CORS:跨域处理
  5. Auth:认证鉴权
  6. RateLimit:限流
  7. Cache:缓存
  8. Logging:日志记录
  9. Business Logic:业务逻辑

2. 中间件分组

go
// 通用中间件组
func CommonMiddlewareGroup() []endpoint.Middleware {
    return []endpoint.Middleware{
        RecoveryMiddleware,
        MetricsMiddleware,
        TracingMiddleware,
        LoggingMiddleware,
    }
}

// 安全中间件组
func SecurityMiddlewareGroup() []endpoint.Middleware {
    return []endpoint.Middleware{
        AuthMiddleware,
        RateLimitMiddleware(rate.NewLimiter(1000, 100)),
        CORSMiddleware(),
    }
}

func main() {
    middlewares := append(CommonMiddlewareGroup(), SecurityMiddlewareGroup()...)
    
    svr := helloservice.NewServer(&HelloHandler{},
        server.WithMiddlewares(middlewares),
    )
}

3. 条件中间件

go
// 根据环境启用不同中间件
func EnvironmentAwareMiddleware() endpoint.Middleware {
    if os.Getenv("APP_ENV") == "production" {
        return ProductionMiddleware()
    }
    return DevelopmentMiddleware()
}

// 根据方法启用中间件
func MethodSpecificMiddleware(methods []string, mw endpoint.Middleware) endpoint.Middleware {
    return func(next endpoint.Endpoint) endpoint.Endpoint {
        return func(ctx context.Context, req, resp interface{}) error {
            ri := rpcinfo.GetRPCInfo(ctx)
            method := ri.To().Method()
            
            for _, m := range methods {
                if m == method {
                    return mw(next)(ctx, req, resp)
                }
            }
            return next(ctx, req, resp)
        }
    }
}

完整示例

go
package main

import (
    "context"
    "errors"
    "fmt"
    "runtime/debug"
    "time"
    
    "github.com/cloudwego/kitex/pkg/klog"
    "github.com/cloudwego/kitex/pkg/rpcinfo"
    "github.com/cloudwego/kitex/server"
    "example/kitex_gen/hello/helloservice"
)

// 日志中间件
func LoggingMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
    return func(ctx context.Context, req, resp interface{}) error {
        start := time.Now()
        ri := rpcinfo.GetRPCInfo(ctx)
        
        klog.CtxInfof(ctx, "[%s] request: %+v", ri.To().Method(), req)
        
        err := next(ctx, req, resp)
        
        klog.CtxInfof(ctx, "[%s] cost: %v, error: %v", ri.To().Method(), time.Since(start), err)
        
        return err
    }
}

// 恢复中间件
func RecoveryMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
    return func(ctx context.Context, req, resp interface{}) (err error) {
        defer func() {
            if r := recover(); r != nil {
                klog.CtxErrorf(ctx, "panic recovered: %v\n%s", r, debug.Stack())
                err = fmt.Errorf("internal error: %v", r)
            }
        }()
        return next(ctx, req, resp)
    }
}

// 监控中间件
func MetricsMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
    return func(ctx context.Context, req, resp interface{}) error {
        start := time.Now()
        ri := rpcinfo.GetRPCInfo(ctx)
        
        err := next(ctx, req, resp)
        
        // 记录 metrics
        klog.CtxInfof(ctx, "[%s] latency: %v", ri.To().Method(), time.Since(start))
        
        return err
    }
}

type HelloHandler struct{}

func (h *HelloHandler) SayHello(ctx context.Context, req *hello.Request) (*hello.Response, error) {
    return &hello.Response{Message: "Hello, " + req.Name + "!"}, nil
}

func main() {
    // 中间件注册
    svr := helloservice.NewServer(&HelloHandler{},
        server.WithMiddleware(RecoveryMiddleware),
        server.WithMiddleware(MetricsMiddleware),
        server.WithMiddleware(LoggingMiddleware),
    )
    
    if err := svr.Run(); err != nil {
        klog.Fatal(err)
    }
}

小结

本章介绍了 Kitex 中间件机制的完整内容:

  1. 中间件原理

    • 洋葱模型:请求 → 中间件链 → 处理器 → 中间件链 → 响应
    • 中间件签名:Middleware func(next endpoint.Endpoint) endpoint.Endpoint
  2. 基本使用

    • 服务端中间件:处理服务端请求
    • 客户端中间件:处理客户端请求
  3. 常用中间件实现

    • 日志中间件:记录请求和响应信息
    • 监控中间件:收集性能指标
    • 认证中间件:验证请求身份
    • 限流中间件:控制请求速率
    • 恢复中间件:捕获和处理 panic
    • 缓存中间件:缓存响应结果
    • 超时控制中间件:控制请求超时
    • CORS 中间件:处理跨域请求
    • 灰度发布中间件:实现灰度发布
    • 链路追踪中间件:跟踪请求链路
  4. 中间件设计原则

    • 单一职责:每个中间件只负责一个功能
    • 可组合性:支持任意组合和链式调用
    • 性能考虑:避免耗时操作,优化内存使用
    • 错误处理:正确传递错误,提供清晰信息
    • 上下文管理:合理使用和清理上下文数据
  5. 中间件性能优化

    • 内存优化:预分配内存,减少分配
    • 并发优化:使用 sync.Pool 复用对象
    • 缓存优化:使用本地缓存提高性能
  6. 中间件最佳实践

    • 注册顺序:从外到内,先通用后具体
    • 中间件分组:按功能分组管理
    • 条件中间件:根据环境或方法启用
  7. 高级用法

    • 中间件分组:组织和管理多个中间件
    • 条件中间件:根据条件选择性启用
    • 环境感知中间件:根据环境配置不同行为

通过本章的学习,你应该掌握了如何设计、实现和使用 Kitex 中间件,以及如何通过中间件扩展和增强服务的功能。在下一章中,我们将学习 Kitex 流式处理,了解如何处理流式请求和响应。