Skip to content

路由与中间件

概述

本章将深入介绍 Hertz 的路由系统和中间件机制。Hertz 提供了灵活的路由功能,支持多种路由模式,同时中间件机制允许在请求处理前后执行通用逻辑,实现认证、日志、限流等横切关注点。

核心内容

路由基础

基本路由

go
package main

import (
    "context"
    
    "github.com/cloudwego/hertz/pkg/app"
    "github.com/cloudwego/hertz/pkg/app/server"
)

func main() {
    h := server.Default()
    
    h.GET("/users", listUsers)
    h.POST("/users", createUser)
    h.PUT("/users/:id", updateUser)
    h.DELETE("/users/:id", deleteUser)
    h.PATCH("/users/:id", patchUser)
    h.HEAD("/users", headUsers)
    h.OPTIONS("/users", optionsUsers)
    
    h.Spin()
}

func listUsers(c context.Context, ctx *app.RequestContext) {
    ctx.JSON(200, map[string]interface{}{
        "users": []string{"Alice", "Bob"},
    })
}

func createUser(c context.Context, ctx *app.RequestContext) {
    ctx.JSON(201, map[string]interface{}{
        "id":   1,
        "name": "New User",
    })
}

路由参数

go
func main() {
    h := server.Default()
    
    // 命名参数
    h.GET("/users/:id", func(c context.Context, ctx *app.RequestContext) {
        id := ctx.Param("id")
        ctx.String(200, "User ID: %s", id)
    })
    
    // 通配参数
    h.GET("/files/*filepath", func(c context.Context, ctx *app.RequestContext) {
        filepath := ctx.Param("filepath")
        ctx.String(200, "File: %s", filepath)
    })
    
    // 正则参数
    h.GET("/users/:id:^\\d+$", func(c context.Context, ctx *app.RequestContext) {
        id := ctx.Param("id")
        ctx.String(200, "User ID (numeric): %s", id)
    })
    
    h.Spin()
}

路由组

go
func main() {
    h := server.Default()
    
    // API 路由组
    api := h.Group("/api")
    {
        api.GET("/status", statusHandler)
    }
    
    // 版本化路由组
    v1 := h.Group("/api/v1")
    {
        v1.GET("/users", listUsers)
        v1.POST("/users", createUser)
        v1.GET("/users/:id", getUser)
    }
    
    v2 := h.Group("/api/v2")
    {
        v2.GET("/users", listUsersV2)
        v2.POST("/users", createUserV2)
    }
    
    // 带中间件的路由组
    auth := h.Group("/api", AuthMiddleware())
    {
        auth.GET("/profile", getProfile)
        auth.PUT("/profile", updateProfile)
    }
    
    h.Spin()
}

高级路由

Any 方法

go
h.Any("/api", func(c context.Context, ctx *app.RequestContext) {
    ctx.JSON(200, map[string]interface{}{
        "method": string(ctx.Method()),
        "path":   string(ctx.URI().Path()),
    })
})

匹配所有

go
h.NoRoute(func(c context.Context, ctx *app.RequestContext) {
    ctx.JSON(404, map[string]interface{}{
        "code":    404,
        "message": "Not Found",
    })
})

h.NoMethod(func(c context.Context, ctx *app.RequestContext) {
    ctx.JSON(405, map[string]interface{}{
        "code":    405,
        "message": "Method Not Allowed",
    })
})

路由重定向

go
h.GET("/old-path", func(c context.Context, ctx *app.RequestContext) {
    ctx.Redirect(301, []byte("/new-path"))
})

h.GET("/new-path", func(c context.Context, ctx *app.RequestContext) {
    ctx.String(200, "New Path")
})

中间件基础

创建中间件

go
func Logger() app.HandlerFunc {
    return func(c context.Context, ctx *app.RequestContext) {
        start := time.Now()
        path := string(ctx.URI().Path())
        
        ctx.Next(c)
        
        latency := time.Since(start)
        klog.Infof("[%s] %s %d %v",
            ctx.Method(),
            path,
            ctx.Response.StatusCode(),
            latency,
        )
    }
}

func main() {
    h := server.Default()
    
    h.Use(Logger())
    
    h.GET("/hello", func(c context.Context, ctx *app.RequestContext) {
        ctx.String(200, "Hello!")
    })
    
    h.Spin()
}

中间件执行顺序

go
func Middleware1() app.HandlerFunc {
    return func(c context.Context, ctx *app.RequestContext) {
        fmt.Println("Middleware1 - Before")
        ctx.Next(c)
        fmt.Println("Middleware1 - After")
    }
}

