【多線程】CountDownLatch實現原理

前言

CountDownLatch是多線程中一個比較重要的概念,它可以使得一個或多個線程等待其他線程執行完畢之後再執行。它內部有一個計數器和一個阻塞隊列,每當一個線程調用countDown()方法後,計數器的值減少1。當計數器的值不爲0時,調用await()方法的線程將會被加入到阻塞隊列,一直阻塞到計數器的值爲0。


常用方法

public class CountDownLatch {

    //構造一個值爲count的計數器
    public CountDownLatch(int count);

    //阻塞當前線程直到計數器爲0
    public void await() throws InterruptedException;

    //在單位爲unit的timeout時間之內阻塞當前線程
    public boolean await(long timeout, TimeUnit unit);

    //將計數器的值減1,當計數器的值爲0時,阻塞隊列內的線程纔可以運行
    public void countDown();      

}

下面給一個簡單的示例:

package com.yang.testCountDownLatch;

import java.util.concurrent.CountDownLatch;

public class Main {
    private static final int NUM = 3;

    public static void main(String[] args) throws InterruptedException {
        CountDownLatch latch = new CountDownLatch(NUM);
        for (int i = 0; i < NUM; i++) {
            new Thread(() -> {
                try {
                    Thread.sleep(2000);
                    System.out.println(Thread.currentThread().getName() + "運行完畢");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } finally {
                    latch.countDown();
                }
            }).start();
        }
        latch.await();
        System.out.println("主線程運行完畢");
    }
}

輸出如下:

看得出來,主線程會等到3個子線程執行完畢纔會執行。


原理解析

類圖
 

可以看得出來,CountDownLatch裏面有一個繼承AQS的內部類Sync,其實是AQS來支持CountDownLatch的各項操作的。

CountDownLatch(int count)

new CountDownLatch(int count)用來創建一個AQS同步隊列,並將計數器的值賦給了AQS的state。

    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    private static final class Sync extends AbstractQueuedSynchronizer {     
        Sync(int count) {
            setState(count);
        }

    }

countDown()

countDown()方法會對計數器進行減1的操作,當計數器值爲0時,將會喚醒在阻塞隊列中等待的所有線程。其內部調用了Sync的releaseShared(1)方法

    public void countDown() {
        sync.releaseShared(1);
    }

    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            //此時計數器的值爲0,喚醒所有被阻塞的線程
            doReleaseShared();
            return true;
        }
        return false;
    }

tryReleaseShared(arg)內部使用了自旋+CAS操將計數器的值減1,當減爲0時,方法返回true,將會調用doReleaseShared()方法。對CAS機制不瞭解的同學,可以先參考我的另外一篇文章淺探CAS實現原理

        protected boolean tryReleaseShared(int releases) {
            //自旋
            for (;;) {
                int c = getState();
                if (c == 0)
                    //此時計數器的值已經爲0了,其他線程早就執行完畢了,當前線程也已經再執行了,不需要再次喚醒了
                    return false;
                int nextc = c-1;
                //使用CAS機制,將state的值變爲state-1
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

doReleaseShared()是AQS中的方法,該方法會喚醒隊列中所有被阻塞的線程。

    private void doReleaseShared() {
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

這段方法比較難理解,會另外篇幅介紹。這裏只要認爲該段方法會喚醒所有因調用await()方法而阻塞的線程。

await()

當計數器的值不爲0時,該方法會將當前線程加入到阻塞隊列中,並把當前線程掛起。

    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

同樣是委託內部類Sync,調用其acquireSharedInterruptibly()方法

    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
    }

接着看Sync內的tryAcquireShared()方法,如果當前計數器的值爲0,則返回1,最終將導致await()不會將線程阻塞。如果當前計數器的值不爲0,則返回-1。

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

tryAcquireShared方法返回一個負值時,將會調用AQS中的doAcquireSharedInterruptibly()方法,將調用await()方法的線程加入到阻塞隊列中,並將此線程掛起。

    private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        //將當前線程構造成一個共享模式的節點,並加入到阻塞隊列中
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {        
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

同樣,以上的代碼位於AQS中,在沒有了解AQS結構的情況下去理解上述代碼,有些困難,關於AQS源碼,會另開篇幅介紹。


使用場景

CountDownLatch的使用場景很廣泛,一般用於分頭最某些事,再彙總的情景。例如:

數據報表:當前的微服務架構十分流行,大多數項目都會被拆成若干的子服務,那麼報表服務在進行統計時,需要向各個服務抽取數據。此時可以創建與服務數相同的線程數,交由線程池處理,每個線程去對應服務中抽取數據,注意需要在finally語句塊中進行countDown()操作。主線程調用await()阻塞,直到所有數據抽取成功,最後主線程再進行對數據的過濾組裝等,形成直觀的報表。

風險評估:客戶端的一個同步請求查詢用戶的風險等級,服務端收到請求後會請求多個子系統獲取數據,然後使用風險評估規則模型進行風險評估。如果使用單線程去完成這些操作,這個同步請求超時的可能性會很大,因爲服務端請求多個子系統是依次排隊的,請求子系統獲取數據的時間是線性累加的。此時可以使用CountDownLatch,讓多個線程併發請求多個子系統,當獲取到多個子系統數據之後,再進行風險評估,這樣請求子系統獲取數據的時間就等於最耗時的那個請求的時間,可以大大減少處理時間。

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章