Context源码学习和最佳实践

Published on in Technology with 0 views and 0 comments
  • Go Context 源码深度解析与最佳实践

  • 目录

    1. Context 核心接口设计
    2. 源码结构解析
    3. 四大核心实现
    4. 取消机制的实现原理
    5. 值传递的设计
    6. 最佳实践与业务场景
  • 1. Context 核心接口设计

    • go
      // context.go - 核心接口定义
      type Context interface {
          // 返回 context 被取消的时间点
          Deadline() (deadline time.Time, ok bool)
      
          // 返回一个 channel,当 context 被取消时会被关闭
          // 这是实现取消通知的核心机制
          Done() <-chan struct{}
      
          // 返回 context 被取消的原因
          Err() error
      
          // 根据 key 获取值
          Value(key interface{}) interface{}
      }
      
    • 设计思想:
      • 返回只读 channel,利用 "关闭 channel 会通知所有接收者" 的特性实现广播
      • 接口最小化,只有 4 个方法,符合 Go 的简洁哲学
      • 不可变性:Context 是不可修改的,只能通过衍生创建新的 Context
  • 2. 源码结构解析

    • Context 包提供了 4 种基础实现:
    • go
      // emptyCtx - 根节点
      type emptyCtx int
      
      // cancelCtx - 可取消的 context
      type cancelCtx struct {
          Context
          mu       sync.Mutex
          done     atomic.Value
          children map[canceler]struct{}
          err      error
      }
      
      // timerCtx - 带超时的 context
      type timerCtx struct {
          cancelCtx
          timer *time.Timer
          deadline time.Time
      }
      
      // valueCtx - 携带键值对的 context
      type valueCtx struct {
          Context
          key, val interface{}
      }
      
    • 核心关系图:
    • plain text
                        emptyCtx (Background/TODO)
                            │
              ┌─────────────┼─────────────┐
              │             │             │
          cancelCtx     timerCtx      valueCtx
              │             │
              └─────────────┘
               (timerCtx 继承 cancelCtx)
      
  • 3. 四大核心实现

  • 3.1 emptyCtx - 根节点实现

    • go
      type emptyCtx int
      
      func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
          return // 返回零值
      }
      
      func (*emptyCtx) Done() <-chan struct{} {
          return nil // 永远不会被取消
      }
      
      func (*emptyCtx) Err() error {
          return nil
      }
      
      func (*emptyCtx) Value(key interface{}) interface{} {
          return nil
      }
      
      var (
          background = new(emptyCtx)
          todo       = new(emptyCtx)
      )
      
      func Background() Context {
          return background
      }
      
      func TODO() Context {
          return todo
      }
      
    • 设计要点:
      • 是最简单的实现,所有方法都返回零值
      • 用于主函数、初始化、测试等顶层 context
      • 用于不确定该用什么 context 的场景
  • 3.2 cancelCtx - 可取消实现

    • go
      type cancelCtx struct {
          Context
      
          mu       sync.Mutex            // 保护以下字段
          done     atomic.Value          // 延迟创建,存储 chan struct{}
          children map[canceler]struct{} // 存储所有子 context
          err      error                 // 取消时的错误信息
      }
      
      // canceler 接口:可被取消的 context 必须实现
      type canceler interface {
          cancel(removeFromParent bool, err error)
          Done() <-chan struct{}
      }
      
    • 核心方法 - Done():
    • go
      func (c *cancelCtx) Done() <-chan struct{} {
          // 使用 atomic.Value 实现延迟初始化
          d := c.done.Load()
          if d != nil {
              return d.(chan struct{})
          }
      
          c.mu.Lock()
          defer c.mu.Unlock()
      
          // double-check,避免重复创建
          d = c.done.Load()
          if d == nil {
              d = make(chan struct{})
              c.done.Store(d)
          }
          return d.(chan struct{})
      }
      
    • 核心方法 - cancel():
    • go
      func (c *cancelCtx) cancel(removeFromParent bool, err error) {
          if err == nil {
              panic("context: internal error: missing cancel error")
          }
      
          c.mu.Lock()
          if c.err != nil {
              c.mu.Unlock()
              return // 已经被取消了
          }
      
          c.err = err
      
          // 关闭 done channel,通知所有监听者
          d, _ := c.done.Load().(chan struct{})
          if d == nil {
              c.done.Store(closedchan) // closedchan 是预先关闭的 channel
          } else {
              close(d)
          }
      
          // 递归取消所有子 context
          for child := range c.children {
              child.cancel(false, err)
          }
          c.children = nil
          c.mu.Unlock()
      
          // 从父节点移除自己
          if removeFromParent {
              removeChild(c.Context, c)
          }
      }
      
    • WithCancel 函数:
    • go
      func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
          if parent == nil {
              panic("cannot create context from nil parent")
          }
      
          c := newCancelCtx(parent)
          propagateCancel(parent, &c) // 建立父子关系
          return &c, func() { c.cancel(true, Canceled) }
      }
      
      // 关键函数:建立取消传播关系
      func propagateCancel(parent Context, child canceler) {
          done := parent.Done()
          if done == nil {
              return // 父 context 永远不会被取消
          }
      
          select {
          case <-done:
              // 父 context 已经被取消
              child.cancel(false, parent.Err())
              return
          default:
          }
      
          // 找到最近的可取消的祖先
          if p, ok := parentCancelCtx(parent); ok {
              p.mu.Lock()
              if p.err != nil {
                  // 父 context 已经被取消
                  child.cancel(false, p.err)
              } else {
                  // 将自己添加到父 context 的 children 中
                  if p.children == nil {
                      p.children = make(map[canceler]struct{})
                  }
                  p.children[child] = struct{}{}
              }
              p.mu.Unlock()
          } else {
              // 父 context 是自定义类型,启动 goroutine 监听
              go func() {
                  select {
                  case <-parent.Done():
                      child.cancel(false, parent.Err())
                  case <-child.Done():
                  }
              }()
          }
      }
      
  • 3.3 timerCtx - 超时取消实现

    • go
      type timerCtx struct {
          cancelCtx
          timer *time.Timer    // 定时器
          deadline time.Time   // 截止时间
      }
      
      func (c *timerCtx) Deadline() (deadline time.Time, ok bool) {
          return c.deadline, true
      }
      
      func (c *timerCtx) cancel(removeFromParent bool, err error) {
          // 先调用 cancelCtx 的 cancel
          c.cancelCtx.cancel(false, err)
      
          if removeFromParent {
              removeChild(c.cancelCtx.Context, c)
          }
      
          c.mu.Lock()
          if c.timer != nil {
              c.timer.Stop() // 停止定时器
              c.timer = nil
          }
          c.mu.Unlock()
      }
      
    • WithTimeout 和 WithDeadline:
    • go
      func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
          return WithDeadline(parent, time.Now().Add(timeout))
      }
      
      func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
          if parent == nil {
              panic("cannot create context from nil parent")
          }
      
          // 如果父 context 的 deadline 更早,直接返回 cancelCtx
          if cur, ok := parent.Deadline(); ok && cur.Before(d) {
              return WithCancel(parent)
          }
      
          c := &timerCtx{
              cancelCtx: newCancelCtx(parent),
              deadline:  d,
          }
          propagateCancel(parent, c)
      
          dur := time.Until(d)
          if dur <= 0 {
              c.cancel(true, DeadlineExceeded) // 已经过期
              return c, func() { c.cancel(false, Canceled) }
          }
      
          c.mu.Lock()
          defer c.mu.Unlock()
          if c.err == nil {
              // 设置定时器,到期自动取消
              c.timer = time.AfterFunc(dur, func() {
                  c.cancel(true, DeadlineExceeded)
              })
          }
      
          return c, func() { c.cancel(true, Canceled) }
      }
      
  • 3.4 valueCtx - 值传递实现

    • go
      type valueCtx struct {
          Context
          key, val interface{}
      }
      
      func (c *valueCtx) Value(key interface{}) interface{} {
          if c.key == key {
              return c.val
          }
          // 递归向上查找
          return c.Context.Value(key)
      }
      
      func WithValue(parent Context, key, val interface{}) Context {
          if parent == nil {
              panic("cannot create context from nil parent")
          }
          if key == nil {
              panic("nil key")
          }
          if !reflectlite.TypeOf(key).Comparable() {
              panic("key is not comparable")
          }
      
          return &valueCtx{parent, key, val}
      }
      
    • 设计要点:
      • valueCtx 是链表结构,通过递归查找实现值的继承
      • 每个 valueCtx 只存储一个键值对
      • 查找时间复杂度 O(n),n 是 context 嵌套层数
  • 4. 取消机制的实现原理

  • 4.1 核心机制

    • Context 的取消机制基于 channel 关闭的广播特性
    • go
      // 当 channel 被关闭时,所有阻塞在该 channel 上的接收操作会立即返回零值
      done := ctx.Done()
      select {
      case <-done:
          // context 被取消了
          fmt.Println(ctx.Err())
      }
      
  • 4.2 树形传播

    • Context 形成树形结构,取消会自动传播:
    • plain text
              parent
             /      \
          child1   child2
          /    \
        gc1    gc2
      
      取消 parent → 自动取消 child1, child2, gc1, gc2
      取消 child1 → 只取消 gc1, gc2,不影响 parent 和 child2
      
  • 4.3 延迟初始化优化

    • go
      // done channel 使用 atomic.Value 实现延迟创建
      // 只有在第一次调用 Done() 时才创建 channel
      // 节省内存,因为很多 context 可能永远不会被监听
      
  • 5. 值传递的设计

  • 5.1 链表查找

    • go
      ctx1 := context.WithValue(ctx, "key1", "val1")
      ctx2 := context.WithValue(ctx1, "key2", "val2")
      ctx3 := context.WithValue(ctx2, "key3", "val3")
      
      // 查找 "key1" 时:ctx3 → ctx2 → ctx1 (找到)
      val := ctx3.Value("key1")
      
  • 5.2 为什么不用 map?

    • 不可变性:每次 WithValue 创建新的 context,不修改原有的
    • 并发安全:链表结构天然线程安全,无需加锁
    • 轻量级:大多数 context 只携带少量值,链表足够高效
  • 5.3 值传递的限制

    • go
      // ❌ 不要用于传递可选参数
      func DoSomething(ctx context.Context) {
          // 不要这样做
          timeout := ctx.Value("timeout").(time.Duration)
      }
      
      // ✅ 应该显式传参
      func DoSomething(ctx context.Context, timeout time.Duration) {
      }
      
      // ✅ 只用于传递请求域数据
      func Handler(ctx context.Context) {
          userID := ctx.Value(userIDKey).(string) // 从请求中提取的用户ID
          traceID := ctx.Value(traceIDKey).(string) // 链路追踪ID
      }
      
  • 6. 最佳实践与业务场景

  • 场景 1:HTTP 服务器 - 请求超时控制

    • go
      package main
      
      import (
          "context"
          "encoding/json"
          "fmt"
          "log"
          "net/http"
          "time"
      )
      
      // 业务场景:电商系统查询订单详情
      // 需求:整个请求必须在 3 秒内完成,包括:
      // 1. 查询订单基本信息(500ms)
      // 2. 查询商品详情(800ms)
      // 3. 查询物流信息(600ms)
      
      type OrderDetail struct {
          OrderID  string `json:"order_id"`
          Products string `json:"products"`
          Shipping string `json:"shipping"`
      }
      
      // 模拟数据库查询
      func queryOrderInfo(ctx context.Context, orderID string) (string, error) {
          // 使用 context 控制超时
          select {
          case <-time.After(500 * time.Millisecond):
              return fmt.Sprintf("Order-%s", orderID), nil
          case <-ctx.Done():
              return "", ctx.Err()
          }
      }
      
      func queryProductInfo(ctx context.Context, orderID string) (string, error) {
          select {
          case <-time.After(800 * time.Millisecond):
              return "iPhone 15 Pro * 1", nil
          case <-ctx.Done():
              return "", ctx.Err()
          }
      }
      
      func queryShippingInfo(ctx context.Context, orderID string) (string, error) {
          select {
          case <-time.After(600 * time.Millisecond):
              return "In Transit", nil
          case <-ctx.Done():
              return "", ctx.Err()
          }
      }
      
      func handleOrderDetail(w http.ResponseWriter, r *http.Request) {
          // 从请求中提取或创建 context
          ctx := r.Context()
      
          // 设置 3 秒超时
          ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
          defer cancel() // 确保释放资源
      
          orderID := r.URL.Query().Get("order_id")
          if orderID == "" {
              http.Error(w, "missing order_id", http.StatusBadRequest)
              return
          }
      
          // 使用 channel 收集结果
          type result struct {
              data string
              err  error
          }
      
          orderCh := make(chan result, 1)
          productCh := make(chan result, 1)
          shippingCh := make(chan result, 1)
      
          // 并发查询三个数据源
          go func() {
              data, err := queryOrderInfo(ctx, orderID)
              orderCh <- result{data, err}
          }()
      
          go func() {
              data, err := queryProductInfo(ctx, orderID)
              productCh <- result{data, err}
          }()
      
          go func() {
              data, err := queryShippingInfo(ctx, orderID)
              shippingCh <- result{data, err}
          }()
      
          // 收集结果
          detail := OrderDetail{OrderID: orderID}
      
          for i := 0; i < 3; i++ {
              select {
              case res := <-orderCh:
                  if res.err != nil {
                      http.Error(w, fmt.Sprintf("order query failed: %v", res.err), http.StatusInternalServerError)
                      return
                  }
                  detail.OrderID = res.data
      
              case res := <-productCh:
                  if res.err != nil {
                      http.Error(w, fmt.Sprintf("product query failed: %v", res.err), http.StatusInternalServerError)
                      return
                  }
                  detail.Products = res.data
      
              case res := <-shippingCh:
                  if res.err != nil {
                      http.Error(w, fmt.Sprintf("shipping query failed: %v", res.err), http.StatusInternalServerError)
                      return
                  }
                  detail.Shipping = res.data
      
              case <-ctx.Done():
                  // 超时或取消
                  http.Error(w, fmt.Sprintf("request timeout: %v", ctx.Err()), http.StatusRequestTimeout)
                  return
              }
          }
      
          // 返回结果
          w.Header().Set("Content-Type", "application/json")
          json.NewEncoder(w).Encode(detail)
      }
      
      func main() {
          http.HandleFunc("/order/detail", handleOrderDetail)
          log.Println("Server starting on :8080")
          log.Fatal(http.ListenAndServe(":8080", nil))
      }
      
      // 测试:curl "http://localhost:8080/order/detail?order_id=12345"
      
    • 关键点:
      • 使用 获取请求的 context,自动绑定请求生命周期
      • 确保 context 资源被释放
      • 所有子任务共享同一个 context,任一超时全部取消
  • 场景 2:数据库事务 - 手动取消控制

    • go
      package main
      
      import (
          "context"
          "database/sql"
          "fmt"
          "log"
          "time"
      
          _ "github.com/go-sql-driver/mysql"
      )
      
      // 业务场景:批量导入用户数据
      // 需求:
      // 1. 每 1000 条数据提交一次事务
      // 2. 用户可以随时取消导入
      // 3. 整个导入有 30 分钟超时限制
      
      type User struct {
          ID    int
          Name  string
          Email string
      }
      
      func importUsers(ctx context.Context, db *sql.DB, users []User) error {
          // 设置整体超时
          ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
          defer cancel()
      
          const batchSize = 1000
      
          for i := 0; i < len(users); i += batchSize {
              // 检查是否被取消
              select {
              case <-ctx.Done():
                  return fmt.Errorf("import cancelled: %w", ctx.Err())
              default:
              }
      
              end := i + batchSize
              if end > len(users) {
                  end = len(users)
              }
      
              batch := users[i:end]
      
              // 开始事务
              tx, err := db.BeginTx(ctx, nil)
              if err != nil {
                  return fmt.Errorf("begin transaction: %w", err)
              }
      
              // 使用 defer 确保事务回滚或提交
              err = func() error {
                  defer func() {
                      if err != nil {
                          tx.Rollback()
                      }
                  }()
      
                  stmt, err := tx.PrepareContext(ctx, 
                      "INSERT INTO users (name, email) VALUES (?, ?)")
                  if err != nil {
                      return fmt.Errorf("prepare statement: %w", err)
                  }
                  defer stmt.Close()
      
                  for _, user := range batch {
                      // 每次插入都检查 context
                      _, err := stmt.ExecContext(ctx, user.Name, user.Email)
                      if err != nil {
                          return fmt.Errorf("insert user: %w", err)
                      }
                  }
      
                  // 提交事务
                  if err := tx.Commit(); err != nil {
                      return fmt.Errorf("commit transaction: %w", err)
                  }
      
                  log.Printf("Imported batch: %d-%d", i, end)
                  return nil
              }()
      
              if err != nil {
                  return err
              }
          }
      
          return nil
      }
      
      func main() {
          db, err := sql.Open("mysql", "user:pass@tcp(localhost:3306)/testdb")
          if err != nil {
              log.Fatal(err)
          }
          defer db.Close()
      
          // 模拟生成 10000 个用户
          users := make([]User, 10000)
          for i := range users {
              users[i] = User{
                  ID:    i + 1,
                  Name:  fmt.Sprintf("User%d", i+1),
                  Email: fmt.Sprintf("user%d@example.com", i+1),
              }
          }
      
          // 创建可取消的 context
          ctx, cancel := context.WithCancel(context.Background())
      
          // 模拟用户在 5 秒后取消
          go func() {
              time.Sleep(5 * time.Second)
              log.Println("User cancelled the import")
              cancel()
          }()
      
          // 执行导入
          if err := importUsers(ctx, db, users); err != nil {
              log.Printf("Import failed: %v", err)
          } else {
              log.Println("Import completed successfully")
          }
      }
      
    • 关键点:
      • 使用 让事务支持取消
      • 在循环中定期检查
      • 使用 确保事务正确回滚
  • 场景 3:微服务调用链 - Context 传递

    • go
      package main
      
      import (
          "context"
          "fmt"
          "log"
          "math/rand"
          "time"
      )
      
      // 业务场景:电商下单流程
      // 调用链:API Gateway → Order Service → Inventory Service → Payment Service
      // 需求:
      // 1. 传递 traceID 用于链路追踪
      // 2. 传递 userID 用于权限验证
      // 3. 整个调用链 5 秒超时
      
      // 定义 context key 的类型,避免冲突
      type contextKey string
      
      const (
          traceIDKey contextKey = "trace_id"
          userIDKey  contextKey = "user_id"
      )
      
      // 从 context 中提取值的辅助函数
      func getTraceID(ctx context.Context) string {
          if v := ctx.Value(traceIDKey); v != nil {
              return v.(string)
          }
          return "unknown"
      }
      
      func getUserID(ctx context.Context) string {
          if v := ctx.Value(userIDKey); v != nil {
              return v.(string)
          }
          return "unknown"
      }
      
      // 模拟网络延迟
      func simulateNetworkDelay() {
          time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond)
      }
      
      // Inventory Service - 库存服务
      func checkInventory(ctx context.Context, productID string, quantity int) error {
          traceID := getTraceID(ctx)
          log.Printf("[Inventory] traceID=%s, checking product=%s, quantity=%d", 
              traceID, productID, quantity)
      
          simulateNetworkDelay()
      
          select {
          case <-ctx.Done():
              return fmt.Errorf("inventory check timeout: %w", ctx.Err())
          default:
              log.Printf("[Inventory] traceID=%s, inventory check passed", traceID)
              return nil
          }
      }
      
      // Payment Service - 支付服务
      func processPayment(ctx context.Context, userID string, amount float64) error {
          traceID := getTraceID(ctx)
          log.Printf("[Payment] traceID=%s, userID=%s, processing payment amount=%.2f", 
              traceID, userID, amount)
      
          simulateNetworkDelay()
      
          select {
          case <-ctx.Done():
              return fmt.Errorf("payment timeout: %w", ctx.Err())
          default:
              log.Printf("[Payment] traceID=%s, payment successful", traceID)
              return nil
          }
      }
      
      // Order Service - 订单服务
      func createOrder(ctx context.Context, order Order) error {
          traceID := getTraceID(ctx)
          userID := getUserID(ctx)
      
          log.Printf("[Order] traceID=%s, userID=%s, creating order=%+v", 
              traceID, userID, order)
      
          // 1. 检查库存
          if err := checkInventory(ctx, order.ProductID, order.Quantity); err != nil {
              return fmt.Errorf("inventory check failed: %w", err)
          }
      
          // 2. 处理支付
          if err := processPayment(ctx, userID, order.Amount); err != nil {
              return fmt.Errorf("payment failed: %w", err)
          }
      
          // 3. 创建订单记录
          simulateNetworkDelay()
      
          select {
          case <-ctx.Done():
              return fmt.Errorf("order creation timeout: %w", ctx.Err())
          default:
              log.Printf("[Order] traceID=%s, order created successfully", traceID)
              return nil
          }
      }
      
      type Order struct {
          ProductID string
          Quantity  int
          Amount    float64
      }
      
      // API Gateway - 入口
      func handleCreateOrder(userID string, order Order) error {
          // 生成 traceID
          traceID := fmt.Sprintf("trace-%d", time.Now().UnixNano())
      
          // 创建带超时的 context
          ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
          defer cancel()
      
          // 注入 traceID 和 userID
          ctx = context.WithValue(ctx, traceIDKey, traceID)
          ctx = context.WithValue(ctx, userIDKey, userID)
      
          log.Printf("[Gateway] traceID=%s, userID=%s, received order request", traceID, userID)
      
          // 调用订单服务
          if err := createOrder(ctx, order); err != nil {
              log.Printf("[Gateway] traceID=%s, order creation failed: %v", traceID, err)
              return err
          }
      
          log.Printf("[Gateway] traceID=%s, order creation succeeded", traceID)
          return nil
      }
      
      func main() {
          rand.Seed(time.Now().UnixNano())
      
          order := Order{
              ProductID: "PROD-12345",
              Quantity:  2,
              Amount:    299.99,
          }
      
          if err := handleCreateOrder("USER-789", order); err != nil {
              log.Printf("Failed to create order: %v", err)
          }
      }
      
      // 输出示例:
      // [Gateway] traceID=trace-1234567890, userID=USER-789, received order request
      // [Order] traceID=trace-1234567890, userID=USER-789, creating order=...
      // [Inventory] traceID=trace-1234567890, checking product=PROD-12345, quantity=2
      // [Inventory] traceID=trace-1234567890, inventory check passed
      // [Payment] traceID=trace-1234567890, userID=USER-789, processing payment amount=299.99
      // [Payment] traceID=trace-1234567890, payment successful
      // [Order] traceID=trace-1234567890, order created successfully
      // [Gateway] traceID=trace-1234567890, order creation succeeded
      
    • 关键点:
      • 定义专用的 key 类型()避免冲突
      • 只传递请求域数据(traceID, userID),不传递业务参数
      • context 在整个调用链中传递,实现超时和取消的联动
  • 场景 4:并发任务协调 - WaitGroup + Context

    • go
      package main
      
      import (
          "context"
          "fmt"
          "log"
          "math/rand"
          "sync"
          "time"
      )
      
      // 业务场景:爬虫系统
      // 需求:
      // 1. 并发爬取多个网页
      // 2. 任意一个任务失败,取消所有任务
      // 3. 所有任务完成后汇总结果
      
      type CrawlResult struct {
          URL   string
          Title string
          Err   error
      }
      
      func crawlPage(ctx context.Context, url string) (*CrawlResult, error) {
          result := &CrawlResult{URL: url}
      
          // 模拟网络请求
          select {
          case <-time.After(time.Duration(500+rand.Intn(1000)) * time.Millisecond):
              // 模拟 10% 的失败率
              if rand.Float32() < 0.1 {
                  result.Err = fmt.Errorf("failed to crawl %s", url)
                  return result, result.Err
              }
              result.Title = fmt.Sprintf("Title of %s", url)
              return result, nil
      
          case <-ctx.Done():
              result.Err = ctx.Err()
              return result, result.Err
          }
      }
      
      func crawlWebsites(urls []string) ([]*CrawlResult, error) {
          // 创建可取消的 context
          ctx, cancel := context.WithCancel(context.Background())
          defer cancel()
      
          // 设置整体超时
          ctx, timeoutCancel := context.WithTimeout(ctx, 10*time.Second)
          defer timeoutCancel()
      
          var (
              wg      sync.WaitGroup
              mu      sync.Mutex
              results []*CrawlResult
              errOnce sync.Once
              firstErr error
          )
      
          // 并发爬取
          for _, url := range urls {
              wg.Add(1)
              go func(url string) {
                  defer wg.Done()
      
                  log.Printf("Starting crawl: %s", url)
                  result, err := crawlPage(ctx, url)
      
                  mu.Lock()
                  results = append(results, result)
                  mu.Unlock()
      
                  if err != nil {
                      // 记录第一个错误,并取消所有任务
                      errOnce.Do(func() {
                          firstErr = err
                          log.Printf("Error occurred, cancelling all tasks: %v", err)
                          cancel()
                      })
                      return
                  }
      
                  log.Printf("Completed crawl: %s - %s", url, result.Title)
              }(url)
          }
      
          // 等待所有 goroutine 完成
          wg.Wait()
      
          if firstErr != nil {
              return results, fmt.Errorf("crawl failed: %w", firstErr)
          }
      
          return results, nil
      }
      
      func main() {
          rand.Seed(time.Now().UnixNano())
      
          urls := []string{
              "https://example.com/page1",
              "https://example.com/page2",
              "https://example.com/page3",
              "https://example.com/page4",
              "https://example.com/page5",
          }
      
          log.Println("Starting crawl...")
          results, err := crawlWebsites(urls)
      
          log.Println("\n=== Results ===")
          for _, r := range results {
              if r.Err != nil {
                  log.Printf("❌ %s: %v", r.URL, r.Err)
              } else {
                  log.Printf("✅ %s: %s", r.URL, r.Title)
              }
          }
      
          if err != nil {
              log.Printf("\nCrawl failed: %v", err)
          } else {
              log.Println("\nAll crawls completed successfully!")
          }
      }
      

标题:Context源码学习和最佳实践
作者:wangzhaoo
地址:http://www.bangnimang.top/go-context