CycleBarrier與CountDownLatch原理

CountDownLatch

衆所周知,它能解決一個任務必須在其他任務完成的情況下才能執行的問題,代碼層面來說就是隻有計數countDown到0的時候,await處的代碼才能繼續向下運行,例如:

import java.util.*;
import java.util.concurrent.*;

public class Main {
    public static void main(String[] args) throws Exception {

        CountDownLatch latch = new CountDownLatch(3);

        ThreadPoolExecutor executor = new ThreadPoolExecutor(10, 15, 60L, TimeUnit.SECONDS, new ArrayBlockingQueue<>(5));
        Future<Integer>[] futures = new Future[3];
        for (int i = 0; i < 3; i++){
            futures[i] = executor.submit(() -> {
                Random rand = new Random();
                int n = rand.nextInt(100);
                int result = 0;
                for (int j = 0; j < n; j++){
                    result += j;
                }
                System.out.println(result + "|" + Thread.currentThread().getName());
                latch.countDown();
                return result;
            });
        }
        latch.await();
        System.out.println("合計每個任務的結果:" + (futures[0].get()+futures[1].get()+futures[2].get()));
    }

}

運行結果:

源碼

實際上內部十分簡單,裏面只有一個AQS的子類

private static final class Sync extends AbstractQueuedSynchronizer {
    private static final long serialVersionUID = 4982264981922014374L;

    // 它把AQS的state(同步狀態)作爲計數器,在AQS裏,state是個volatile標記的int變量
    Sync(int count) {
        setState(count);
    }

    int getCount() {
        return getState();
    }

    protected int tryAcquireShared(int acquires) {
        // 同步狀態爲0,則返回1,否則返回-1
        return (getState() == 0) ? 1 : -1;
    }

    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            // 如果狀態爲0則返回false
            if (c == 0)
                return false;
            // 計數器減1
            int nextc = c-1;
            // CAS操作,如果內存中的同步狀態值等於期望值c,那麼將同步狀態設置爲給定的更新值nextc
            if (compareAndSetState(c, nextc))
                return nextc == 0;  // 當計數器減到0,返回true
        }
    }
}

public void countDown() {
    sync.releaseShared(1);
}

public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

下面看具體做了什麼事情

先來看await

public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    // 當計數器不等於0,返回-1,證明還有任務未執行完,進入下面方法等待
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}

