CAS與AQS源碼簡析

什麼是CAS?
CAS(Compare And Swap),顧名思義就是比較並交換。用於解決多線程使用鎖帶來的性能損耗的問題,是一種非阻塞算法,其交換原理如下圖:
 

 
 
CAS用法:
- 數據庫中的樂觀鎖:即表字段+version字段,然後每次更新時就比較當前version版本是否一致,一直才更新並且升級version=version+1。
 
 
- java中用到CAS的類如:java.util.concurrent.atomic.*
 

什麼是AQS?
AQS(AbstractQueuedSynchronizer),顧名思義就是抽象隊列同步器。由FIFO(先進先出)的阻塞隊列和相關同步器組成。這是在concurrent包(併發處理)下。
 

    AbstractQueuedSynchronizer爲鎖機制維護了一個隊列,需要獲取鎖的線程們排在隊列中,只有排在隊首的線程纔有資格獲取鎖。

 
 
首先看張圖,取自《Java併發編程的藝術》:
 

 
 
然後看如下,AbstractQueuedSynchronizer源碼及其分析如下:
/**
* 提供一個阻塞鎖和相關依賴FIFO等待隊列同步器的實現。
* 這個類支持排他共享模式。排他模式下當一個已獲取到了,其他線程嘗試獲取不可能成功。共享模式可以被多個線程獲取。通常子類實現僅支持其中一種,但是也有兩種的支持的如ReadWriteLock
* 這個類定義了一個實現了Condition的內部類ConditionObject,用於排他模式。    
* 使用一個基礎的同步器需要重新定義以下方法:
* <li> {@link #tryAcquire}
* <li> {@link #tryRelease}
* <li> {@link #tryAcquireShared}
* <li> {@link #tryReleaseShared}
* <li> {@link #isHeldExclusively}
* 以上的每個方法均默認拋出{UnsupportedOperationException}錯誤,所以以上的幾個方法沒有提供默認實現,需要子類重寫。
* 這個類提供了一個有效的、可伸縮的基礎給同步器如狀態、acquire獲取和的同步器釋放參數、內部FIFO等待隊。當這些不夠用時,可使用atomic、Queue、LockSupport。
*/
public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {
    private static final long serialVersionUID = 7373984972572414691L;
 
    protected AbstractQueuedSynchronizer() { }
 
    /**
     * 等待隊列的node class。Node作爲等待隊列的節點
     * 這個等待隊列是CLH的變體,CLH一般用於自旋鎖。使用其代替一般的同步器,但也用了相同的策略來控制。 
     */
    static final class Node {
        /** 共享模式 */
        static final Node SHARED = new Node();
        /** 排他模式 */
        static final Node EXCLUSIVE = null;
        /** 當前線程被取消 */
        static final int CANCELLED =  1;
        /** 當前節點的後繼節點包含的線程需要運行*/
        static final int SIGNAL    = -1;
        /** 當前結點在condition隊列中 */
        static final int CONDITION = -2;
        /** 當前場景下後續的acquireShared能夠得以執行  */
        static final int PROPAGATE = -3;
 
        /** 當前節點的狀態。*/
        volatile int waitStatus;
        /** 前驅結點 */
        volatile Node prev;
        /** 後繼節點 */
        volatile Node next;
        /** 入隊線程 */
        volatile Thread thread;
        /** 存儲condition隊列中的後繼節點 */
        Node nextWaiter;
 
        final boolean isShared() { return nextWaiter == SHARED;}
        /**
         * 返回前驅節點
         */
        final Node predecessor() throws NullPointerException {
            Node p = prev;
            if (p == null)
                throw new NullPointerException();
            else
                return p;
        }
        ..............................
    }
    /**
     * 僅用於初始化等待隊列的head。只能通過setHead修改,當這個head還存在時不能將waitStatus=>cancelled
     */
    private transient volatile Node head;
    /** Tail節點初始化,僅能通過enq追加新的wait node*/
    private transient volatile Node tail;
    /** synchronization state. */
    private volatile int state;
 
    /** CAS原子性的修改 synchronization state ,拉到代碼最下面可見其值的設置*/
    protected final boolean compareAndSetState(int expect, int update) {
        return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
    }
    static final long spinForTimeoutThreshold = 1000L;
 
    /**
     * 爲隊列追加node節點
     */
    private Node enq(final Node node) {
        for (;;) {//一直循環入隊,直到成功
            Node t = tail;
            if (t == null) { //同樣獲取尾節點,並且如果爲空就將尾節點初始化爲頭結點head一樣
                if (compareAndSetHead(new Node()))
                    tail = head;
            } else { //尾節點不爲空就執行addWaiter一樣的過程把新的node加到最後
                node.prev = t;
                if (compareAndSetTail(t, node)) {
                    t.next = node;
                    return t;
                }
            }
        }
    }
 
    /**
     * 新建Node併入隊
     */
    private Node addWaiter(Node mode) {
        //新建一個Node
        Node node = new Node(Thread.currentThread(), mode);
        // 存儲當前尾節點(當作舊的尾節點)
        Node pred = tail;
        if (pred != null) {  //如果當前尾節點不爲空
            node.prev = pred;  //將新建的節點的前驅節點執行舊的爲節點
            if (compareAndSetTail(pred, node)) {//CAS原子替換當前尾節點從舊的替換到新建node的位置
                pred.next = node;//將舊的尾節點位置的後置節點執行新建的節點
                return node;
            }
        }
        //如果上面入隊失敗則調用enq方法入隊
        enq(node);
        return node;
    }
    private void setHead(Node node) {
        head = node; //將頭結點指向node
        node.thread = null;  //線程置空
        node.prev = null;//因爲是頭節點了,不用需要前驅結點
    }
 
    /**
     * 喚醒後續節點
     */
    private void unparkSuccessor(Node node) {
        int ws = node.waitStatus;
        if (ws < 0) compareAndSetWaitStatus(node, ws, 0);
        Node s = node.next;
        //如果後置節點是尾節點或Cancelled狀態
        if (s == null || s.waitStatus > 0) {
            s = null;  //將當前後置節點置爲null
            for (Node t = tail; t != null && t != node; t = t.prev)
                if (t.waitStatus <= 0)
                    s = t;
        }
        if (s != null)
            LockSupport.unpark(s.thread);
    }
 
    /**
     * 共享模式下
     */
    private void doReleaseShared() {
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }
 
    private void setHeadAndPropagate(Node node, int propagate) {
        Node h = head; // Record old head for check below
        setHead(node);
        if (propagate > 0 || h == null || h.waitStatus < 0 ||
            (h = head) == null || h.waitStatus < 0) {
            Node s = node.next;
            if (s == null || s.isShared())
                doReleaseShared();
        }
    }
 
    /**
     * 取消正在嘗試獲取鎖的節點
     */
    private void cancelAcquire(Node node) {
        if (node == null) return;
        //cancel一個節點時會將當前節點thread置爲null
        node.thread = null;
        // 循環跳過已設置了cancelled狀態的節點
        Node pred = node.prev;
        while (pred.waitStatus > 0) node.prev = pred = pred.prev;
        //存儲上面得到的節點前驅節點
        Node predNext = pred.next;
        //將當前要cancel的節點狀態設置CANCELLED
        node.waitStatus = Node.CANCELLED;
 
        //1 如果當前節點node是尾節點。更新尾節點爲pred.next指向null,相當於刪除了node(和pred到node間爲cancel的節點)
        if (node == tail && compareAndSetTail(node, pred)) {
            compareAndSetNext(pred, predNext, null);
        } else {
            int ws;
            //2 當前既不是尾節點,也不是head後繼節點。設置node的前驅節點waitStatus爲SIGNAL,node前驅節點指向後繼節點,相當於刪除node
            if (pred != head &&
                ((ws = pred.waitStatus) == Node.SIGNAL ||
                 (ws <= 0 && compareAndSetWaitStatus(pred, ws, Node.SIGNAL))) &&
                pred.thread != null) {
                Node next = node.next;
                if (next != null && next.waitStatus <= 0)
                    compareAndSetNext(pred, predNext, next);
            } else {
                //3 如果node是head的後繼節點。則直接喚醒node的後繼節點。在head後面的節點有資格嘗試獲取鎖,但是當前node放齊了當前資格,所以會喚醒其後續的節點
                unparkSuccessor(node);
            }
 
            node.next = node; // help GC
        }
    }
 
    /**
     * 判斷當前節點是否掛起
     */
    private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
        int ws = pred.waitStatus;
        if (ws == Node.SIGNAL)  //當前狀態下掛起
            return true;
        if (ws > 0) {
            do {//跳過已被設置了cancelled的前驅節點
                node.prev = pred = pred.prev;
            } while (pred.waitStatus > 0);
            pred.next = node;
        } else {
            /** 將上級的等待狀態設爲SIGNAL */
            compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
        }
        return false;
    }
 
    static void selfInterrupt() {
        Thread.currentThread().interrupt();
    }
 
    private final boolean parkAndCheckInterrupt() {
        LockSupport.park(this);
        return Thread.interrupted();
    }
 
    /**
     * 嘗試獲取鎖
     */
    final boolean acquireQueued(final Node node, int arg) {
        boolean failed = true;
        try {
            boolean interrupted = false;
            for (;;) {  //進入循環後會不斷的嘗試獲取
                final Node p = node.predecessor();//獲取當前節點的頭結點!!只有head頭結點才持有鎖!!
                //如果當前的前驅節點是頭結點則嘗試獲取鎖。
                //如果嘗試成功則將當前node設爲頭結點,並將舊的head設爲null便於回收
                //獲取失敗看是否需要掛起,如果需要掛起則掛起線程等待下一次被喚醒時繼續嘗試獲取鎖。
                if (p == head && tryAcquire(arg)) {
                    setHead(node);
                    p.next = null; // 幫助GC
                    failed = false;
                    return interrupted;
                }
                //判斷是否掛起(根據Node的狀態=-3就會掛起),然後調用颳起的方法(裏面調了Thread.interrupted();)
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())  
                    interrupted = true;
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }
    。。。。。。。。。。。。。。。
    // Main exported methods
 
  /** 子類實現的嘗試獲取鎖的方法 */
    protected boolean tryAcquire(int arg) {
        throw new UnsupportedOperationException();
    }
 
    /** 子類實現嘗試釋放鎖的方法 */
    protected boolean tryRelease(int arg) {
        throw new UnsupportedOperationException();
    }
 
    /** 子類實現嘗試獲取共享鎖的方法  */
    protected int tryAcquireShared(int arg) {
        throw new UnsupportedOperationException();
    }
 
    /** 子類實現嘗試釋放共享鎖的方法  */
    protected boolean tryReleaseShared(int arg) {
        throw new UnsupportedOperationException();
    }
 
     /** 子類實現排他模式下狀態是否佔用  */
    protected boolean isHeldExclusively() {
        throw new UnsupportedOperationException();
    }
 
    /**
     * 排他模式,獲取互斥鎖
     */
    public final void acquire(int arg) {
        //嘗試獲取鎖(tryAcquire在此類中是拋異常的,應在子類實現),
        //如果嘗試獲取失敗就調用acquireQueued再次嘗試獲取鎖,addWaiter適用於新建一個新node
        if (!tryAcquire(arg) &&
            acquireQueued(addWaiter(Node.EXCLUSIVE), arg))
            selfInterrupt();
    }
 
    public final void acquireInterruptibly(int arg)
             throws InterruptedException {}
    public final boolean tryAcquireNanos(int arg, long nanosTimeout)
            throws InterruptedException {}
 
    public final boolean release(int arg) {
        if (tryRelease(arg)) {//嘗試釋放鎖成功
            Node h = head;//獲取當前被釋放了鎖的head頭節點
            //如果頭節點不爲空且當前節點狀態正常就喚醒當前節點的後續節點
            if (h != null && h.waitStatus != 0)
                unparkSuccessor(h);
            return true;
        }
        return false;
    }
 
    public final void acquireShared(int arg) {
        if (tryAcquireShared(arg) < 0)
            doAcquireShared(arg);
    }
 
    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {... }
 
    public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
            throws InterruptedException {........ }
 
    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }
 
    // 以下是隊列檢查方法
    public final boolean hasQueuedThreads() {
        return head != tail;//隊列是否存在
    }
    public final boolean hasContended() {
        return head != null;  //頭結點是否爲空
    }
    public final Thread getFirstQueuedThread() {
        return (head == tail) ? null : fullGetFirstQueuedThread();
    }
     。。。。。。。。。。
    //工具和監控方法
    public final int getQueueLength() {
        int n = 0;//拿到隊列長度
        for (Node p = tail; p != null; p = p.prev) {
            if (p.thread != null)
                ++n;
        }
        return n;
    }
    //獲取當前node queue的所有線程
    public final Collection<Thread> getQueuedThreads() {
        ArrayList<Thread> list = new ArrayList<Thread>();
        for (Node p = tail; p != null; p = p.prev) {
            Thread t = p.thread;
            if (t != null)
                list.add(t);
        }
        return list;
    }
    ......的工具類..............................
 
    /**
     * condition queue, 單向隊列。線程拿到鎖,但條件不足時,會放到這個隊列等待被喚醒
     */
    public class ConditionObject implements Condition, java.io.Serializable {
        private static final long serialVersionUID = 1173984872572414699L;
        private transient Node firstWaiter; //頭結點
        private transient Node lastWaiter;//尾節點
 
        public ConditionObject() { }
 
        private Node addConditionWaiter() {
            Node t = lastWaiter;
            // If lastWaiter is cancelled, clean out.
            if (t != null && t.waitStatus != Node.CONDITION) {
                unlinkCancelledWaiters();
                t = lastWaiter;
            }
            Node node = new Node(Thread.currentThread(), Node.CONDITION);
            if (t == null)
                firstWaiter = node;
            else
                t.nextWaiter = node;
            lastWaiter = node;
            return node;
        }
 
        private void doSignal(Node first) {
            do {
                if ( (firstWaiter = first.nextWaiter) == null)
                    lastWaiter = null;
                first.nextWaiter = null;
            } while (!transferForSignal(first) &&
                     (first = firstWaiter) != null);
        }
 
        private void doSignalAll(Node first) {
            lastWaiter = firstWaiter = null;
            do {
                Node next = first.nextWaiter;
                first.nextWaiter = null;
                transferForSignal(first);
                first = next;
            } while (first != null);
        }
 
        private void unlinkCancelledWaiters() {
            Node t = firstWaiter;
            Node trail = null;
            while (t != null) {
                Node next = t.nextWaiter;
                if (t.waitStatus != Node.CONDITION) {
                    t.nextWaiter = null;
                    if (trail == null)
                        firstWaiter = next;
                    else
                        trail.nextWaiter = next;
                    if (next == null)
                        lastWaiter = trail;
                }
                else
                    trail = t;
                t = next;
            }
        }
 
        public final void signal() {
            if (!isHeldExclusively())
                throw new IllegalMonitorStateException();
            Node first = firstWaiter;
            if (first != null)
                doSignal(first);
        }
 
        public final void signalAll() {
            if (!isHeldExclusively())
                throw new IllegalMonitorStateException();
            Node first = firstWaiter;
            if (first != null)
                doSignalAll(first);
        }
 
        public final void awaitUninterruptibly() {
            Node node = addConditionWaiter();
            int savedState = fullyRelease(node);
            boolean interrupted = false;
            while (!isOnSyncQueue(node)) {
                LockSupport.park(this);
                if (Thread.interrupted())
                    interrupted = true;
            }
            if (acquireQueued(node, savedState) || interrupted)
                selfInterrupt();
        }
 
        private static final int REINTERRUPT =  1;
        private static final int THROW_IE    = -1;
 
        private int checkInterruptWhileWaiting(Node node) {
            return Thread.interrupted() ?
                (transferAfterCancelledWait(node) ? THROW_IE : REINTERRUPT) :
                0;
        }
 
        private void reportInterruptAfterWait(int interruptMode)
            throws InterruptedException {
            if (interruptMode == THROW_IE)
                throw new InterruptedException();
            else if (interruptMode == REINTERRUPT)
                selfInterrupt();
        }
 
        public final void await() throws InterruptedException {
            if (Thread.interrupted())
                throw new InterruptedException();
            Node node = addConditionWaiter();
            int savedState = fullyRelease(node);
            int interruptMode = 0;
            while (!isOnSyncQueue(node)) {
                LockSupport.park(this);
                if ((interruptMode = checkInterruptWhileWaiting(node)) != 0)
                    break;
            }
            if (acquireQueued(node, savedState) && interruptMode != THROW_IE)
                interruptMode = REINTERRUPT;
            if (node.nextWaiter != null) // clean up if cancelled
                unlinkCancelledWaiters();
            if (interruptMode != 0)
                reportInterruptAfterWait(interruptMode);
        }
 
        public final long awaitNanos(long nanosTimeout)
                throws InterruptedException {
            if (Thread.interrupted())
                throw new InterruptedException();
            Node node = addConditionWaiter();
            int savedState = fullyRelease(node);
            final long deadline = System.nanoTime() + nanosTimeout;
            int interruptMode = 0;
            while (!isOnSyncQueue(node)) {
                if (nanosTimeout <= 0L) {
                    transferAfterCancelledWait(node);
                    break;
                }
                if (nanosTimeout >= spinForTimeoutThreshold)
                    LockSupport.parkNanos(this, nanosTimeout);
                if ((interruptMode = checkInterruptWhileWaiting(node)) != 0)
                    break;
                nanosTimeout = deadline - System.nanoTime();
            }
            if (acquireQueued(node, savedState) && interruptMode != THROW_IE)
                interruptMode = REINTERRUPT;
            if (node.nextWaiter != null)
                unlinkCancelledWaiters();
            if (interruptMode != 0)
                reportInterruptAfterWait(interruptMode);
            return deadline - System.nanoTime();
        }
 
        public final boolean awaitUntil(Date deadline)
                throws InterruptedException {
            long abstime = deadline.getTime();
            if (Thread.interrupted())
                throw new InterruptedException();
            Node node = addConditionWaiter();
            int savedState = fullyRelease(node);
            boolean timedout = false;
            int interruptMode = 0;
            while (!isOnSyncQueue(node)) {
                if (System.currentTimeMillis() > abstime) {
                    timedout = transferAfterCancelledWait(node);
                    break;
                }
                LockSupport.parkUntil(this, abstime);
                if ((interruptMode = checkInterruptWhileWaiting(node)) != 0)
                    break;
            }
            if (acquireQueued(node, savedState) && interruptMode != THROW_IE)
                interruptMode = REINTERRUPT;
            if (node.nextWaiter != null)
                unlinkCancelledWaiters();
            if (interruptMode != 0)
                reportInterruptAfterWait(interruptMode);
            return !timedout;
        }
 
        public final boolean await(long time, TimeUnit unit)
                throws InterruptedException {
            long nanosTimeout = unit.toNanos(time);
            if (Thread.interrupted())
                throw new InterruptedException();
            Node node = addConditionWaiter();
            int savedState = fullyRelease(node);
            final long deadline = System.nanoTime() + nanosTimeout;
            boolean timedout = false;
            int interruptMode = 0;
            while (!isOnSyncQueue(node)) {
                if (nanosTimeout <= 0L) {
                    timedout = transferAfterCancelledWait(node);
                    break;
                }
                if (nanosTimeout >= spinForTimeoutThreshold)
                    LockSupport.parkNanos(this, nanosTimeout);
                if ((interruptMode = checkInterruptWhileWaiting(node)) != 0)
                    break;
                nanosTimeout = deadline - System.nanoTime();
            }
            if (acquireQueued(node, savedState) && interruptMode != THROW_IE)
                interruptMode = REINTERRUPT;
            if (node.nextWaiter != null)
                unlinkCancelledWaiters();
            if (interruptMode != 0)
                reportInterruptAfterWait(interruptMode);
            return !timedout;
        }
    。。。。。。。。。。之類的工具類。。。。。。。。
    }
 
    private static final Unsafe unsafe = Unsafe.getUnsafe();
    private static final long stateOffset;
    private static final long headOffset;
    private static final long tailOffset;
    private static final long waitStatusOffset;
    private static final long nextOffset;
 
    static {
        try {
            stateOffset = unsafe.objectFieldOffset
                (AbstractQueuedSynchronizer.class.getDeclaredField("state"));
            headOffset = unsafe.objectFieldOffset
                (AbstractQueuedSynchronizer.class.getDeclaredField("head"));
            tailOffset = unsafe.objectFieldOffset
                (AbstractQueuedSynchronizer.class.getDeclaredField("tail"));
            waitStatusOffset = unsafe.objectFieldOffset
                (Node.class.getDeclaredField("waitStatus"));
            nextOffset = unsafe.objectFieldOffset
                (Node.class.getDeclaredField("next"));
 
        } catch (Exception ex) { throw new Error(ex); }
    }
 
    private final boolean compareAndSetHead(Node update) {
        return unsafe.compareAndSwapObject(this, headOffset, null, update);
    }
    。。。。。一堆CAS方法。。。。。。。。
}
 
 
 
 
 
 
 
 
 
 
 
 
參考:
以下四篇:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章