深入理解Go語言(08):sync.WaitGroup源碼分析

一、sync.WaitGroup簡介

1.1 sync.WaitGroup 解決了什麼問題

在編程的時候,有時遇到一個大的任務,爲了提高計算速度,會用到併發程序,把一個大的任務拆分成幾個小的獨立的任務各自執行,因爲這幾個小任務相互沒有關係,可以獨立執行,這時候就可以用 Go 協程來處理這種併發任務。

但是這裏會有一個問題,協程的調度器調度主 goroutine 和子 goroutine 時,機會是均等的,萬一主 goroutine 運行完了,子 goroutine 還沒運行完,程序就結束了。子任務沒執行完程序結束,這種程序就有 bug,怎麼解決這種問題?能不能讓所有的子 goroutine 執行完,在讓主程序結束從而讓程序順利執行完。這時候,sync.WaitGroup 就出場了。

sync.WaitGroup() 可以等待一組 goroutine 執行完再讓剩下的程序執行完。

waitgroup 就是解決 go 中併發時,多個 goroutine 同步的問題。

1.2 用法

一般用法:

  1. 主 goroutine 通過調用 Add(i) 來設置需要等待的子 goroutine 數量,i 表示子 goroutine 數量。
  2. 子 goroutine 通過調用 Done() 來表示子 goroutine 執行完畢,goroutine 數量就減一,Add(-1)。
  3. 主 goroutine 通過調用 Wait() 來等待所有的子 goroutine 執行完畢。

一個小demo:

package main

import (
	"fmt"
	"sync"
)

func main() {
	var wg sync.WaitGroup

	wg.Add(2)

	go func() {
		defer wg.Done()
		fmt.Println("子 goroutine1")
	}()

	go func() {
		defer wg.Done()
		fmt.Println("子 goroutine2")
	}()

	wg.Wait() // 等待所有的子goroutine結束

	fmt.Println("程序運行結束")
}

程序運行輸出:

子 goroutine2
子 goroutine1
程序運行結束

二、waitgroup源碼分析

go1.17.10

2.1 數據結構WaitGroup

//https://github.com/golang/go/blob/go1.17.10/src/sync/waitgroup.go
// A WaitGroup waits for a collection of goroutines to finish.
// The main goroutine calls Add to set the number of
// goroutines to wait for. Then each of the goroutines
// runs and calls Done when finished. At the same time,
// Wait can be used to block until all goroutines have finished.
//
// A WaitGroup must not be copied after first use.
type WaitGroup struct {
	noCopy noCopy

	// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
	// 64-bit atomic operations require 64-bit alignment, but 32-bit
	// compilers do not ensure it. So we allocate 12 bytes and then use
	// the aligned 8 bytes in them as state, and the other 4 as storage
	// for the sema.
	state1 [3]uint32
}
  • 第一個字段:noCopy

Go 源碼中檢測禁止複製的技術。這種寫法(noCopy)告訴 go vet 檢測工具,如果有複製行爲,那麼就違反了複製使用的規則。

進一步解釋下,如果在程序中,有對 WaitGroup 賦值的行爲,那麼 go vet 會檢測,發現它並報錯違反了複製使用的規則,但是 noCopy 並不會影響程序正常編譯和運行。這是 Go 語言中的一個小 trick。

  • 第二個字段:state1 數組

這個字段由 3 個數組組成,每個數組大小佔 32 bits。一個字段就包含了 WaitGroup 中使用到的 3 種數據 - counter、waiter、semaphore(信號量)。而且還兼容 64 位系統和 32 位系統中的內存對齊,內存對齊的好處是加快 CPU 對內存的訪問。內存對齊在 64 位系統中,變量一般佔據 64 位(8 byte),對齊就是指變量的起始地址是 8 的倍數。在 32 位系統中,變量一般佔據 32 位(4 byte),對齊就是指變量的起始地址是 4 的倍數。

它怎麼包含 3 種數據,是怎麼做到的?程序註釋有解釋:

  • 64 位系統中:state1[0] 是 counter 計數器,state1[1] 是 waiter 計數器,state1[2] 就是 semaphore。
  • 32 位系統中:state1[0] 是 semaphore,state1[1] 是 counter 計數器,state1[2] 就是 waiter 計數器。