private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    // 把當前線程包裝成Node放入等待隊列
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            // 獲取當前線程的前驅節點,以檢查等待狀態
            final Node p = node.predecessor();
            if (p == head) {
                // 如果計數器等於0,返回1,證明此時阻塞可以解除了
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

上面的過程可以總結爲:當進入await方法後,如果此時計數器不爲0,則進入死循環一直檢查計數器的值,直到爲0退出,此時停止等待。

再來看countDown

public final boolean releaseShared(int arg) {
    // 嘗試計數器減1,只有減到0纔會返回true
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}

private void doReleaseShared() {
    for (;;) {
        Node h = head;
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            // 等待狀態爲SIGNAL
            if (ws == Node.SIGNAL) {
                // 把當前節點的等待狀態從SIGNAL設置成0,如果設置失敗則繼續循環。
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            // loop to recheck cases
                // 成功的話則卸載當前節點的所有後繼
                unparkSuccessor(h);
            }
            // 如果等待狀態爲0,則嘗試將狀態設置爲PROPAGATE,如果設置失敗則繼續循環。
            else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        if (h == head)                   // loop if head changed
            break;
    }
}

countDown的過程可以總結爲:嘗試將計數器-1,直到爲0,爲0的時候通知等待線程。

CycleBarrier

欄柵的作用就是讓指定的一批任務能夠同時開始執行,比如

import java.util.*;
import java.util.concurrent.*;

public class Main {
    public static void main(String[] args) throws Exception {
        CyclicBarrier cyclicBarrier = new CyclicBarrier(3);


        ThreadPoolExecutor executor = new ThreadPoolExecutor(10, 15, 60L, TimeUnit.SECONDS, new ArrayBlockingQueue<>(5));
        Future<Integer>[] futures = new Future[3];
        for (int i = 0; i < 3; i++){
            futures[i] = executor.submit(() -> {
                System.out.println("await|" + Thread.currentThread().getName());
                cyclicBarrier.await();
                Random rand = new Random();
                int n = rand.nextInt(100);
                int result = 0;
                for (int j = 0; j < n; j++){
                    result += j;
                }
                System.out.println(result + "|" + Thread.currentThread().getName());
                return result;
            });
        }
    }

}

運行結果

源碼

進來之後首先發現的是成員變量

/** 用來保護柵欄入口的鎖 */
private final ReentrantLock lock = new ReentrantLock();
/** 等待條件,直到計數器爲0 */
private final Condition trip = lock.newCondition();
/** 參與線程的個數 */
private final int parties;
/* 計數器爲0時要運行的命令,由用戶定義 */
private final Runnable barrierCommand;
/** 當前等待的一代 */
private Generation generation = new Generation();
/**
 * parties數量的等待線程。每一代等待的數量從parties到0。當調用nextGeneration或者breakBarrier方法時重置。
 */
private int count;

從這裏可以看出,除了內部實現用的ReentrantLock,其工作過程無非:計數器不爲0的時候線程等待;當等待線程全部就緒,也就是計數器減爲0的時候重置計數器並通知所有線程繼續運行。

導致計數器重置原因有兩個:一個就是發生異常,將當前這一代標記爲無效(broken=true);另一個就是正常就緒,開啓下一代(new Generation)

核心方法dowait

// 情況一:timed=false,nanos=0L,代表一直阻塞
// 情況二:timed=true,nanos!=0L,代表在超時時間內阻塞
private int dowait(boolean timed, long nanos)
    throws InterruptedException, BrokenBarrierException,
           TimeoutException {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        // 獲取當前這一代
        final Generation g = generation;

        // 如果當前這一代已經銷燬,拋異常
        if (g.broken)
            throw new BrokenBarrierException();
        // 測試當前線程是否被中斷
        if (Thread.interrupted()) {
            // 將broken設置爲true,代表這一代已經銷燬,重置count;然後通知所有等待線程
            breakBarrier();
            throw new InterruptedException();
        }
        // count 減1
        int index = --count;
        // 如果減1之後變成0,證明等待線程全部就緒。
        if (index == 0) {  // tripped
            boolean ranAction = false;
            try {
                // 如果用戶定義了額外的命令,則執行
                final Runnable command = barrierCommand;
                if (command != null)
                    command.run();
                ranAction = true;
                // 開啓下一代(通知所有等待線程,重置count,new一個新的Generation)
                nextGeneration();
                return 0;
            } finally {
                if (!ranAction)
                    breakBarrier();
            }
        }

        // loop until tripped, broken, interrupted, or timed out
        // 如果減1之後不等於0,也就是還有其它線程沒有就緒,那麼進入此循環,直到就緒或者被銷燬,或者被中斷和超時
        for (;;) {
            try {
                if (!timed)
                    // 未定義超時,則一直阻塞
                    trip.await();
                else if (nanos > 0L)
                    // 等待指定的超時時間
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                    // We're about to finish waiting even if we had not
                    // been interrupted, so this interrupt is deemed to
                    // "belong" to subsequent execution.
                    Thread.currentThread().interrupt();
                }
            }

            if (g.broken)
                throw new BrokenBarrierException();

            if (g != generation)
                return index;

            // 超時,則銷燬這一代,通知所有等待線程並重置count
            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }
}

總結

兩個工具實現思路都很簡單,唯一我思考的是,爲什麼CountDownLatch只能用一次?

CycleBarrier很明顯,它無論正常執行或者發生異常中斷都有重置count的邏輯。

而CountDownLatch則沒有重置的邏輯,那麼,到底是CountDownLatch不能重置還是僅僅因爲沒有重置的邏輯。爲此我把CountDownLatch的代碼照搬,然後加上了簡單的重置方法,如下:

import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;


