Kitex 中间件机制
概述
中间件是 Kitex 的核心特性之一,用于在请求处理过程中插入通用逻辑,如日志记录、监控统计、认证鉴权等。本章将深入介绍 Kitex 中间件的原理、使用方式和最佳实践。
核心内容
中间件原理
Kitex 中间件采用洋葱模型:
请求 → Middleware1 → Middleware2 → Handler → Middleware2 → Middleware1 → 响应
↓ ↓ ↓ ↑ ↑
前置处理 前置处理 业务处理 后置处理 后置处理中间件签名
go
type Middleware func(next endpoint.Endpoint) endpoint.Endpoint
type Endpoint func(ctx context.Context, req, resp interface{}) error基本使用
服务端中间件
go
package main
import (
"context"
"time"
"github.com/cloudwego/kitex/pkg/klog"
"github.com/cloudwego/kitex/server"
)
// 日志中间件
func LoggingMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
start := time.Now()
// 前置处理
klog.CtxInfof(ctx, "[Server] request: %v", req)
// 调用下一个中间件或处理器
err := next(ctx, req, resp)
// 后置处理
klog.CtxInfof(ctx, "[Server] response: %v, cost: %v, err: %v",
resp, time.Since(start), err)
return err
}
}
func main() {
svr := helloservice.NewServer(&HelloHandler{},
server.WithMiddleware(LoggingMiddleware),
)
svr.Run()
}客户端中间件
go
package main
import (
"github.com/cloudwego/kitex/client"
)
func main() {
cli := helloservice.MustNewClient("hello",
client.WithHostPorts("127.0.0.1:8888"),
client.WithMiddleware(LoggingMiddleware),
)
}常用中间件实现
1. 日志中间件
go
func LoggingMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
start := time.Now()
// 获取方法名
ri := rpcinfo.GetRPCInfo(ctx)
methodName := ri.To().Method()
klog.CtxInfof(ctx, "[%s] request: %+v", methodName, req)
err := next(ctx, req, resp)
cost := time.Since(start)
if err != nil {
klog.CtxErrorf(ctx, "[%s] error: %v, cost: %v", methodName, err, cost)
} else {
klog.CtxInfof(ctx, "[%s] success, cost: %v", methodName, cost)
}
return err
}
}2. 监控中间件
go
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
requestCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "rpc_requests_total",
Help: "Total number of RPC requests",
}, []string{"service", "method", "status"})
requestLatency = promauto.NewHistogramVec(prometheus.HistogramOpts{
Name: "rpc_request_latency_seconds",
Help: "RPC request latency in seconds",
Buckets: []float64{.001, .005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10},
}, []string{"service", "method"})
)
func MetricsMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
start := time.Now()
ri := rpcinfo.GetRPCInfo(ctx)
service := ri.To().ServiceName()
method := ri.To().Method()
err := next(ctx, req, resp)
status := "success"
if err != nil {
status = "error"
}
requestCounter.WithLabelValues(service, method, status).Inc()
requestLatency.WithLabelValues(service, method).Observe(time.Since(start).Seconds())
return err
}
}3. 认证中间件
go
func AuthMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
// 获取认证信息
token := transmeta.GetMeta(ctx, "authorization")
if token == "" {
return errors.New("missing authorization token")
}
// 验证 token
userId, err := validateToken(token)
if err != nil {
return fmt.Errorf("invalid token: %w", err)
}
// 设置用户信息到上下文
ctx = context.WithValue(ctx, "user-id", userId)
return next(ctx, req, resp)
}
}4. 限流中间件
go
import "golang.org/x/time/rate"
func RateLimitMiddleware(limiter *rate.Limiter) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
if !limiter.Allow() {
return errors.New("rate limit exceeded")
}
return next(ctx, req, resp)
}
}
}
// 使用
func main() {
limiter := rate.NewLimiter(rate.Limit(1000), 100) // 1000 QPS,突发 100
svr := helloservice.NewServer(&HelloHandler{},
server.WithMiddleware(RateLimitMiddleware(limiter)),
)
}5. 恢复中间件
go
func RecoveryMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
klog.CtxErrorf(ctx, "panic recovered: %v\n%s", r, debug.Stack())
err = fmt.Errorf("internal error: %v", r)
}
}()
return next(ctx, req, resp)
}
}6. 缓存中间件
go
import "sync"
type CacheItem struct {
value interface{}
expiration time.Time
}
type Cache struct {
items map[string]*CacheItem
mu sync.RWMutex
}
func NewCache() *Cache {
return &Cache{
items: make(map[string]*CacheItem),
}
}
func (c *Cache) Get(key string) (interface{}, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
item, ok := c.items[key]
if !ok {
return nil, false
}
if time.Now().After(item.expiration) {
return nil, false
}
return item.value, true
}
func (c *Cache) Set(key string, value interface{}, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.items[key] = &CacheItem{
value: value,
expiration: time.Now().Add(ttl),
}
}
func CacheMiddleware(cache *Cache, ttl time.Duration) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
// 生成缓存键
key := generateCacheKey(ctx, req)
// 检查缓存
if cached, found := cache.Get(key); found {
// 复制缓存值到 resp
if err := copyResponse(resp, cached); err == nil {
klog.CtxInfof(ctx, "cache hit for key: %s", key)
return nil
}
}
// 调用下一个中间件
err := next(ctx, req, resp)
// 缓存响应
if err == nil {
cache.Set(key, resp, ttl)
klog.CtxInfof(ctx, "cache set for key: %s", key)
}
return err
}
}
}
func generateCacheKey(ctx context.Context, req interface{}) string {
ri := rpcinfo.GetRPCInfo(ctx)
method := ri.To().Method()
return fmt.Sprintf("%s:%v", method, req)
}
func copyResponse(dst, src interface{}) error {
// 实现响应复制逻辑
// 可以使用反射或序列化/反序列化
return nil
}7. 超时控制中间件
go
func TimeoutMiddleware(timeout time.Duration) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
// 创建带超时的上下文
ctxWithTimeout, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// 通道用于接收结果
errCh := make(chan error, 1)
go func() {
errCh <- next(ctxWithTimeout, req, resp)
}()
select {
case err := <-errCh:
return err
case <-ctxWithTimeout.Done():
return fmt.Errorf("request timeout after %v", timeout)
}
}
}
}8. CORS 中间件
go
func CORSMiddleware() endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
// 获取 HTTP 上下文
if httpCtx, ok := httpx.GetHTTPContext(ctx); ok {
// 设置 CORS 头部
httpCtx.Response.Header.Set("Access-Control-Allow-Origin", "*")
httpCtx.Response.Header.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
httpCtx.Response.Header.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
// 处理 OPTIONS 请求
if httpCtx.Request.Method == "OPTIONS" {
httpCtx.Response.StatusCode = 204
return nil
}
}
return next(ctx, req, resp)
}
}
}9. 灰度发布中间件
go
func CanaryMiddleware(canaryRatio float64) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
// 检查是否应该路由到灰度版本
if shouldRouteToCanary(canaryRatio) {
klog.CtxInfof(ctx, "routing to canary version")
// 这里可以修改上下文或请求,路由到灰度服务
}
return next(ctx, req, resp)
}
}
}
func shouldRouteToCanary(ratio float64) bool {
return rand.Float64() < ratio
}10. 链路追踪中间件
go
import "go.opentelemetry.io/otel"
func TracingMiddleware() endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
ri := rpcinfo.GetRPCInfo(ctx)
method := ri.To().Method()
service := ri.To().ServiceName()
// 创建 span
tracer := otel.Tracer("kitex")
ctx, span := tracer.Start(ctx, method)
defer span.End()
// 设置 span 属性
span.SetAttributes(
attribute.String("service.name", service),
attribute.String("method.name", method),
)
// 调用下一个中间件
err := next(ctx, req, resp)
// 记录错误
if err != nil {
span.RecordError(err)
}
return err
}
}
}中间件顺序
中间件的执行顺序很重要:
go
func main() {
svr := helloservice.NewServer(&HelloHandler{},
// 执行顺序:Recovery → Metrics → Auth → Logging → Handler
server.WithMiddleware(RecoveryMiddleware), // 最外层,捕获 panic
server.WithMiddleware(MetricsMiddleware), // 监控统计
server.WithMiddleware(AuthMiddleware), // 认证鉴权
server.WithMiddleware(LoggingMiddleware), // 日志记录
)
}中间件分组
go
// 业务中间件组
func BusinessMiddlewareGroup() []endpoint.Middleware {
return []endpoint.Middleware{
RecoveryMiddleware,
LoggingMiddleware,
MetricsMiddleware,
}
}
func main() {
var middlewares []endpoint.Middleware
middlewares = append(middlewares, BusinessMiddlewareGroup()...)
svr := helloservice.NewServer(&HelloHandler{},
server.WithMiddlewares(middlewares),
)
}条件中间件
go
func ConditionalMiddleware(condition func(ctx context.Context) bool, mw endpoint.Middleware) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
if condition(ctx) {
return mw(next)(ctx, req, resp)
}
return next(ctx, req, resp)
}
}
}
// 使用示例:仅对特定方法启用认证
func main() {
svr := helloservice.NewServer(&HelloHandler{},
server.WithMiddleware(
ConditionalMiddleware(
func(ctx context.Context) bool {
ri := rpcinfo.GetRPCInfo(ctx)
return ri.To().Method() != "HealthCheck"
},
AuthMiddleware,
),
),
)
}中间件设计原则
1. 单一职责
- 每个中间件只负责一个功能
- 避免在一个中间件中实现多个功能
- 保持中间件逻辑简洁清晰
2. 可组合性
- 中间件应该可以任意组合
- 不依赖于其他中间件的存在
- 支持链式调用
3. 性能考虑
- 避免在中间件中执行耗时操作
- 减少内存分配和复制
- 合理使用缓存
- 避免不必要的网络调用
4. 错误处理
- 正确处理和传递错误
- 不要吞掉错误
- 提供清晰的错误信息
5. 上下文管理
- 合理使用上下文传递信息
- 避免在上下文中存储过多数据
- 清理不再需要的上下文数据
中间件性能优化
1. 内存优化
go
// 预分配内存
func MemoryOptimizedMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
// 预分配切片容量
buffer := make([]byte, 0, 1024)
// 使用 buffer 进行操作
// ...
return next(ctx, req, resp)
}
}2. 并发优化
go
// 使用 sync.Pool 复用对象
var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 1024)
},
}
func ConcurrentOptimizedMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
buffer := bufferPool.Get().([]byte)
defer bufferPool.Put(buffer)
// 使用 buffer
// ...
return next(ctx, req, resp)
}
}3. 缓存优化
go
// 使用本地缓存
var localCache = sync.Map{}
func CacheOptimizedMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
key := generateKey(req)
if val, ok := localCache.Load(key); ok {
// 使用缓存值
return nil
}
// 调用下一个中间件
err := next(ctx, req, resp)
// 缓存结果
localCache.Store(key, resp)
return err
}
}中间件最佳实践
1. 中间件注册顺序
推荐顺序:
- Recovery:最外层,捕获 panic
- Metrics:监控统计
- Tracing:链路追踪
- CORS:跨域处理
- Auth:认证鉴权
- RateLimit:限流
- Cache:缓存
- Logging:日志记录
- Business Logic:业务逻辑
2. 中间件分组
go
// 通用中间件组
func CommonMiddlewareGroup() []endpoint.Middleware {
return []endpoint.Middleware{
RecoveryMiddleware,
MetricsMiddleware,
TracingMiddleware,
LoggingMiddleware,
}
}
// 安全中间件组
func SecurityMiddlewareGroup() []endpoint.Middleware {
return []endpoint.Middleware{
AuthMiddleware,
RateLimitMiddleware(rate.NewLimiter(1000, 100)),
CORSMiddleware(),
}
}
func main() {
middlewares := append(CommonMiddlewareGroup(), SecurityMiddlewareGroup()...)
svr := helloservice.NewServer(&HelloHandler{},
server.WithMiddlewares(middlewares),
)
}3. 条件中间件
go
// 根据环境启用不同中间件
func EnvironmentAwareMiddleware() endpoint.Middleware {
if os.Getenv("APP_ENV") == "production" {
return ProductionMiddleware()
}
return DevelopmentMiddleware()
}
// 根据方法启用中间件
func MethodSpecificMiddleware(methods []string, mw endpoint.Middleware) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
ri := rpcinfo.GetRPCInfo(ctx)
method := ri.To().Method()
for _, m := range methods {
if m == method {
return mw(next)(ctx, req, resp)
}
}
return next(ctx, req, resp)
}
}
}完整示例
go
package main
import (
"context"
"errors"
"fmt"
"runtime/debug"
"time"
"github.com/cloudwego/kitex/pkg/klog"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/server"
"example/kitex_gen/hello/helloservice"
)
// 日志中间件
func LoggingMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
start := time.Now()
ri := rpcinfo.GetRPCInfo(ctx)
klog.CtxInfof(ctx, "[%s] request: %+v", ri.To().Method(), req)
err := next(ctx, req, resp)
klog.CtxInfof(ctx, "[%s] cost: %v, error: %v", ri.To().Method(), time.Since(start), err)
return err
}
}
// 恢复中间件
func RecoveryMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
klog.CtxErrorf(ctx, "panic recovered: %v\n%s", r, debug.Stack())
err = fmt.Errorf("internal error: %v", r)
}
}()
return next(ctx, req, resp)
}
}
// 监控中间件
func MetricsMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) error {
start := time.Now()
ri := rpcinfo.GetRPCInfo(ctx)
err := next(ctx, req, resp)
// 记录 metrics
klog.CtxInfof(ctx, "[%s] latency: %v", ri.To().Method(), time.Since(start))
return err
}
}
type HelloHandler struct{}
func (h *HelloHandler) SayHello(ctx context.Context, req *hello.Request) (*hello.Response, error) {
return &hello.Response{Message: "Hello, " + req.Name + "!"}, nil
}
func main() {
// 中间件注册
svr := helloservice.NewServer(&HelloHandler{},
server.WithMiddleware(RecoveryMiddleware),
server.WithMiddleware(MetricsMiddleware),
server.WithMiddleware(LoggingMiddleware),
)
if err := svr.Run(); err != nil {
klog.Fatal(err)
}
}小结
本章介绍了 Kitex 中间件机制的完整内容:
中间件原理:
- 洋葱模型:请求 → 中间件链 → 处理器 → 中间件链 → 响应
- 中间件签名:
Middleware func(next endpoint.Endpoint) endpoint.Endpoint
基本使用:
- 服务端中间件:处理服务端请求
- 客户端中间件:处理客户端请求
常用中间件实现:
- 日志中间件:记录请求和响应信息
- 监控中间件:收集性能指标
- 认证中间件:验证请求身份
- 限流中间件:控制请求速率
- 恢复中间件:捕获和处理 panic
- 缓存中间件:缓存响应结果
- 超时控制中间件:控制请求超时
- CORS 中间件:处理跨域请求
- 灰度发布中间件:实现灰度发布
- 链路追踪中间件:跟踪请求链路
中间件设计原则:
- 单一职责:每个中间件只负责一个功能
- 可组合性:支持任意组合和链式调用
- 性能考虑:避免耗时操作,优化内存使用
- 错误处理:正确传递错误,提供清晰信息
- 上下文管理:合理使用和清理上下文数据
中间件性能优化:
- 内存优化:预分配内存,减少分配
- 并发优化:使用 sync.Pool 复用对象
- 缓存优化:使用本地缓存提高性能
中间件最佳实践:
- 注册顺序:从外到内,先通用后具体
- 中间件分组:按功能分组管理
- 条件中间件:根据环境或方法启用
高级用法:
- 中间件分组:组织和管理多个中间件
- 条件中间件:根据条件选择性启用
- 环境感知中间件:根据环境配置不同行为
通过本章的学习,你应该掌握了如何设计、实现和使用 Kitex 中间件,以及如何通过中间件扩展和增强服务的功能。在下一章中,我们将学习 Kitex 流式处理,了解如何处理流式请求和响应。