【golang】sync.WaitGroup详解

1、前言

Go语言在设计上对同步(Synchronization,数据同步和线程同步)提供大量的支持,好比 goroutine和channel同步原语,库层面有

  
- sync:提供基本的同步原语(好比Mutex、RWMutex、Locker)和 工具类(Once、WaitGroup、Cond、Pool、Map)
- sync/atomic:提供变量的原子操做(基于硬件指令 compare-and-swap)

-- 引用自《Golang package sync 剖析(一): sync.Once》数据库

上一期中,咱们介绍了 sync.Once 如何保障 exactly once 语义,本期文章咱们介绍 package sync 下的另外一个工具类:sync.WaitGroupsegmentfault

2、为何须要 WaitGroup

想象一个场景:咱们有一个用户画像服务,当一个请求到来时,须要函数

  1. 从 request 里解析出 user_id 和 画像维度参数
  2. 根据 user_id 从 ABCDE 五个子服务(数据库服务、存储服务、rpc服务等)拉取不一样维度的信息
  3. 将读取的信息进行整合,返回给调用方

假设 ABCDE 五个服务的响应时间 p99 是 20~50ms 之间。若是咱们顺序调用 ABCDE 读取信息,不考虑数据整合消耗时间,服务端总体响应时间 p99 是:工具

sum(A, B, C, D, E) => [100ms, 250ms]

先不说业务上能不能接受,响应时间上显然有很大的优化空间。最直观的优化方向就是,取数逻辑的总时间消耗:oop

sum(A, B, C, D, E) -> max(A, B, C, D, E)

具体到 coding 上,咱们须要并行调用 ABCDE 五个子服务,待调用所有返回之后,进行数据整合。如何保障所有返回呢?优化

此时,sync.WaitGroup 闪耀登场。ui

3、WaitGroup 用法

官方文档对 WaitGroup 的描述是:一个 WaitGroup 对象能够等待一组协程结束。使用方法是:atom

  1. main协程经过调用 wg.Add(delta int) 设置worker协程的个数,而后建立worker协程;
  2. worker协程执行结束之后,都要调用 wg.Done()
  3. main协程调用 wg.Wait() 且被block,直到全部worker协程所有执行结束后返回。

这里先看一个典型的例子:线程

// src/cmd/compile/internal/ssa/gen/main.go
func  main() {
  // 省略部分代码 ...
  var wg sync.WaitGroup
  for _, task := range tasks {
    task  := task
    wg.Add(1)
    go func() {
      task()
      wg.Done()
    }()
  }
  wg.Wait()
  // 省略部分代码...
}

这个例子具有了 WaitGroup 正确使用的大部分要素,包括:设计

  1. wg.Done 必须在全部 wg.Add 以后执行,因此要保证两个函数都在main协程中调用;
  2. wg.Done 在 worker协程里调用,尤为要保证调用一次,不能由于 panic 或任何缘由致使没有执行(建议使用 defer wg.Done());
  3. wg.Donewg.Wait 在时序上是没有前后。

细心的朋友可能会发现一行很是诡异的代码:

task  := task

Go 对 array/slice 进行遍历时,runtime 会把 task[i] 拷贝到 task 的内存地址,下标 i 会变,而 task 的内存地址不会变。若是不进行此次赋值操做,全部 goroutine 可能读到的都是最后一个task。为了让你们有一个直观的感受,咱们用下面这段代码作实验:

package main

import (
  "fmt"
  "unsafe"
)

func main() {
  tasks := []func(){
    func() { fmt.Printf("1. ") },
    func() { fmt.Printf("2. ") },
  }

  for idx, task := range tasks {
    task()
    fmt.Printf("遍历 = %v, ", unsafe.Pointer(&task))
    fmt.Printf("下标 = %v, ", unsafe.Pointer(&tasks[idx]))
    task  := task
    fmt.Printf("局部变量 = %vn", unsafe.Pointer(&task))
  }
}

这段代码的打印结果是:

1. 遍历 = 0x40c140, 下标 = 0x40c138, 局部变量 = 0x40c150
2. 遍历 = 0x40c140, 下标 = 0x40c13c, 局部变量 = 0x40c158

不一样机器上执行打印结果有所不一样,但共同点是:

  1. 遍历时,数据的内存地址不变
  2. 经过下标取数时,内存地址不一样
  3. for-loop 内建立的局部变量,即使名字相同,内存地址也不会复用

使用 WaitGroup 时,除了上面提到的注意事项,还须要解决数据回收和异常处理的问题。这里咱们也提供两种方式供参考:

  1. 对于 rpc 调用,能够经过 data channel 和 error channel 搜集信息,或者二合一的channel
  2. 共享变量,好比加锁的 map

4、WaitGroup 实现

在讨论这个主题以前,建议读者先思考一下:若是让你去实现 WaitGroup,你会怎么作?

锁?确定不行!

信号量?怎么实现?

------------切入正题------------

在 Go 源码里,WaitGroup 在逻辑上包含:

  1. worker 计数器:main协程调用 wg.Add(delta int) 时增长 delta,调用 wg.Done时减一。
  2. waiter 计数器:调用 wg.Wait 时,计数器加一; worker计数器下降到0时,重置waiter计数器
  3. 信号量:用于阻塞 main协程。调用 wg.Wait 时,经过 runtime_Semacquire 获取信号量;下降 waiter 计数器时,经过 runtime_Semrelease 释放信号量。

为了便于演示,咱们魔改一下上面的例子:

package main

import (
  "fmt"
  "sync"
  "time"
)