image-20220926170807715

counter、waiter、semaphore 說明:

  • counter:計算協程 goroutine 個數的計數器,表示當前要執行的 goroutine 個數。waitgroup 中的函數 Add(i),counter = counter + i;函數 Done(),counter - 1。

  • waiter:等待協程 goroutine 的計數器,表示當前已經調用 Wait() 函數的 goroutine-group 個數,也就是需要結束的goroutine組數。waitgroup 中的 Wait(),waiter + 1,並掛起當前 goroutine。

  • semaphore:go runtime 內部信號量實現。waitgroup 中會用到 semaphore 的兩個相關函數,runtime_Semacquire 和 runtime_Semrelease。

  • runtime_Semacquire 表示增加一個信號量,並掛起當前 goroutine。

  • runtime_Semrelease 表示減少一個信號量,並喚醒 semaphore 上其中一個正在等待的 goroutine。

A. 字段 state1 設計的技巧:

在 32 位系統中,內存對齊時,可以把數組第 1 位 state1[0] 作爲對齊的 padding,因爲 state1 本身是 uint32 的數組,所以數組第一位也有 32 位。這樣就保證了把數組後兩位看做統一的 64 位整數時是64位對齊的。

只改變 semaphore 的位置順序,就既可以保證 counter+waiter 一定會 64 位對齊,也可以保證內存的高效利用。

B. 信號量是什麼?

前面提到了信號量 semaphore,下面簡單瞭解下:

信號量是unix/linux系統提供的一種保護共享資源的機制,用於防止多個線程同時訪問某個資源。它本質上是一個計數器。

信號量包含一個非負整型的變量,有兩個原子操作 wait(down) 和 signal(up)。wait 又可以稱爲 P 或 down 操作,減 1 操作;signal 也被稱爲 V 或 up 操作,加 1 操作。信號量通過原子操作實現的 加 1減 1 運算來實現對併發資源的控制。

wait(down) 操作,如果信號量的非負整型變量 S > 0,wait 將其減 1;如果 S = 0,wait 將該線程阻塞。

signal(up) 操作,如果有線程在信號量上阻塞(此時 S = 0),signal 會解除對某個等待線程的阻塞,恢復運行;如果沒有線程阻塞在信號量上,signal 將 S 加 1。

S 可以理解爲資源的數量,信號量即是通過控制資源數量加減來實現併發的互斥和同步。

內核信號量 struct semaphore ,包含 3 個字段:

count - 存放 atomic_t 類型的值,表示資源的數量。

wait - 存放等待隊列鏈表的地址,當前等待資源的所有睡眠進程都放在這個鏈表中。如果 count 大於0或等於0,等待隊列爲空。

sleepers - 存放一個標識,表示是否有一些進程在信號量上睡眠。

更多關於信號量知識,請去查看 linux 內核相關內容。

2.2 state()-從state1中取變量:

WaitGroup struct 中的字段 state1 裏面包含 3 種數據變量,怎麼取出來呢?看下面的函數 state():

// https://github.com/golang/go/blob/go1.17.10/src/sync/waitgroup.go#L31
// state returns pointers to the state and sema fields stored within wg.state1.
// 得到state的地址和信號量的地址
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
        // 如果地址是64bit對齊,數組前兩個元素組成state,後一個元素是信號量
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
     // 如果地址是32bit對齊,數組後兩個元素組成state,第一個元素32bit就是信號量
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}

第一步,判斷是 32 位 還是 64 位系統:

Golang 中判斷當前變量是 32 位對齊還是 64 位對齊:https://go.dev/ref/spec#System_considerations,

uintptr(unsafe.Pointer(&x)) % unsafe.Alignof(x) == 0

image-20220926171002820

第二步,取出相應的數據

64位對齊:數組 state1 的 state[0] 是 counter 計數,state[2] 是信號量

(*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]

32位對齊:數組 state1 的 state[0] 是信號量,state[1] 是 counter 計數

(*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]

