手動實現簡單公平鎖,理解AQS

鎖源碼

package com.test.zwj;

import sun.misc.Unsafe;

import java.lang.reflect.Field;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.locks.LockSupport;

public class MyFairLock {

    //鎖狀態
    private int state;

    //持鎖線程
    private Thread LockHolder;

    //等待隊列(線程安全)
    private ConcurrentLinkedQueue<Thread> waiters = new ConcurrentLinkedQueue<>();

    //Unsafe
    private static final Unsafe unsafe = getUnsafe();

    //鎖狀態偏移量
    private static long stateOffset;

    static {
        try {
            stateOffset = unsafe.objectFieldOffset(MyFairLock.class.getDeclaredField("state"));
        } catch (NoSuchFieldException e) {
           throw new Error(e);
        }
    }

    /**
     * 加鎖
     * 加鎖成功後方法返回
     */
    public void lock() {
        //先加鎖,如成功則返回,
        if (acquire()){
            System.out.println(Thread.currentThread().getName() + "加鎖成功");
            return;
        };
        //加鎖失敗,則進入等待隊列,等待被喚醒,
        waiters.offer(Thread.currentThread());
        //被喚醒後重新嘗試加鎖,因此這裏使用循環,直到加鎖成功,方法才能返回
        for (;;){
            if (acquire()){
                //加鎖成功,出隊列
                System.out.println(Thread.currentThread().getName() + " 加鎖成功");
                waiters.remove(Thread.currentThread());
                return;
            }
            //睡眠一會,測試先unpark後park的情況
            /*try {
                TimeUnit.SECONDS.sleep(10);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }*/
            System.out.println(Thread.currentThread().getName() + " park");
            //阻塞當前線程,等待被喚醒
            LockSupport.park();
            //喚醒後,繼續循環
        }
    }

    /**
     * 嘗試加鎖
     * @return 是否加鎖成功
     */
    private boolean acquire(){
        if (!shouldPark() && compareAndSwapInt(0,1)){
            LockHolder = Thread.currentThread();
            return true;
        }
        //支持可重入
        if (Thread.currentThread() == LockHolder){
            state++;
            return true;
        }
        return false;
    }

    /**
     * 是否應該阻塞
     *
     * 當已存在等待線程且自己不會首位線程,則需要等待
     */
    private boolean shouldPark(){
        return waiters.size() != 0 && waiters.peek() != Thread.currentThread();
    }

    /**
     * 解鎖
     */
    public void unlock() {
        //檢驗是否是持鎖線程
        if (Thread.currentThread() != LockHolder){
            throw new RuntimeException("LockHolder is not current thread, currentThead is "
                    + Thread.currentThread().getName() + ", LockHolder thread is " + LockHolder);
        }
        //只有持鎖線程能繼續執行
        //清空持鎖線程
        LockHolder = null;
        //恢復鎖狀態,必須先清空持鎖線程再回復鎖狀態,否則可能發生:新線程設置持鎖線程後,在這兒這裏清空了
        state--;
        System.out.println(Thread.currentThread().getName() + " 釋放鎖");
        if (state == 0 ){
            //喚醒等待隊列中第一個線程
            Thread first = waiters.peek();
            if (first != null){
                //可能有人覺得這裏有bug:如果線程a進入隊列後即將park但是還沒有park,所以喚醒將無效
                //其實先unpark後park也沒有關係,此時park()將直接返回
                LockSupport.unpark(first);
                System.out.println(Thread.currentThread().getName() + " 喚醒 " + first.getName());
            }
        }
    }

    /**
     * CAS更新鎖狀態
     * @param expect 期望的當前值
     * @param update 更新後的值
     * @return 是否更新成功
     */
    private boolean compareAndSwapInt(int expect, int update){
        return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
    }

    private static Unsafe getUnsafe(){
        try {
            Field field = Unsafe.class.getDeclaredField("theUnsafe");
            field.setAccessible(true);
            return (Unsafe) field.get(null);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }
}


測試用例

package com.test.zwj;

import java.util.concurrent.CountDownLatch;

public class Main {
    private static int counter;
    private static MyFairLock lock = new MyFairLock();
    private  static int threadNum = 10000;
    private static CountDownLatch latch = new CountDownLatch(threadNum);
    public static void main(String[] args) throws Exception {
        for (int i = 0; i < threadNum; i++){
            Thread thread = new Thread(new Task());
            thread.setName("Thread-" + i);
            thread.start();
        }
        latch.await();
        System.out.println("counter is "+counter);
    }

    private static class Task implements Runnable{

        @Override
        public void run() {
           fun1();
           fun2();
        }
    }

    private static void fun1(){
        lock.lock();
        try{
            counter++;
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            lock.unlock();
        }
    }

    private static void fun2(){
        lock.lock();
        try{
            counter++;
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            latch.countDown();
            lock.unlock();
        }
    }
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章