WaitGroup 也是最常用的 Go 同步原语之一, 用来做任务编排. 它要解决的就是并发-等待的问题: 现在一个 goroutine A 在检查点(checkpoint) 等待一组 goroutine 全部完成它们的任务, 如果这些 goroutine 还没全部完成任务, 那么 goroutine A 就会被阻塞在检查点, 直到所有的 goroutine 都完成任务后才能继续执行.

比如, 要完成一个大任务, 需要使用并行的 goroutine 执行三个小任务, 只有这三个任务都完成了, 才能执行后面的任务. 如果通过轮询的方式定时询问三个小任务是否完成, 则存在两个问题, 一是性能比较低, 因为三个小任务可能早就完成, 却要等很长时间才被轮询到; 二是会有很多无谓的轮询, 空耗 CPU 资源, 因此需要 WaitGroup 同步原语, 它可以阻塞等待的 goroutine, 等到三个小任务都完成了, 再及时唤醒它们.

WaitGroup 使用方法

WaitGroup 的功能就是等待一组 goroutine 都完成任务. 一般主 goroutine 会设置要等待的 goroutine 数量 n, 也就是将计数器的值设置为 n, 这些 goroutine 运行完毕后调用 Done 方法, 告诉 WaitGroup 已经完成任务了. 主 goroutine 调用 Wait 方法偶尔会被阻塞, 直到这 n 个 goroutine 全部完成任务.

import (
	"log"
	"net/http"
	"sync"
	"time"
)

func RunDemo() {
	var wg sync.WaitGroup

	var urls = []string{
		"<https://www.baidu.com>",
		"<https://www.google.com>",
		"<https://www.bing.com>",
	}
	http.DefaultClient.Timeout = time.Second
	wg.Add(3)
	for i := 0; i < 3; i++ {
		go func(url string) {
			defer wg.Done()

			log.Println("fetching", url)
			resp, err := http.Get(url)
			if err != nil {
				return
			}
			resp.Body.Close()
		}(urls[i])
	}

	wg.Wait()
	log.Println("done")
}

WaitGroup 在使用中有如下一些特点:

如果想获取等待的那些 goroutine 执行的结果, 则需要使用额外的变量, 而 WaitGroup 本身不会保存额外的信息的.

func RunDemo2() {
	var wg sync.WaitGroup

	var urls = []string{
		"<https://www.baidu.com>",
		"<https://www.google.com>",
		"<https://www.bing.com>",
	}
	var result = make([]bool, len(urls))
	http.DefaultClient.Timeout = time.Second

	wg.Add(3)
	for i := 0; i < 3; i++ {
		i := i
		go func(url string) {
			defer wg.Done()
			log.Println("fetching", url)
			resp, err := http.Get(url)
			if err != nil {
				result[i] = false
				return
			}
			result[i] = resp.StatusCode == http.StatusOK
			resp.Body.Close()
		}(urls[i])
	}

	wg.Wait()
	log.Println("done")
	for i := 0; i < len(urls); i++ {
		log.Println(urls[i], ":", result[i])
	}
}

另外, WaitGroup 本身没有控制这些执行任务的 goroutine 中止能力, 只能等这些 goroutine 执行完毕, 把计数器的值降为0

WaitGroup 的实现

首先看 WaitGroup 的结构体定义(Go 1.20)为例:

type WaitGroup struct {
	noCopy noCopy

	state  atomic.Unit64 // 高 32 位为计数器的值, 低 32 位为 waiter 的数量
	sema   uint32        // 信号量
}

第一个字段是 noCopy 是辅助字段, 主要辅助 vet 工具检查是否通过 copy 复制了这个 WaitGroup 实例.