func Middleware2() app.HandlerFunc {
    return func(c context.Context, ctx *app.RequestContext) {
        fmt.Println("Middleware2 - Before")
        ctx.Next(c)
        fmt.Println("Middleware2 - After")
    }
}

func main() {
    h := server.Default()
    
    h.Use(Middleware1())
    h.Use(Middleware2())
    
    h.GET("/test", func(c context.Context, ctx *app.RequestContext) {
        fmt.Println("Handler")
        ctx.String(200, "OK")
    })
    
    h.Spin()
}

// 输出顺序:
// Middleware1 - Before
// Middleware2 - Before
// Handler
// Middleware2 - After
// Middleware1 - After

常用中间件

认证中间件

go
func AuthMiddleware() app.HandlerFunc {
    return func(c context.Context, ctx *app.RequestContext) {
        token := string(ctx.GetHeader("Authorization"))
        
        if token == "" {
            ctx.AbortWithStatusJSON(401, map[string]interface{}{
                "code":    401,
                "message": "Unauthorized",
            })
            return
        }
        
        userID, err := validateToken(token)
        if err != nil {
            ctx.AbortWithStatusJSON(401, map[string]interface{}{
                "code":    401,
                "message": "Invalid token",
            })
            return
        }
        
        ctx.Set("userID", userID)
        ctx.Next(c)
    }
}

func validateToken(token string) (string, error) {
    return "user123", nil
}

func main() {
    h := server.Default()
    
    protected := h.Group("/api")
    protected.Use(AuthMiddleware())
    {
        protected.GET("/profile", func(c context.Context, ctx *app.RequestContext) {
            userID := ctx.GetString("userID")
            ctx.JSON(200, map[string]interface{}{
                "user_id": userID,
            })
        })
    }
    
    h.Spin()
}

跨域中间件

go
func CORSMiddleware() app.HandlerFunc {
    return func(c context.Context, ctx *app.RequestContext) {
        ctx.Header("Access-Control-Allow-Origin", "*")
        ctx.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
        ctx.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
        ctx.Header("Access-Control-Max-Age", "86400")
        
        if string(ctx.Method()) == "OPTIONS" {
            ctx.AbortWithStatus(204)
            return
        }
        
        ctx.Next(c)
    }
}

func main() {
    h := server.Default()
    h.Use(CORSMiddleware())
    h.Spin()
}

限流中间件

go
import (
    "sync"
    "time"
)

type RateLimiter struct {
    mu       sync.Mutex
    requests map[string][]time.Time
    limit    int
    window   time.Duration
}

func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
    return &RateLimiter{
        requests: make(map[string][]time.Time),
        limit:    limit,
        window:   window,
    }
}

func (rl *RateLimiter) Allow(ip string) bool {
    rl.mu.Lock()
    defer rl.mu.Unlock()
    
    now := time.Now()
    windowStart := now.Add(-rl.window)
    
    requests := rl.requests[ip]
    validRequests := []time.Time{}
    
    for _, t := range requests {
        if t.After(windowStart) {
            validRequests = append(validRequests, t)
        }
    }
    
    if len(validRequests) >= rl.limit {
        rl.requests[ip] = validRequests
        return false
    }
    
    validRequests = append(validRequests, now)
    rl.requests[ip] = validRequests
    return true
}

func RateLimitMiddleware(limiter *RateLimiter) app.HandlerFunc {
    return func(c context.Context, ctx *app.RequestContext) {
        ip := ctx.ClientIP()
        
        if !limiter.Allow(string(ip)) {
            ctx.AbortWithStatusJSON(429, map[string]interface{}{
                "code":    429,
                "message": "Too Many Requests",
            })
            return
        }
        
        ctx.Next(c)
    }
}

func main() {
    h := server.Default()
    
    limiter := NewRateLimiter(100, time.Minute)
    h.Use(RateLimitMiddleware(limiter))
    
    h.Spin()
}

恢复中间件

go
func RecoveryMiddleware() app.HandlerFunc {
    return func(c context.Context, ctx *app.RequestContext) {
        defer func() {
            if err := recover(); err != nil {
                klog.Errorf("Panic recovered: %v\n%s", err, debug.Stack())
                
                ctx.AbortWithStatusJSON(500, map[string]interface{}{
                    "code":    500,
                    "message": "Internal Server Error",
                })
            }
        }()
        
        ctx.Next(c)
    }
}

func main() {
    h := server.Default()
    h.Use(RecoveryMiddleware())
    
    h.GET("/panic", func(c context.Context, ctx *app.RequestContext) {
        panic("something went wrong")
    })
    
    h.Spin()
}

请求 ID 中间件

go
import (
    "github.com/google/uuid"
)

