go
// context.go - 核心接口定义
type Context interface {
// 返回 context 被取消的时间点
Deadline() (deadline time.Time, ok bool)
// 返回一个 channel,当 context 被取消时会被关闭
// 这是实现取消通知的核心机制
Done() <-chan struct{}
// 返回 context 被取消的原因
Err() error
// 根据 key 获取值
Value(key interface{}) interface{}
}
go
// emptyCtx - 根节点
type emptyCtx int
// cancelCtx - 可取消的 context
type cancelCtx struct {
Context
mu sync.Mutex
done atomic.Value
children map[canceler]struct{}
err error
}
// timerCtx - 带超时的 context
type timerCtx struct {
cancelCtx
timer *time.Timer
deadline time.Time
}
// valueCtx - 携带键值对的 context
type valueCtx struct {
Context
key, val interface{}
}
plain text
emptyCtx (Background/TODO)
│
┌─────────────┼─────────────┐
│ │ │
cancelCtx timerCtx valueCtx
│ │
└─────────────┘
(timerCtx 继承 cancelCtx)
go
type emptyCtx int
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
return // 返回零值
}
func (*emptyCtx) Done() <-chan struct{} {
return nil // 永远不会被取消
}
func (*emptyCtx) Err() error {
return nil
}
func (*emptyCtx) Value(key interface{}) interface{} {
return nil
}
var (
background = new(emptyCtx)
todo = new(emptyCtx)
)
func Background() Context {
return background
}
func TODO() Context {
return todo
}
go
type cancelCtx struct {
Context
mu sync.Mutex // 保护以下字段
done atomic.Value // 延迟创建,存储 chan struct{}
children map[canceler]struct{} // 存储所有子 context
err error // 取消时的错误信息
}
// canceler 接口:可被取消的 context 必须实现
type canceler interface {
cancel(removeFromParent bool, err error)
Done() <-chan struct{}
}
go
func (c *cancelCtx) Done() <-chan struct{} {
// 使用 atomic.Value 实现延迟初始化
d := c.done.Load()
if d != nil {
return d.(chan struct{})
}
c.mu.Lock()
defer c.mu.Unlock()
// double-check,避免重复创建
d = c.done.Load()
if d == nil {
d = make(chan struct{})
c.done.Store(d)
}
return d.(chan struct{})
}
go
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
if err == nil {
panic("context: internal error: missing cancel error")
}
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return // 已经被取消了
}
c.err = err
// 关闭 done channel,通知所有监听者
d, _ := c.done.Load().(chan struct{})
if d == nil {
c.done.Store(closedchan) // closedchan 是预先关闭的 channel
} else {
close(d)
}
// 递归取消所有子 context
for child := range c.children {
child.cancel(false, err)
}
c.children = nil
c.mu.Unlock()
// 从父节点移除自己
if removeFromParent {
removeChild(c.Context, c)
}
}
go
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
if parent == nil {
panic("cannot create context from nil parent")
}
c := newCancelCtx(parent)
propagateCancel(parent, &c) // 建立父子关系
return &c, func() { c.cancel(true, Canceled) }
}
// 关键函数:建立取消传播关系
func propagateCancel(parent Context, child canceler) {
done := parent.Done()
if done == nil {
return // 父 context 永远不会被取消
}
select {
case <-done:
// 父 context 已经被取消
child.cancel(false, parent.Err())
return
default:
}
// 找到最近的可取消的祖先
if p, ok := parentCancelCtx(parent); ok {
p.mu.Lock()
if p.err != nil {
// 父 context 已经被取消
child.cancel(false, p.err)
} else {
// 将自己添加到父 context 的 children 中
if p.children == nil {
p.children = make(map[canceler]struct{})
}
p.children[child] = struct{}{}
}
p.mu.Unlock()
} else {
// 父 context 是自定义类型,启动 goroutine 监听
go func() {
select {
case <-parent.Done():
child.cancel(false, parent.Err())
case <-child.Done():
}
}()
}
}
go
type timerCtx struct {
cancelCtx
timer *time.Timer // 定时器
deadline time.Time // 截止时间
}
func (c *timerCtx) Deadline() (deadline time.Time, ok bool) {
return c.deadline, true
}
func (c *timerCtx) cancel(removeFromParent bool, err error) {
// 先调用 cancelCtx 的 cancel
c.cancelCtx.cancel(false, err)
if removeFromParent {
removeChild(c.cancelCtx.Context, c)
}
c.mu.Lock()
if c.timer != nil {
c.timer.Stop() // 停止定时器
c.timer = nil
}
c.mu.Unlock()
}
go
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
return WithDeadline(parent, time.Now().Add(timeout))
}
func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
if parent == nil {
panic("cannot create context from nil parent")
}
// 如果父 context 的 deadline 更早,直接返回 cancelCtx
if cur, ok := parent.Deadline(); ok && cur.Before(d) {
return WithCancel(parent)
}
c := &timerCtx{
cancelCtx: newCancelCtx(parent),
deadline: d,
}
propagateCancel(parent, c)
dur := time.Until(d)
if dur <= 0 {
c.cancel(true, DeadlineExceeded) // 已经过期
return c, func() { c.cancel(false, Canceled) }
}
c.mu.Lock()
defer c.mu.Unlock()
if c.err == nil {
// 设置定时器,到期自动取消
c.timer = time.AfterFunc(dur, func() {
c.cancel(true, DeadlineExceeded)
})
}
return c, func() { c.cancel(true, Canceled) }
}
go
type valueCtx struct {
Context
key, val interface{}
}
func (c *valueCtx) Value(key interface{}) interface{} {
if c.key == key {
return c.val
}
// 递归向上查找
return c.Context.Value(key)
}
func WithValue(parent Context, key, val interface{}) Context {
if parent == nil {
panic("cannot create context from nil parent")
}
if key == nil {
panic("nil key")
}
if !reflectlite.TypeOf(key).Comparable() {
panic("key is not comparable")
}
return &valueCtx{parent, key, val}
}
go
// 当 channel 被关闭时,所有阻塞在该 channel 上的接收操作会立即返回零值
done := ctx.Done()
select {
case <-done:
// context 被取消了
fmt.Println(ctx.Err())
}
plain text
parent
/ \
child1 child2
/ \
gc1 gc2
取消 parent → 自动取消 child1, child2, gc1, gc2
取消 child1 → 只取消 gc1, gc2,不影响 parent 和 child2
go
// done channel 使用 atomic.Value 实现延迟创建
// 只有在第一次调用 Done() 时才创建 channel
// 节省内存,因为很多 context 可能永远不会被监听
go
ctx1 := context.WithValue(ctx, "key1", "val1")
ctx2 := context.WithValue(ctx1, "key2", "val2")
ctx3 := context.WithValue(ctx2, "key3", "val3")
// 查找 "key1" 时:ctx3 → ctx2 → ctx1 (找到)
val := ctx3.Value("key1")
go
// ❌ 不要用于传递可选参数
func DoSomething(ctx context.Context) {
// 不要这样做
timeout := ctx.Value("timeout").(time.Duration)
}
// ✅ 应该显式传参
func DoSomething(ctx context.Context, timeout time.Duration) {
}
// ✅ 只用于传递请求域数据
func Handler(ctx context.Context) {
userID := ctx.Value(userIDKey).(string) // 从请求中提取的用户ID
traceID := ctx.Value(traceIDKey).(string) // 链路追踪ID
}
go
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
)
// 业务场景:电商系统查询订单详情
// 需求:整个请求必须在 3 秒内完成,包括:
// 1. 查询订单基本信息(500ms)
// 2. 查询商品详情(800ms)
// 3. 查询物流信息(600ms)
type OrderDetail struct {
OrderID string `json:"order_id"`
Products string `json:"products"`
Shipping string `json:"shipping"`
}
// 模拟数据库查询
func queryOrderInfo(ctx context.Context, orderID string) (string, error) {
// 使用 context 控制超时
select {
case <-time.After(500 * time.Millisecond):
return fmt.Sprintf("Order-%s", orderID), nil
case <-ctx.Done():
return "", ctx.Err()
}
}
func queryProductInfo(ctx context.Context, orderID string) (string, error) {
select {
case <-time.After(800 * time.Millisecond):
return "iPhone 15 Pro * 1", nil
case <-ctx.Done():
return "", ctx.Err()
}
}
func queryShippingInfo(ctx context.Context, orderID string) (string, error) {
select {
case <-time.After(600 * time.Millisecond):
return "In Transit", nil
case <-ctx.Done():
return "", ctx.Err()
}
}
func handleOrderDetail(w http.ResponseWriter, r *http.Request) {
// 从请求中提取或创建 context
ctx := r.Context()
// 设置 3 秒超时
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel() // 确保释放资源
orderID := r.URL.Query().Get("order_id")
if orderID == "" {
http.Error(w, "missing order_id", http.StatusBadRequest)
return
}
// 使用 channel 收集结果
type result struct {
data string
err error
}
orderCh := make(chan result, 1)
productCh := make(chan result, 1)
shippingCh := make(chan result, 1)
// 并发查询三个数据源
go func() {
data, err := queryOrderInfo(ctx, orderID)
orderCh <- result{data, err}
}()
go func() {
data, err := queryProductInfo(ctx, orderID)
productCh <- result{data, err}
}()
go func() {
data, err := queryShippingInfo(ctx, orderID)
shippingCh <- result{data, err}
}()
// 收集结果
detail := OrderDetail{OrderID: orderID}
for i := 0; i < 3; i++ {
select {
case res := <-orderCh:
if res.err != nil {
http.Error(w, fmt.Sprintf("order query failed: %v", res.err), http.StatusInternalServerError)
return
}
detail.OrderID = res.data
case res := <-productCh:
if res.err != nil {
http.Error(w, fmt.Sprintf("product query failed: %v", res.err), http.StatusInternalServerError)
return
}
detail.Products = res.data
case res := <-shippingCh:
if res.err != nil {
http.Error(w, fmt.Sprintf("shipping query failed: %v", res.err), http.StatusInternalServerError)
return
}
detail.Shipping = res.data
case <-ctx.Done():
// 超时或取消
http.Error(w, fmt.Sprintf("request timeout: %v", ctx.Err()), http.StatusRequestTimeout)
return
}
}
// 返回结果
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(detail)
}
func main() {
http.HandleFunc("/order/detail", handleOrderDetail)
log.Println("Server starting on :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
}
// 测试:curl "http://localhost:8080/order/detail?order_id=12345"
go
package main
import (
"context"
"database/sql"
"fmt"
"log"
"time"
_ "github.com/go-sql-driver/mysql"
)
// 业务场景:批量导入用户数据
// 需求:
// 1. 每 1000 条数据提交一次事务
// 2. 用户可以随时取消导入
// 3. 整个导入有 30 分钟超时限制
type User struct {
ID int
Name string
Email string
}
func importUsers(ctx context.Context, db *sql.DB, users []User) error {
// 设置整体超时
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
defer cancel()
const batchSize = 1000
for i := 0; i < len(users); i += batchSize {
// 检查是否被取消
select {
case <-ctx.Done():
return fmt.Errorf("import cancelled: %w", ctx.Err())
default:
}
end := i + batchSize
if end > len(users) {
end = len(users)
}
batch := users[i:end]
// 开始事务
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
// 使用 defer 确保事务回滚或提交
err = func() error {
defer func() {
if err != nil {
tx.Rollback()
}
}()
stmt, err := tx.PrepareContext(ctx,
"INSERT INTO users (name, email) VALUES (?, ?)")
if err != nil {
return fmt.Errorf("prepare statement: %w", err)
}
defer stmt.Close()
for _, user := range batch {
// 每次插入都检查 context
_, err := stmt.ExecContext(ctx, user.Name, user.Email)
if err != nil {
return fmt.Errorf("insert user: %w", err)
}
}
// 提交事务
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
log.Printf("Imported batch: %d-%d", i, end)
return nil
}()
if err != nil {
return err
}
}
return nil
}
func main() {
db, err := sql.Open("mysql", "user:pass@tcp(localhost:3306)/testdb")
if err != nil {
log.Fatal(err)
}
defer db.Close()
// 模拟生成 10000 个用户
users := make([]User, 10000)
for i := range users {
users[i] = User{
ID: i + 1,
Name: fmt.Sprintf("User%d", i+1),
Email: fmt.Sprintf("user%d@example.com", i+1),
}
}
// 创建可取消的 context
ctx, cancel := context.WithCancel(context.Background())
// 模拟用户在 5 秒后取消
go func() {
time.Sleep(5 * time.Second)
log.Println("User cancelled the import")
cancel()
}()
// 执行导入
if err := importUsers(ctx, db, users); err != nil {
log.Printf("Import failed: %v", err)
} else {
log.Println("Import completed successfully")
}
}
go
package main
import (
"context"
"fmt"
"log"
"math/rand"
"time"
)
// 业务场景:电商下单流程
// 调用链:API Gateway → Order Service → Inventory Service → Payment Service
// 需求:
// 1. 传递 traceID 用于链路追踪
// 2. 传递 userID 用于权限验证
// 3. 整个调用链 5 秒超时
// 定义 context key 的类型,避免冲突
type contextKey string
const (
traceIDKey contextKey = "trace_id"
userIDKey contextKey = "user_id"
)
// 从 context 中提取值的辅助函数
func getTraceID(ctx context.Context) string {
if v := ctx.Value(traceIDKey); v != nil {
return v.(string)
}
return "unknown"
}
func getUserID(ctx context.Context) string {
if v := ctx.Value(userIDKey); v != nil {
return v.(string)
}
return "unknown"
}
// 模拟网络延迟
func simulateNetworkDelay() {
time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond)
}
// Inventory Service - 库存服务
func checkInventory(ctx context.Context, productID string, quantity int) error {
traceID := getTraceID(ctx)
log.Printf("[Inventory] traceID=%s, checking product=%s, quantity=%d",
traceID, productID, quantity)
simulateNetworkDelay()
select {
case <-ctx.Done():
return fmt.Errorf("inventory check timeout: %w", ctx.Err())
default:
log.Printf("[Inventory] traceID=%s, inventory check passed", traceID)
return nil
}
}
// Payment Service - 支付服务
func processPayment(ctx context.Context, userID string, amount float64) error {
traceID := getTraceID(ctx)
log.Printf("[Payment] traceID=%s, userID=%s, processing payment amount=%.2f",
traceID, userID, amount)
simulateNetworkDelay()
select {
case <-ctx.Done():
return fmt.Errorf("payment timeout: %w", ctx.Err())
default:
log.Printf("[Payment] traceID=%s, payment successful", traceID)
return nil
}
}
// Order Service - 订单服务
func createOrder(ctx context.Context, order Order) error {
traceID := getTraceID(ctx)
userID := getUserID(ctx)
log.Printf("[Order] traceID=%s, userID=%s, creating order=%+v",
traceID, userID, order)
// 1. 检查库存
if err := checkInventory(ctx, order.ProductID, order.Quantity); err != nil {
return fmt.Errorf("inventory check failed: %w", err)
}
// 2. 处理支付
if err := processPayment(ctx, userID, order.Amount); err != nil {
return fmt.Errorf("payment failed: %w", err)
}
// 3. 创建订单记录
simulateNetworkDelay()
select {
case <-ctx.Done():
return fmt.Errorf("order creation timeout: %w", ctx.Err())
default:
log.Printf("[Order] traceID=%s, order created successfully", traceID)
return nil
}
}
type Order struct {
ProductID string
Quantity int
Amount float64
}
// API Gateway - 入口
func handleCreateOrder(userID string, order Order) error {
// 生成 traceID
traceID := fmt.Sprintf("trace-%d", time.Now().UnixNano())
// 创建带超时的 context
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 注入 traceID 和 userID
ctx = context.WithValue(ctx, traceIDKey, traceID)
ctx = context.WithValue(ctx, userIDKey, userID)
log.Printf("[Gateway] traceID=%s, userID=%s, received order request", traceID, userID)
// 调用订单服务
if err := createOrder(ctx, order); err != nil {
log.Printf("[Gateway] traceID=%s, order creation failed: %v", traceID, err)
return err
}
log.Printf("[Gateway] traceID=%s, order creation succeeded", traceID)
return nil
}
func main() {
rand.Seed(time.Now().UnixNano())
order := Order{
ProductID: "PROD-12345",
Quantity: 2,
Amount: 299.99,
}
if err := handleCreateOrder("USER-789", order); err != nil {
log.Printf("Failed to create order: %v", err)
}
}
// 输出示例:
// [Gateway] traceID=trace-1234567890, userID=USER-789, received order request
// [Order] traceID=trace-1234567890, userID=USER-789, creating order=...
// [Inventory] traceID=trace-1234567890, checking product=PROD-12345, quantity=2
// [Inventory] traceID=trace-1234567890, inventory check passed
// [Payment] traceID=trace-1234567890, userID=USER-789, processing payment amount=299.99
// [Payment] traceID=trace-1234567890, payment successful
// [Order] traceID=trace-1234567890, order created successfully
// [Gateway] traceID=trace-1234567890, order creation succeeded
go
package main
import (
"context"
"fmt"
"log"
"math/rand"
"sync"
"time"
)
// 业务场景:爬虫系统
// 需求:
// 1. 并发爬取多个网页
// 2. 任意一个任务失败,取消所有任务
// 3. 所有任务完成后汇总结果
type CrawlResult struct {
URL string
Title string
Err error
}
func crawlPage(ctx context.Context, url string) (*CrawlResult, error) {
result := &CrawlResult{URL: url}
// 模拟网络请求
select {
case <-time.After(time.Duration(500+rand.Intn(1000)) * time.Millisecond):
// 模拟 10% 的失败率
if rand.Float32() < 0.1 {
result.Err = fmt.Errorf("failed to crawl %s", url)
return result, result.Err
}
result.Title = fmt.Sprintf("Title of %s", url)
return result, nil
case <-ctx.Done():
result.Err = ctx.Err()
return result, result.Err
}
}
func crawlWebsites(urls []string) ([]*CrawlResult, error) {
// 创建可取消的 context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 设置整体超时
ctx, timeoutCancel := context.WithTimeout(ctx, 10*time.Second)
defer timeoutCancel()
var (
wg sync.WaitGroup
mu sync.Mutex
results []*CrawlResult
errOnce sync.Once
firstErr error
)
// 并发爬取
for _, url := range urls {
wg.Add(1)
go func(url string) {
defer wg.Done()
log.Printf("Starting crawl: %s", url)
result, err := crawlPage(ctx, url)
mu.Lock()
results = append(results, result)
mu.Unlock()
if err != nil {
// 记录第一个错误,并取消所有任务
errOnce.Do(func() {
firstErr = err
log.Printf("Error occurred, cancelling all tasks: %v", err)
cancel()
})
return
}
log.Printf("Completed crawl: %s - %s", url, result.Title)
}(url)
}
// 等待所有 goroutine 完成
wg.Wait()
if firstErr != nil {
return results, fmt.Errorf("crawl failed: %w", firstErr)
}
return results, nil
}
func main() {
rand.Seed(time.Now().UnixNano())
urls := []string{
"https://example.com/page1",
"https://example.com/page2",
"https://example.com/page3",
"https://example.com/page4",
"https://example.com/page5",
}
log.Println("Starting crawl...")
results, err := crawlWebsites(urls)
log.Println("\n=== Results ===")
for _, r := range results {
if r.Err != nil {
log.Printf("❌ %s: %v", r.URL, r.Err)
} else {
log.Printf("✅ %s: %s", r.URL, r.Title)
}
}
if err != nil {
log.Printf("\nCrawl failed: %v", err)
} else {
log.Println("\nAll crawls completed successfully!")
}
}