2.3 Add()-計數器

Add(delta int) 函數用來增加計數器的值,要運行的協程數量,它把 delta 值累加到 counter 中。二就是釋放信號量。

在程序中,delta 可以爲負值,也就是 counter 值可能變成 0 或者 負值,當 counter = 0,waiter 就釋放相等數量的信號量,把等待的 goroutine 全部喚醒。如果 counter < 0 負值了, 就 panic 報錯。

下面看看代碼,去掉 go runtime 裏相關競態代碼,

// https://github.com/golang/go/blob/go1.17.10/src/sync/waitgroup.go#L53
func (wg *WaitGroup) Add(delta int) {
    statep, semap := wg.state() // 獲取state(counter+waiter)和semaphore信號量的指針
    
	... ...
    // uint64(delta)<<32 把 delta 左移32位,因爲counter在statep的高32位
    // 然後把delta原子的增加到counter中
	state := atomic.AddUint64(statep, uint64(delta)<<32)
    // v => counter, w => waiter
	v := int32(state >> 32)//獲取counter值
	w := uint32(state)     //獲取waiter值
    
	... ...
    //counter變爲負值了,panic報錯
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
    //waiter不等於0,說明已經執行了waiter,這時你又調用Add(),是不允許的
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
    //v->counter,counter>0,說明還有goroutine沒執行完,不需要釋放信號量,直接返回
    //w->waiter, waiter=0,沒有等待的goroutine,不需要釋放信號量,直接返回
	if v > 0 || w == 0 {
		return
	}
    
    // This goroutine has set counter to 0 when waiters > 0.
	// Now there can't be concurrent mutations of state:
	// - Adds must not happen concurrently with Wait,
	// - Wait does not increment waiters if it sees counter == 0.
	// Still do a cheap sanity check to detect WaitGroup misuse.
    // Add()和Wait()不能並行操作
    // counter==0,也不能執行Wait()操作
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	
	*statep = 0 // 結束了將counter清零,下面在釋放waiter數的信號量
	for ; w != 0; w-- {// 循環釋放waiter個數的信號量
		runtime_Semrelease(semap, false, 0)// 一次釋放一個信號量,喚醒一個等待者
	}
}

2.4 Done()

// https://github.com/golang/go/blob/go1.17.10/src/sync/waitgroup.go#L97
// Done decrements the WaitGroup counter by one.
func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

Done() 函數直接調用Add(),然後傳入 -1 參數,將計數器減 1。

2.5 Wait()

Wait() 函數代碼,累加waiter數,增加信號量然後等待喚醒,

// https://github.com/golang/go/blob/go1.17.10/src/sync/waitgroup.go#L103
func (wg *WaitGroup) Wait() {
    statep, semap := wg.state() //獲取state(counter+waiter)和semaphore信號量的指針
	... ...
	for {// 死循環
		state := atomic.LoadUint64(statep) //原子的獲取state值
		v := int32(state >> 32) // 獲取counter值
		w := uint32(state)      //獲取waiter值
        if v == 0 {// counter=0,不需要wait直接返回
			// Counter is 0, no need to wait.
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
		... ...
		// Increment waiters count.
		if atomic.CompareAndSwapUint64(statep, state, state+1) {// 使用CAS累加wiater
			... ...
			runtime_Semacquire(semap) //增加信號量,等待信號量喚醒
            // 這時 *statep 還不等於 0,那麼使用過程肯定有誤,直接 panic
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			... ...
			return
		}
	}
}

2.6 小結

源碼分析完了,代碼量雖然很少,但是在 WaitGroup 包中的代碼做了很多異常情況判斷,對它的使用做了限制和規範。看看使用WaitGroup 時注意事項:

  • Add() 操作需要早於 Wait() 操作

  • 調用 Done() 次數要與 Add() 計數器值相等

  • 計數器 (counter) 的值小於 0,會 panic

  • Add() 和 Wait() 不能並行調用,比如在 2 個不同 goroutine 裏調用,會 panic

  • 要重複調用 WaitGroup,必須等 Wait() 執行完才能進行下一輪調用

三、參考

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章