基於AQS分析CountDownLatch的原理

示例代碼

先貼一段示例代碼

public class TestCountDownLatch {

    public static CountDownLatch countDownLatch = new CountDownLatch(3);

    public static void main(String[] args) throws InterruptedException {
        for (int i=0;i<3;i++){
            new Thread(new Task()).start();
        }
        System.out.println("線程啓動結束,主線程進入等待狀態");
        countDownLatch.await();
        System.out.println("主線程結束");
    }
}

class Task implements Runnable{
    @Override
    public void run() {
        try {
            Thread.sleep(1000);
            System.out.println(Thread.currentThread().getName()+" task ");
            TestCountDownLatch.countDownLatch.countDown();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

主線程會在countDownLatch.await()這裏阻塞,三個Task線程執行完各自的TestCountDownLatch.countDownLatch.countDown()後,主線程繼續向下執行。
這個是主線程阻塞等待子線程的例子,當前這個阻塞等待並不拘泥於主線程,可以讓任意一個線程進行await,當CountDownLatch計數器歸零時線程纔會繼續

原理

new CountDownLatch(3)
public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }


private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

可以看到CountDownLatch是通過一個繼承AQS的內部鎖實現的,構造器設定了鎖的狀態值。

await
 public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
    }

protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }
  1. 主線程調用await時會嘗試去獲取share鎖,如果此時state!=0,就會進入doAcquireSharedInterruptibly中,這個方法如果看過AQS的實現不會陌生,就是AQS中構造同步隊列節點的方法。
  2. AQS將當前線程(主線程)構造成一個節點(如果是第一個節點則會構建一個空的頭節點),然後主線程會自旋再去嘗試獲得一次share鎖
  3. 還是沒有獲取到,這時候將節點阻塞

countDown

public void countDown() {
        sync.releaseShared(1);
    }

public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }

 protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }

private void doReleaseShared() {
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

private void unparkSuccessor(Node node) {
        int ws = node.waitStatus;
        if (ws < 0)
            compareAndSetWaitStatus(node, ws, 0);
        Node s = node.next;
        if (s == null || s.waitStatus > 0) {
            s = null;
            for (Node t = tail; t != null && t != node; t = t.prev)
                if (t.waitStatus <= 0)
                    s = t;
        }
        if (s != null)
            LockSupport.unpark(s.thread);
    }
  1. 工作線程每次調用countDown,都會調用releaseShared(1);,即釋放了一把共享鎖。該方法會CAS更改state值,只有當state減到0時纔會返回true,即示例代碼中的前兩個線程調用該方法只會讓state自減1,當第三個線程將state減爲0了纔會發生下面的動作
  2. 調用doReleaseShared,將頭節點的WaitStatus更改後進入unparkSuccessor(h),這個函數會刪掉所有取消獲取鎖的線程節點,同時調用LockSupport.unpark方法將頭節點的下個節點(主線程)喚醒
  3. 主線程喚醒後又回到了doAcquireSharedInterruptibly中的for(;;)自旋中,這時候由於state=0因此是可以獲取到共享鎖的,然後會進入到if條件中將自己設置爲頭節點,並嘗試喚醒後面的節點
  4. 這時候主線程會返回,也就是從countDownLatch.await();返回可以繼續往下執行了

總結

總結下來就是,創建CountDownLatch時會讓其自己持有數量n的共享鎖,每次countDown就是在釋放這個共享鎖,await的線程要等到這個共享鎖完全被釋放了纔會返回

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章