func RequestIDMiddleware() app.HandlerFunc {
    return func(c context.Context, ctx *app.RequestContext) {
        requestID := string(ctx.GetHeader("X-Request-ID"))
        
        if requestID == "" {
            requestID = uuid.New().String()
        }
        
        ctx.Set("request_id", requestID)
        ctx.Header("X-Request-ID", requestID)
        
        ctx.Next(c)
    }
}

func main() {
    h := server.Default()
    h.Use(RequestIDMiddleware())
    
    h.GET("/test", func(c context.Context, ctx *app.RequestContext) {
        requestID := ctx.GetString("request_id")
        ctx.JSON(200, map[string]string{
            "request_id": requestID,
        })
    })
    
    h.Spin()
}

中间件作用域

全局中间件

go
func main() {
    h := server.Default()
    
    h.Use(LoggerMiddleware())
    h.Use(RecoveryMiddleware())
    
    h.GET("/public", publicHandler)
    h.GET("/private", privateHandler)
    
    h.Spin()
}

路由组中间件

go
func main() {
    h := server.Default()
    
    api := h.Group("/api")
    api.Use(AuthMiddleware())
    {
        api.GET("/users", listUsers)
        api.POST("/users", createUser)
    }
    
    public := h.Group("/public")
    {
        public.GET("/status", statusHandler)
    }
    
    h.Spin()
}

单路由中间件

go
func main() {
    h := server.Default()
    
    h.GET("/public", publicHandler)
    
    h.GET("/private", AuthMiddleware(), privateHandler)
    
    h.GET("/admin", AuthMiddleware(), AdminMiddleware(), adminHandler)
    
    h.Spin()
}

完整示例

go
package main

import (
    "context"
    "fmt"
    "log"
    "runtime/debug"
    "time"
    
    "github.com/cloudwego/hertz/pkg/app"
    "github.com/cloudwego/hertz/pkg/app/server"
    "github.com/cloudwego/hertz/pkg/common/hlog"
)

func LoggerMiddleware() app.HandlerFunc {
    return func(c context.Context, ctx *app.RequestContext) {
        start := time.Now()
        
        ctx.Next(c)
        
        hlog.Infof("[%s] %s %d %v",
            ctx.Method(),
            ctx.URI().Path(),
            ctx.Response.StatusCode(),
            time.Since(start),
        )
    }
}

func RecoveryMiddleware() app.HandlerFunc {
    return func(c context.Context, ctx *app.RequestContext) {
        defer func() {
            if err := recover(); err != nil {
                hlog.Errorf("Panic: %v\n%s", err, debug.Stack())
                ctx.AbortWithStatusJSON(500, map[string]interface{}{
                    "code":    500,
                    "message": "Internal Server Error",
                })
            }
        }()
        ctx.Next(c)
    }
}

func AuthMiddleware() app.HandlerFunc {
    return func(c context.Context, ctx *app.RequestContext) {
        token := string(ctx.GetHeader("Authorization"))
        
        if token == "" {
            ctx.AbortWithStatusJSON(401, map[string]interface{}{
                "code":    401,
                "message": "Unauthorized",
            })
            return
        }
        
        ctx.Set("userID", "user123")
        ctx.Next(c)
    }
}

func main() {
    h := server.Default()
    
    h.Use(LoggerMiddleware())
    h.Use(RecoveryMiddleware())
    
    h.GET("/health", func(c context.Context, ctx *app.RequestContext) {
        ctx.JSON(200, map[string]string{"status": "ok"})
    })
    
    api := h.Group("/api/v1")
    api.Use(AuthMiddleware())
    {
        api.GET("/users", func(c context.Context, ctx *app.RequestContext) {
            userID := ctx.GetString("userID")
            ctx.JSON(200, map[string]interface{}{
                "users":   []string{"Alice", "Bob"},
                "user_id": userID,
            })
        })
        
        api.GET("/users/:id", func(c context.Context, ctx *app.RequestContext) {
            id := ctx.Param("id")
            ctx.JSON(200, map[string]interface{}{
                "id":   id,
                "name": "User " + id,
            })
        })
    }
    
    h.NoRoute(func(c context.Context, ctx *app.RequestContext) {
        ctx.JSON(404, map[string]interface{}{
            "code":    404,
            "message": "Not Found",
        })
    })
    
    log.Println("Server starting on :8080")
    h.Spin()
}

小结

本章介绍了 Hertz 的路由与中间件:

  1. 路由基础:基本路由、路由参数、路由组
  2. 高级路由:Any 方法、NoRoute、NoMethod、重定向
  3. 中间件基础:创建中间件、执行顺序
  4. 常用中间件:认证、跨域、限流、恢复、请求 ID
  5. 中间件作用域:全局、路由组、单路由

在下一章中,我们将学习 Hertz 的参数绑定与验证。