路由与中间件
概述
本章将深入介绍 Hertz 的路由系统和中间件机制。Hertz 提供了灵活的路由功能,支持多种路由模式,同时中间件机制允许在请求处理前后执行通用逻辑,实现认证、日志、限流等横切关注点。
核心内容
路由基础
基本路由
go
package main
import (
"context"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
)
func main() {
h := server.Default()
h.GET("/users", listUsers)
h.POST("/users", createUser)
h.PUT("/users/:id", updateUser)
h.DELETE("/users/:id", deleteUser)
h.PATCH("/users/:id", patchUser)
h.HEAD("/users", headUsers)
h.OPTIONS("/users", optionsUsers)
h.Spin()
}
func listUsers(c context.Context, ctx *app.RequestContext) {
ctx.JSON(200, map[string]interface{}{
"users": []string{"Alice", "Bob"},
})
}
func createUser(c context.Context, ctx *app.RequestContext) {
ctx.JSON(201, map[string]interface{}{
"id": 1,
"name": "New User",
})
}路由参数
go
func main() {
h := server.Default()
// 命名参数
h.GET("/users/:id", func(c context.Context, ctx *app.RequestContext) {
id := ctx.Param("id")
ctx.String(200, "User ID: %s", id)
})
// 通配参数
h.GET("/files/*filepath", func(c context.Context, ctx *app.RequestContext) {
filepath := ctx.Param("filepath")
ctx.String(200, "File: %s", filepath)
})
// 正则参数
h.GET("/users/:id:^\\d+$", func(c context.Context, ctx *app.RequestContext) {
id := ctx.Param("id")
ctx.String(200, "User ID (numeric): %s", id)
})
h.Spin()
}路由组
go
func main() {
h := server.Default()
// API 路由组
api := h.Group("/api")
{
api.GET("/status", statusHandler)
}
// 版本化路由组
v1 := h.Group("/api/v1")
{
v1.GET("/users", listUsers)
v1.POST("/users", createUser)
v1.GET("/users/:id", getUser)
}
v2 := h.Group("/api/v2")
{
v2.GET("/users", listUsersV2)
v2.POST("/users", createUserV2)
}
// 带中间件的路由组
auth := h.Group("/api", AuthMiddleware())
{
auth.GET("/profile", getProfile)
auth.PUT("/profile", updateProfile)
}
h.Spin()
}高级路由
Any 方法
go
h.Any("/api", func(c context.Context, ctx *app.RequestContext) {
ctx.JSON(200, map[string]interface{}{
"method": string(ctx.Method()),
"path": string(ctx.URI().Path()),
})
})匹配所有
go
h.NoRoute(func(c context.Context, ctx *app.RequestContext) {
ctx.JSON(404, map[string]interface{}{
"code": 404,
"message": "Not Found",
})
})
h.NoMethod(func(c context.Context, ctx *app.RequestContext) {
ctx.JSON(405, map[string]interface{}{
"code": 405,
"message": "Method Not Allowed",
})
})路由重定向
go
h.GET("/old-path", func(c context.Context, ctx *app.RequestContext) {
ctx.Redirect(301, []byte("/new-path"))
})
h.GET("/new-path", func(c context.Context, ctx *app.RequestContext) {
ctx.String(200, "New Path")
})中间件基础
创建中间件
go
func Logger() app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
start := time.Now()
path := string(ctx.URI().Path())
ctx.Next(c)
latency := time.Since(start)
klog.Infof("[%s] %s %d %v",
ctx.Method(),
path,
ctx.Response.StatusCode(),
latency,
)
}
}
func main() {
h := server.Default()
h.Use(Logger())
h.GET("/hello", func(c context.Context, ctx *app.RequestContext) {
ctx.String(200, "Hello!")
})
h.Spin()
}中间件执行顺序
go
func Middleware1() app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
fmt.Println("Middleware1 - Before")
ctx.Next(c)
fmt.Println("Middleware1 - After")
}
}
func Middleware2() app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
fmt.Println("Middleware2 - Before")
ctx.Next(c)
fmt.Println("Middleware2 - After")
}
}
func main() {
h := server.Default()
h.Use(Middleware1())
h.Use(Middleware2())
h.GET("/test", func(c context.Context, ctx *app.RequestContext) {
fmt.Println("Handler")
ctx.String(200, "OK")
})
h.Spin()
}
// 输出顺序:
// Middleware1 - Before
// Middleware2 - Before
// Handler
// Middleware2 - After
// Middleware1 - After常用中间件
认证中间件
go
func AuthMiddleware() app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
token := string(ctx.GetHeader("Authorization"))
if token == "" {
ctx.AbortWithStatusJSON(401, map[string]interface{}{
"code": 401,
"message": "Unauthorized",
})
return
}
userID, err := validateToken(token)
if err != nil {
ctx.AbortWithStatusJSON(401, map[string]interface{}{
"code": 401,
"message": "Invalid token",
})
return
}
ctx.Set("userID", userID)
ctx.Next(c)
}
}
func validateToken(token string) (string, error) {
return "user123", nil
}
func main() {
h := server.Default()
protected := h.Group("/api")
protected.Use(AuthMiddleware())
{
protected.GET("/profile", func(c context.Context, ctx *app.RequestContext) {
userID := ctx.GetString("userID")
ctx.JSON(200, map[string]interface{}{
"user_id": userID,
})
})
}
h.Spin()
}跨域中间件
go
func CORSMiddleware() app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
ctx.Header("Access-Control-Allow-Origin", "*")
ctx.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
ctx.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
ctx.Header("Access-Control-Max-Age", "86400")
if string(ctx.Method()) == "OPTIONS" {
ctx.AbortWithStatus(204)
return
}
ctx.Next(c)
}
}
func main() {
h := server.Default()
h.Use(CORSMiddleware())
h.Spin()
}限流中间件
go
import (
"sync"
"time"
)
type RateLimiter struct {
mu sync.Mutex
requests map[string][]time.Time
limit int
window time.Duration
}
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
return &RateLimiter{
requests: make(map[string][]time.Time),
limit: limit,
window: window,
}
}
func (rl *RateLimiter) Allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
windowStart := now.Add(-rl.window)
requests := rl.requests[ip]
validRequests := []time.Time{}
for _, t := range requests {
if t.After(windowStart) {
validRequests = append(validRequests, t)
}
}
if len(validRequests) >= rl.limit {
rl.requests[ip] = validRequests
return false
}
validRequests = append(validRequests, now)
rl.requests[ip] = validRequests
return true
}
func RateLimitMiddleware(limiter *RateLimiter) app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
ip := ctx.ClientIP()
if !limiter.Allow(string(ip)) {
ctx.AbortWithStatusJSON(429, map[string]interface{}{
"code": 429,
"message": "Too Many Requests",
})
return
}
ctx.Next(c)
}
}
func main() {
h := server.Default()
limiter := NewRateLimiter(100, time.Minute)
h.Use(RateLimitMiddleware(limiter))
h.Spin()
}恢复中间件
go
func RecoveryMiddleware() app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
defer func() {
if err := recover(); err != nil {
klog.Errorf("Panic recovered: %v\n%s", err, debug.Stack())
ctx.AbortWithStatusJSON(500, map[string]interface{}{
"code": 500,
"message": "Internal Server Error",
})
}
}()
ctx.Next(c)
}
}
func main() {
h := server.Default()
h.Use(RecoveryMiddleware())
h.GET("/panic", func(c context.Context, ctx *app.RequestContext) {
panic("something went wrong")
})
h.Spin()
}请求 ID 中间件
go
import (
"github.com/google/uuid"
)
func RequestIDMiddleware() app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
requestID := string(ctx.GetHeader("X-Request-ID"))
if requestID == "" {
requestID = uuid.New().String()
}
ctx.Set("request_id", requestID)
ctx.Header("X-Request-ID", requestID)
ctx.Next(c)
}
}
func main() {
h := server.Default()
h.Use(RequestIDMiddleware())
h.GET("/test", func(c context.Context, ctx *app.RequestContext) {
requestID := ctx.GetString("request_id")
ctx.JSON(200, map[string]string{
"request_id": requestID,
})
})
h.Spin()
}中间件作用域
全局中间件
go
func main() {
h := server.Default()
h.Use(LoggerMiddleware())
h.Use(RecoveryMiddleware())
h.GET("/public", publicHandler)
h.GET("/private", privateHandler)
h.Spin()
}路由组中间件
go
func main() {
h := server.Default()
api := h.Group("/api")
api.Use(AuthMiddleware())
{
api.GET("/users", listUsers)
api.POST("/users", createUser)
}
public := h.Group("/public")
{
public.GET("/status", statusHandler)
}
h.Spin()
}单路由中间件
go
func main() {
h := server.Default()
h.GET("/public", publicHandler)
h.GET("/private", AuthMiddleware(), privateHandler)
h.GET("/admin", AuthMiddleware(), AdminMiddleware(), adminHandler)
h.Spin()
}完整示例
go
package main
import (
"context"
"fmt"
"log"
"runtime/debug"
"time"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/cloudwego/hertz/pkg/common/hlog"
)
func LoggerMiddleware() app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
start := time.Now()
ctx.Next(c)
hlog.Infof("[%s] %s %d %v",
ctx.Method(),
ctx.URI().Path(),
ctx.Response.StatusCode(),
time.Since(start),
)
}
}
func RecoveryMiddleware() app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
defer func() {
if err := recover(); err != nil {
hlog.Errorf("Panic: %v\n%s", err, debug.Stack())
ctx.AbortWithStatusJSON(500, map[string]interface{}{
"code": 500,
"message": "Internal Server Error",
})
}
}()
ctx.Next(c)
}
}
func AuthMiddleware() app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
token := string(ctx.GetHeader("Authorization"))
if token == "" {
ctx.AbortWithStatusJSON(401, map[string]interface{}{
"code": 401,
"message": "Unauthorized",
})
return
}
ctx.Set("userID", "user123")
ctx.Next(c)
}
}
func main() {
h := server.Default()
h.Use(LoggerMiddleware())
h.Use(RecoveryMiddleware())
h.GET("/health", func(c context.Context, ctx *app.RequestContext) {
ctx.JSON(200, map[string]string{"status": "ok"})
})
api := h.Group("/api/v1")
api.Use(AuthMiddleware())
{
api.GET("/users", func(c context.Context, ctx *app.RequestContext) {
userID := ctx.GetString("userID")
ctx.JSON(200, map[string]interface{}{
"users": []string{"Alice", "Bob"},
"user_id": userID,
})
})
api.GET("/users/:id", func(c context.Context, ctx *app.RequestContext) {
id := ctx.Param("id")
ctx.JSON(200, map[string]interface{}{
"id": id,
"name": "User " + id,
})
})
}
h.NoRoute(func(c context.Context, ctx *app.RequestContext) {
ctx.JSON(404, map[string]interface{}{
"code": 404,
"message": "Not Found",
})
})
log.Println("Server starting on :8080")
h.Spin()
}小结
本章介绍了 Hertz 的路由与中间件:
- 路由基础:基本路由、路由参数、路由组
- 高级路由:Any 方法、NoRoute、NoMethod、重定向
- 中间件基础:创建中间件、执行顺序
- 常用中间件:认证、跨域、限流、恢复、请求 ID
- 中间件作用域:全局、路由组、单路由
在下一章中,我们将学习 Hertz 的参数绑定与验证。