鎖源碼
package com.test.zwj;
import sun.misc.Unsafe;
import java.lang.reflect.Field;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.locks.LockSupport;
public class MyFairLock {
private int state;
private Thread LockHolder;
private ConcurrentLinkedQueue<Thread> waiters = new ConcurrentLinkedQueue<>();
private static final Unsafe unsafe = getUnsafe();
private static long stateOffset;
static {
try {
stateOffset = unsafe.objectFieldOffset(MyFairLock.class.getDeclaredField("state"));
} catch (NoSuchFieldException e) {
throw new Error(e);
}
}
public void lock() {
if (acquire()){
System.out.println(Thread.currentThread().getName() + "加鎖成功");
return;
};
waiters.offer(Thread.currentThread());
for (;;){
if (acquire()){
System.out.println(Thread.currentThread().getName() + " 加鎖成功");
waiters.remove(Thread.currentThread());
return;
}
System.out.println(Thread.currentThread().getName() + " park");
LockSupport.park();
}
}
private boolean acquire(){
if (!shouldPark() && compareAndSwapInt(0,1)){
LockHolder = Thread.currentThread();
return true;
}
if (Thread.currentThread() == LockHolder){
state++;
return true;
}
return false;
}
private boolean shouldPark(){
return waiters.size() != 0 && waiters.peek() != Thread.currentThread();
}
public void unlock() {
if (Thread.currentThread() != LockHolder){
throw new RuntimeException("LockHolder is not current thread, currentThead is "
+ Thread.currentThread().getName() + ", LockHolder thread is " + LockHolder);
}
LockHolder = null;
state--;
System.out.println(Thread.currentThread().getName() + " 釋放鎖");
if (state == 0 ){
Thread first = waiters.peek();
if (first != null){
LockSupport.unpark(first);
System.out.println(Thread.currentThread().getName() + " 喚醒 " + first.getName());
}
}
}
private boolean compareAndSwapInt(int expect, int update){
return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
}
private static Unsafe getUnsafe(){
try {
Field field = Unsafe.class.getDeclaredField("theUnsafe");
field.setAccessible(true);
return (Unsafe) field.get(null);
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
}
測試用例
package com.test.zwj;
import java.util.concurrent.CountDownLatch;
public class Main {
private static int counter;
private static MyFairLock lock = new MyFairLock();
private static int threadNum = 10000;
private static CountDownLatch latch = new CountDownLatch(threadNum);
public static void main(String[] args) throws Exception {
for (int i = 0; i < threadNum; i++){
Thread thread = new Thread(new Task());
thread.setName("Thread-" + i);
thread.start();
}
latch.await();
System.out.println("counter is "+counter);
}
private static class Task implements Runnable{
@Override
public void run() {
fun1();
fun2();
}
}
private static void fun1(){
lock.lock();
try{
counter++;
}catch (Exception e){
e.printStackTrace();
}finally {
lock.unlock();
}
}
private static void fun2(){
lock.lock();
try{
counter++;
}catch (Exception e){
e.printStackTrace();
}finally {
latch.countDown();
lock.unlock();
}
}
}