public class MyCountDown {

    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        /**
         * 新加
         * @param count
         */
        void reset(int count){
            // 重新設置狀態
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    private final Sync sync;

    private final int count;

    public MyCountDown(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
        this.count = count;
    }


    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    public boolean await(long timeout, TimeUnit unit)
            throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }


    public void countDown() {
        sync.releaseShared(1);
    }


    public long getCount() {
        return sync.getCount();
    }

    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }

    /**
     * 新加
     */
    public void reset(){
        // 調用重置的方法
        this.sync.reset(count);
    }
}

測試:

import java.util.*;
import java.util.concurrent.*;

public class Main {
    public static void main(String[] args) throws Exception {

        MyCountDown myCountDown = new MyCountDown(3);
        ThreadPoolExecutor executor = new ThreadPoolExecutor(10, 15, 60L, TimeUnit.SECONDS, new ArrayBlockingQueue<>(5));
        Future<Integer>[] futures = new Future[3];
        for (int i = 0; i < 3; i++){
            futures[i] = executor.submit(() -> {
                Random rand = new Random();
                int n = rand.nextInt(100);
                int result = 0;
                for (int j = 0; j < n; j++){
                    result += j;
                }
                System.out.println(result + "|" + Thread.currentThread().getName());
                Thread.sleep(new Random().nextInt(2000));   // 模擬耗時
                myCountDown.countDown();
                return result;
            });
        }
        myCountDown.await();
        System.out.println("第一次:" + (futures[0].get() + futures[1].get() + futures[2].get()));
        myCountDown.reset();    // 重置

        for (int i = 0; i < 3; i++){
            futures[i] = executor.submit(() -> {
                Random rand = new Random();
                int n = rand.nextInt(100);
                int result = 0;
                for (int j = 0; j < n; j++){
                    result += j;
                }
                System.out.println(result + "|" + Thread.currentThread().getName());
                Thread.sleep(new Random().nextInt(2000));   // 模擬耗時
                myCountDown.countDown();
                return result;
            });
        }
        myCountDown.await();
        System.out.println("如果重置無效,則這個信息會先於任務信息輸出");
        System.out.println("第二次:" + (futures[0].get() + futures[1].get() + futures[2].get()));
    }

}

輸出

如果換成CountDownLatch

import java.util.*;
import java.util.concurrent.*;

public class Main {
    public static void main(String[] args) throws Exception {

        CountDownLatch latch = new CountDownLatch(3);
        ThreadPoolExecutor executor = new ThreadPoolExecutor(10, 15, 60L, TimeUnit.SECONDS, new ArrayBlockingQueue<>(5));
        Future<Integer>[] futures = new Future[3];
        for (int i = 0; i < 3; i++){
            futures[i] = executor.submit(() -> {
                Random rand = new Random();
                int n = rand.nextInt(100);
                int result = 0;
                for (int j = 0; j < n; j++){
                    result += j;
                }
                System.out.println(result + "|" + Thread.currentThread().getName());
                Thread.sleep(new Random().nextInt(2000));   // 模擬耗時
                latch.countDown();
                return result;
            });
        }
        latch.await();
        System.out.println("第一次:" + (futures[0].get() + futures[1].get() + futures[2].get()));

        for (int i = 0; i < 3; i++){
            futures[i] = executor.submit(() -> {
                Random rand = new Random();
                int n = rand.nextInt(100);
                int result = 0;
                for (int j = 0; j < n; j++){
                    result += j;
                }
                System.out.println(result + "|" + Thread.currentThread().getName());
                Thread.sleep(new Random().nextInt(2000));   // 模擬耗時
                latch.countDown();
                return result;
            });
        }
        latch.await();
        System.out.println("如果重置無效,則這個信息會先於任務信息輸出");
        System.out.println("第二次:" + (futures[0].get() + futures[1].get() + futures[2].get()));
    }

}

輸出

 

 所以可以得出結論,CountDownLatch不是沒有辦法重置,只不過沒有寫相關邏輯。當然這個問題如果我說錯了,望指正。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章