Skip to content

SSE 流式响应

概述

本章将深入介绍 Hertz 的 SSE(Server-Sent Events)流式响应。SSE 是一种服务器向客户端推送数据的技术,适用于实时通知、日志流、AI 对话等场景。Hertz 提供了完善的 SSE 支持,可以轻松实现服务端推送功能。

核心内容

SSE 基础

SSE 协议格式

data: 第一条消息

data: 第二条消息
id: 2
event: message

data: 第三条消息
id: 3
event: custom
retry: 3000

SSE 消息格式说明:

  • data: 消息数据(必需)
  • id: 消息 ID(可选)
  • event: 事件类型(可选,默认为 message
  • retry: 重连时间(可选,毫秒)
  • 每条消息以两个换行符结束

基本 SSE 响应

go
package main

import (
    "context"
    "fmt"
    "time"
    
    "github.com/cloudwego/hertz/pkg/app"
    "github.com/cloudwego/hertz/pkg/app/server"
    "github.com/cloudwego/hertz/pkg/protocol/sse"
)

func main() {
    h := server.Default()
    
    h.GET("/events", func(c context.Context, ctx *app.RequestContext) {
        ctx.SetContentType("text/event-stream")
        ctx.SetHeader("Cache-Control", "no-cache")
        ctx.SetHeader("Connection", "keep-alive")
        
        for i := 0; i < 10; i++ {
            fmt.Fprintf(ctx, "data: Message %d\n\n", i+1)
            ctx.Flush()
            time.Sleep(1 * time.Second)
        }
    })
    
    h.Spin()
}

SSE 工具库

使用 sse 包

go
import (
    "github.com/cloudwego/hertz/pkg/protocol/sse"
)

func main() {
    h := server.Default()
    
    h.GET("/sse", func(c context.Context, ctx *app.RequestContext) {
        s := sse.NewSSE(ctx)
        
        for i := 0; i < 10; i++ {
            event := &sse.Event{
                Data: []byte(fmt.Sprintf("Message %d", i+1)),
            }
            
            if err := s.Publish(event); err != nil {
                break
            }
            
            time.Sleep(500 * time.Millisecond)
        }
    })
    
    h.Spin()
}

发送带 ID 的消息

go
h.GET("/sse-with-id", func(c context.Context, ctx *app.RequestContext) {
    s := sse.NewSSE(ctx)
    
    for i := 0; i < 10; i++ {
        event := &sse.Event{
            ID:   fmt.Sprintf("%d", i+1),
            Data: []byte(fmt.Sprintf("Message %d", i+1)),
        }
        
        if err := s.Publish(event); err != nil {
            break
        }
        
        time.Sleep(500 * time.Millisecond)
    }
})

发送自定义事件

go
h.GET("/sse-events", func(c context.Context, ctx *app.RequestContext) {
    s := sse.NewSSE(ctx)
    
    events := []struct {
        Event string
        Data  string
    }{
        {"message", "普通消息"},
        {"notification", "通知消息"},
        {"alert", "警告消息"},
        {"update", "更新消息"},
    }
    
    for _, e := range events {
        event := &sse.Event{
            Event: e.Event,
            Data:  []byte(e.Data),
        }
        
        if err := s.Publish(event); err != nil {
            break
        }
        
        time.Sleep(1 * time.Second)
    }
})

实时数据推送

实时日志推送

go
type LogEntry struct {
    Time    string `json:"time"`
    Level   string `json:"level"`
    Message string `json:"message"`
}

func main() {
    h := server.Default()
    
    logChan := make(chan LogEntry, 100)
    
    go func() {
        levels := []string{"INFO", "WARN", "ERROR", "DEBUG"}
        for i := 0; ; i++ {
            logChan <- LogEntry{
                Time:    time.Now().Format(time.RFC3339),
                Level:   levels[i%len(levels)],
                Message: fmt.Sprintf("Log message %d", i+1),
            }
            time.Sleep(500 * time.Millisecond)
        }
    }()
    
    h.GET("/logs/stream", func(c context.Context, ctx *app.RequestContext) {
        s := sse.NewSSE(ctx)
        
        for {
            select {
            case log := <-logChan:
                data, _ := json.Marshal(log)
                event := &sse.Event{
                    Event: "log",
                    Data:  data,
                }
                
                if err := s.Publish(event); err != nil {
                    return
                }
            case <-c.Done():
                return
            }
        }
    })
    
    h.Spin()
}

实时监控数据

go
type Metric struct {
    Timestamp int64   `json:"timestamp"`
    CPU       float64 `json:"cpu"`
    Memory    float64 `json:"memory"`
    Requests  int     `json:"requests"`
}

func main() {
    h := server.Default()
    
    h.GET("/metrics/stream", func(c context.Context, ctx *app.RequestContext) {
        s := sse.NewSSE(ctx)
        
        ticker := time.NewTicker(1 * time.Second)
        defer ticker.Stop()
        
        for {
            select {
            case <-ticker.C:
                metric := Metric{
                    Timestamp: time.Now().Unix(),
                    CPU:       float64(rand.Intn(100)) / 100,
                    Memory:    float64(rand.Intn(100)) / 100,
                    Requests:  rand.Intn(1000),
                }
                
                data, _ := json.Marshal(metric)
                event := &sse.Event{
                    Event: "metric",
                    Data:  data,
                }
                
                if err := s.Publish(event); err != nil {
                    return
                }
            case <-c.Done():
                return
            }
        }
    })
    
    h.Spin()
}

AI 对话流式响应

模拟 AI 对话

go
func main() {
    h := server.Default()
    
    h.POST("/chat", func(c context.Context, ctx *app.RequestContext) {
        var req struct {
            Message string `json:"message"`
        }
        
        if err := ctx.BindJSON(&req); err != nil {
            ctx.JSON(400, map[string]string{"error": err.Error()})
            return
        }
        
        s := sse.NewSSE(ctx)
        
        response := "这是一个模拟的 AI 回复,我们会逐字发送这条消息,模拟真实的流式响应效果。"
        
        words := strings.Split(response, "")
        
        for i, word := range words {
            event := &sse.Event{
                ID:   fmt.Sprintf("%d", i+1),
                Data: []byte(word),
            }
            
            if err := s.Publish(event); err != nil {
                break
            }
            
            time.Sleep(50 * time.Millisecond)
        }
        
        doneEvent := &sse.Event{
            Event: "done",
            Data:  []byte("[DONE]"),
        }
        s.Publish(doneEvent)
    })
    
    h.Spin()
}

结构化 AI 响应

go
type ChatChunk struct {
    Content string `json:"content"`
    Done    bool   `json:"done"`
}

func main() {
    h := server.Default()
    
    h.POST("/chat/stream", func(c context.Context, ctx *app.RequestContext) {
        var req struct {
            Prompt string `json:"prompt"`
        }
        
        if err := ctx.BindJSON(&req); err != nil {
            ctx.JSON(400, map[string]string{"error": err.Error()})
            return
        }
        
        s := sse.NewSSE(ctx)
        
        sentences := []string{
            "这是第一句话。",
            "这是第二句话。",
            "这是第三句话。",
            "回复完成。",
        }
        
        for i, sentence := range sentences {
            chunk := ChatChunk{
                Content: sentence,
                Done:    i == len(sentences)-1,
            }
            
            data, _ := json.Marshal(chunk)
            event := &sse.Event{
                ID:   fmt.Sprintf("%d", i+1),
                Event: "chat",
                Data:  data,
            }
            
            if err := s.Publish(event); err != nil {
                break
            }
            
            time.Sleep(200 * time.Millisecond)
        }
    })
    
    h.Spin()
}

广播推送

消息广播系统

go
type Broker struct {
    clients map[chan *sse.Event]bool
    mu      sync.RWMutex
}

func NewBroker() *Broker {
    return &Broker{
        clients: make(map[chan *sse.Event]bool),
    }
}

func (b *Broker) Subscribe() chan *sse.Event {
    b.mu.Lock()
    defer b.mu.Unlock()
    
    ch := make(chan *sse.Event, 10)
    b.clients[ch] = true
    return ch
}

func (b *Broker) Unsubscribe(ch chan *sse.Event) {
    b.mu.Lock()
    defer b.mu.Unlock()
    
    delete(b.clients, ch)
    close(ch)
}

func (b *Broker) Broadcast(event *sse.Event) {
    b.mu.RLock()
    defer b.mu.RUnlock()
    
    for ch := range b.clients {
        select {
        case ch <- event:
        default:
        }
    }
}

func main() {
    h := server.Default()
    broker := NewBroker()
    
    go func() {
        for i := 0; ; i++ {
            event := &sse.Event{
                ID:   fmt.Sprintf("%d", i+1),
                Event: "broadcast",
                Data:  []byte(fmt.Sprintf("Broadcast message %d", i+1)),
            }
            broker.Broadcast(event)
            time.Sleep(5 * time.Second)
        }
    }()
    
    h.GET("/broadcast", func(c context.Context, ctx *app.RequestContext) {
        s := sse.NewSSE(ctx)
        ch := broker.Subscribe()
        defer broker.Unsubscribe(ch)
        
        for {
            select {
            case event := <-ch:
                if err := s.Publish(event); err != nil {
                    return
                }
            case <-c.Done():
                return
            }
        }
    })
    
    h.POST("/broadcast/send", func(c context.Context, ctx *app.RequestContext) {
        var req struct {
            Message string `json:"message"`
        }
        
        if err := ctx.BindJSON(&req); err != nil {
            ctx.JSON(400, map[string]string{"error": err.Error()})
            return
        }
        
        event := &sse.Event{
            Event: "message",
            Data:  []byte(req.Message),
        }
        broker.Broadcast(event)
        
        ctx.JSON(200, map[string]string{"status": "sent"})
    })
    
    h.Spin()
}

客户端断开处理

检测客户端断开

go
h.GET("/sse-reconnect", func(c context.Context, ctx *app.RequestContext) {
    lastEventID := string(ctx.GetHeader("Last-Event-ID"))
    
    s := sse.NewSSE(ctx)
    
    startID := 1
    if lastEventID != "" {
        if id, err := strconv.Atoi(lastEventID); err == nil {
            startID = id + 1
        }
    }
    
    for i := startID; ; i++ {
        select {
        case <-c.Done():
            fmt.Println("Client disconnected")
            return
        default:
            event := &sse.Event{
                ID:   fmt.Sprintf("%d", i),
                Data: []byte(fmt.Sprintf("Message %d", i)),
            }
            
            if err := s.Publish(event); err != nil {
                fmt.Println("Publish error:", err)
                return
            }
            
            time.Sleep(1 * time.Second)
        }
    }
})

心跳保活

go
h.GET("/sse-heartbeat", func(c context.Context, ctx *app.RequestContext) {
    s := sse.NewSSE(ctx)
    
    ticker := time.NewTicker(15 * time.Second)
    defer ticker.Stop()
    
    messageTicker := time.NewTicker(2 * time.Second)
    defer messageTicker.Stop()
    
    msgCount := 0
    
    for {
        select {
        case <-ticker.C:
            heartbeat := &sse.Event{
                Event: "heartbeat",
                Data:  []byte("ping"),
            }
            if err := s.Publish(heartbeat); err != nil {
                return
            }
            
        case <-messageTicker.C:
            msgCount++
            event := &sse.Event{
                ID:   fmt.Sprintf("%d", msgCount),
                Data: []byte(fmt.Sprintf("Message %d", msgCount)),
            }
            if err := s.Publish(event); err != nil {
                return
            }
            
        case <-c.Done():
            return
        }
    }
})

完整示例

go
package main

import (
    "context"
    "encoding/json"
    "fmt"
    "math/rand"
    "strconv"
    "strings"
    "sync"
    "time"
    
    "github.com/cloudwego/hertz/pkg/app"
    "github.com/cloudwego/hertz/pkg/app/server"
    "github.com/cloudwego/hertz/pkg/protocol/sse"
)

type Message struct {
    ID        string    `json:"id"`
    Content   string    `json:"content"`
    Timestamp time.Time `json:"timestamp"`
}

type Broker struct {
    clients map[chan *sse.Event]bool
    mu      sync.RWMutex
}

func NewBroker() *Broker {
    return &Broker{
        clients: make(map[chan *sse.Event]bool),
    }
}

func (b *Broker) Subscribe() chan *sse.Event {
    b.mu.Lock()
    defer b.mu.Unlock()
    
    ch := make(chan *sse.Event, 10)
    b.clients[ch] = true
    return ch
}

func (b *Broker) Unsubscribe(ch chan *sse.Event) {
    b.mu.Lock()
    defer b.mu.Unlock()
    
    delete(b.clients, ch)
    close(ch)
}

func (b *Broker) Broadcast(event *sse.Event) {
    b.mu.RLock()
    defer b.mu.RUnlock()
    
    for ch := range b.clients {
        select {
        case ch <- event:
        default:
        }
    }
}

var broker = NewBroker()

func main() {
    h := server.Default()
    
    h.GET("/sse", func(c context.Context, ctx *app.RequestContext) {
        s := sse.NewSSE(ctx)
        
        for i := 0; i < 10; i++ {
            event := &sse.Event{
                ID:   fmt.Sprintf("%d", i+1),
                Data: []byte(fmt.Sprintf("Message %d", i+1)),
            }
            
            if err := s.Publish(event); err != nil {
                break
            }
            
            time.Sleep(500 * time.Millisecond)
        }
    })
    
    h.GET("/sse/events", func(c context.Context, ctx *app.RequestContext) {
        s := sse.NewSSE(ctx)
        
        eventTypes := []string{"message", "notification", "alert", "update"}
        
        for i := 0; i < 20; i++ {
            msg := Message{
                ID:        fmt.Sprintf("%d", i+1),
                Content:   fmt.Sprintf("Event content %d", i+1),
                Timestamp: time.Now(),
            }
            
            data, _ := json.Marshal(msg)
            
            event := &sse.Event{
                ID:    msg.ID,
                Event: eventTypes[i%len(eventTypes)],
                Data:  data,
            }
            
            if err := s.Publish(event); err != nil {
                break
            }
            
            time.Sleep(300 * time.Millisecond)
        }
    })
    
    h.GET("/sse/subscribe", func(c context.Context, ctx *app.RequestContext) {
        s := sse.NewSSE(ctx)
        ch := broker.Subscribe()
        defer broker.Unsubscribe(ch)
        
        for {
            select {
            case event := <-ch:
                if err := s.Publish(event); err != nil {
                    return
                }
            case <-c.Done():
                return
            }
        }
    })
    
    h.POST("/sse/broadcast", func(c context.Context, ctx *app.RequestContext) {
        var req struct {
            Message string `json:"message"`
        }
        
        if err := ctx.BindJSON(&req); err != nil {
            ctx.JSON(400, map[string]string{"error": err.Error()})
            return
        }
        
        event := &sse.Event{
            Event: "broadcast",
            Data:  []byte(req.Message),
        }
        broker.Broadcast(event)
        
        ctx.JSON(200, map[string]string{"status": "broadcast sent"})
    })
    
    h.POST("/chat/stream", func(c context.Context, ctx *app.RequestContext) {
        var req struct {
            Prompt string `json:"prompt"`
        }
        
        if err := ctx.BindJSON(&req); err != nil {
            ctx.JSON(400, map[string]string{"error": err.Error()})
            return
        }
        
        s := sse.NewSSE(ctx)
        
        response := "这是一个模拟的 AI 流式回复,我们会逐词发送这条消息,模拟真实的 AI 对话效果。"
        words := strings.Fields(response)
        
        for i, word := range words {
            event := &sse.Event{
                ID:    fmt.Sprintf("%d", i+1),
                Event: "chat",
                Data:  []byte(word + " "),
            }
            
            if err := s.Publish(event); err != nil {
                break
            }
            
            time.Sleep(100 * time.Millisecond)
        }
        
        doneEvent := &sse.Event{
            Event: "done",
            Data:  []byte("[DONE]"),
        }
        s.Publish(doneEvent)
    })
    
    h.GET("/metrics/stream", func(c context.Context, ctx *app.RequestContext) {
        s := sse.NewSSE(ctx)
        
        ticker := time.NewTicker(1 * time.Second)
        defer ticker.Stop()
        
        for i := 0; ; i++ {
            select {
            case <-ticker.C:
                metric := map[string]interface{}{
                    "timestamp": time.Now().Unix(),
                    "cpu":       rand.Float64() * 100,
                    "memory":    rand.Float64() * 100,
                    "requests":  rand.Intn(1000),
                }
                
                data, _ := json.Marshal(metric)
                event := &sse.Event{
                    ID:    strconv.Itoa(i + 1),
                    Event: "metric",
                    Data:  data,
                }
                
                if err := s.Publish(event); err != nil {
                    return
                }
            case <-c.Done():
                return
            }
        }
    })
    
    h.Spin()
}

小结

本章介绍了 Hertz 的 SSE 流式响应:

  1. SSE 基础:协议格式、基本响应
  2. SSE 工具库:使用 sse 包发送消息
  3. 实时数据推送:日志推送、监控数据
  4. AI 对话流式响应:模拟 AI 对话、结构化响应
  5. 广播推送:消息广播系统
  6. 客户端断开处理:检测断开、心跳保活

至此,Hertz 系列教程全部完成。