func main() {
  tasks  := []func(){
    func() { time.Sleep(time.Second); fmt.Println("1 sec later") },
    func() { time.Sleep(time.Second *  2); fmt.Println("2 sec later") },
}

  var wg sync.WaitGroup // 1-1
  wg.Add(len(tasks))    // 1-2
  for _, task := range tasks {
    task  := task
    go func() {       // 1-3-1
      defer wg.Done() // 1-3-2
      task()          // 1-3-3
    }()               // 1-3-1
  }
  wg.Wait()           // 1-4
  fmt.Println("exit")
}

上面这段代码中,

  1. 1-1 建立一个 WaitGroup 对象,worker计数器和waiter计数器默认值均为0。
  2. 1-2 设置 worker计数器为 len(tasks)
  3. 1-3-1 建立 worker协程,并启动任务。
  4. 1-4 设置 waiter计数器,获取信号量,main协程被阻塞。
  5. 1-3-3 中执行结束后,1-3-2 下降worker计数器。当worker计数器下降到0时,

    • 重置 waiter计数器
    • 释放信号量,main 协程被激活,1-4 wg.Wait 返回

尽管 Add(delta int) 里 delta 能够是正数、0、负数。咱们在使用时,delta 老是正数。

wg.Done 等价于 wg.Add(-1)。在本文中,咱们提到 wg.Add时,默认 delta > 0

了解了 WaitGroup 的原理之后,咱们看下它的源码。为了便于理解,我只保留了核心逻辑。对于这部分逻辑,咱们分三部分讲解:

  1. WaitGroup 结构
  2. AddDone
  3. Wait

提示:若是只想了解 WaitGroup 的正确用法,本文读到这儿就足够了。对底层有兴趣的朋友能够继续读,不过最好打开IDE,参考源码一块儿读。

4.1 WaitGroup 结构

type WaitGroup struct {
  noCopy noCopy
  state1 [3]uint32
}

WaitGroup 结构体里有 noCopystate1 两个字段。

编译代码时,go vet 工具会检查 noCopy 字段,避免 WaitGroup 对象被拷贝。

state1 字段比较秀,在逻辑上它包含了 worker计数器、waiter计数器和信号量。具体如何读这三个变量,参考下面代码:

// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
  if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
    return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
  } else {
    return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
  }
}

// 读取计数器和信号量
statep, semap := wg.state()
state  := atomic.LoadUint64(statep)
v := int32(state >> 32)
w := uint32(state)

三个变量的取数逻辑是:

  • worker计数器:vstatep *uint64左32位
  • waiter计数器:wstatep *uint64右32位
  • 信号量:semapstate1 [3]uint32 的第一个字节/最后一个字节

因此,更新worker计数器,须要这样作:

state := atomic.AddUint64(statep, uint64(delta)<<32)

更新waiter计数器,须要这样作:

statep, semap := wg.state()
for {
  state := atomic.LoadUint64(statep)
  if atomic.CompareAndSwapUint64(statep, state, state+1)   {
    // 忽略其余逻辑
    return
  }
}

细心的朋友可能会发现,worker计数器的更新是直接累加,而 waiter计数器的更新是 CompareAndSwap。这是由于在 main协程中执行 wg.Add 时,只有main协程对 state1 作修改;而 wg.Wait 中修改waiter计数器时,可能有不少个协程在更新 state1。若是你还不太理解这段话,不妨先往下走,了解 wg.Addwg.Wait 的细节以后再回头看。

4.2 Add 和 Done

wg.Add 操做的核心逻辑比较简单,即修改 worker计数器,根据worker计数器的状态进行后续操做。简化版的代码以下:

func (wg *WaitGroup) Add(delta int) {
  statep, semap := wg.state()
  // 1. 修改worker计数器
  state := atomic.AddUint64(statep, uint64(delta)<<32)
  v := int32(state >> 32)
  w := uint32(state)
  if v <  0 {
    panic("sync: negative WaitGroup counter")
  }
  if w != 0 && delta > 0 && v == int32(delta) {
    panic("sync: WaitGroup misuse: Add called concurrently with Wait")
  }
  // 2. 判断计数器
  if v > 0 || w == 0 {
    return
  }
  
  // 3. 当 worker计数器下降到0时
  // 重置 waiter计数器,并释放信号量
  *statep = 0
  for ; w != 0; w-- {
    runtime_Semrelease(semap, false)
  }
}

func (wg *WaitGroup) Done() {
  wg.Add(-1)
}

4.3 Wait

wg.Wait 的逻辑是修改waiter计数器,并等待信号量被释放。简化版的代码以下:

func (wg *WaitGroup) Wait() {
  statep, semap  := wg.state()
  for {
    // 1. 读取计数器
    state := atomic.LoadUint64(statep)
    v := int32(state >> 32)
    w := uint32(state)
    if v == 0 {
      return
    }

    // 2. 增长waiter计数器
    if atomic.CompareAndSwapUint64(statep, state, state+1) {
      // 3. 获取信号量
      runtime_Semacquire(semap)
      if *statep != 0 {
        panic("sync: WaitGroup is reused before previous Wait has returned")
      }
    
      // 4. 信号量获取成功
      return
    }
  }
}

因为源码比较长,包含了不少校验逻辑和注释,本文中在引用时,在保留核心逻辑的同时均作了不一样程度的删减。最后,推荐各位把源码下载下来,细细研读一番,从细节上对 WaitGroup 的设计有更深刻的理解。

相关文章
相关标签/搜索