Java實現臨界區:經典併發控制回顧

只要有一定的內存order保證,不通過比較並交換(CAS)那些需要硬件支持的原子操作,能不能實現一個互斥的臨界區?答案是:能。

計算機先驅 Edsger Wybe Dijkstra,50多年前的這篇經典paper中就提出瞭解決方案。並且自這以後開啓瞭如何通過一般編程語言實現併發控制的 研究。


這裏的假設我們有N個線程,代表序號分別爲1-N,一個公共變量k用於輔助指示當前佔有臨界區的線程。臨界區是critical section,並且內存模型是先執行的操作對後面可見,對同一個內存位置的訪問是一個接着另一個。

初始數組b[N],c[N]完全都爲true。k的初始值任意(1-N)。這裏的i變量代表當前的執行邏輯單元(線程)。

對於每個線程i,b[i]和c[i]都代表了線程i的參與競爭臨界區的意願,b[i]==false代表線程i將要參與臨界區的爭奪,線程c[i]==false代表線程i正在爭競臨界區。線程退出臨界區時,會而將b[i]、c[i]都置爲true。從而其他線程能夠通過查看當前的b[k]和c[k]來判斷線程是否仍然佔據臨界區,這裏的判斷是一個大概的判斷,由於各個線程執行順序的不確定。

存在多個線程查看b[k],從而將k設置爲自身的id,從而都進入了臨界區前的位置,但即使這樣,由於進臨界區前先要查看其他線程的c[j]值,所以這裏至多隻有一個線程進入臨界區,其他線程都退回到Li1的邏輯。存在這種情況,這裏一個線程都無法獲取臨界區,從而全部回到Li1,下一次繼續競爭。

注意:paper中的Li2,c[i] := true這一句會導致許多重複的無意義操作(因爲c[i]本來就是true),這裏針對的情況僅僅是從Li4裏面goto Li1的時候,所以我們將c[i]:=true放到goto Li1之前就能保持程序語義,並且減少了無用功。

我們用JAVA來實現一遍這個方案試試,並且用10個線程,每個進入臨界區1千萬次,每次+1來驗證它,可執行代碼如下:

package com.psly.testatomic;

import sun.misc.Unsafe;

public class TestVolatile {
	//用於內存保證:putXXVolatile/getXXVolatile
	private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
	private static final int _Obase  = _unsafe.arrayBaseOffset(int[].class);
	private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
	//N:線程數,TIMES每個線程需要進入臨界區的次數。
	private final static int N = 10;
	private final static int TIMES = 10000000;
	
	private final static int[] B = new int[N+1];
	private final static int[] C = new int[N+1];
	//每個線程進入臨界區++count,最終count == N * TIMES
	private static long count;
	//countObj:獲取count字段所屬於的對象(其實就是地址),
	private final static Object countObj;
	//countOffset:獲取count字段處於所在對象地址的偏移量
	private final static long countOffset;
	//k與上面的count字段類似
	private static int k = 1;
	private final static Object kObj; 
	private final static long kOffset;
	static{
		for(int i = 1; i <= N; ++i){
			B[i] = 1;
			C[i] = 1;
		}
		try {
			countObj = _unsafe.staticFieldBase(TestVolatile.class.getDeclaredField("count"));
			countOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("count"));
			kObj = _unsafe.staticFieldBase(TestVolatile.class.getDeclaredField("k"));
			kOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("k"));
		} catch (Exception e) {
			throw new Error(e);
		} 
	}

	final static void dijkstrasConcurMethod(int pM){
	    int times = TIMES;
        int i = pM;
    L0: for(;;){
    	B[i] = 0;
    L1: 	for(;;){
    			if( k != i ) {
    				//C[i] = 1;
    				if(B[_unsafe.getIntVolatile(kObj, kOffset)] == 1)
    					_unsafe.putIntVolatile(kObj, kOffset, i);//k = i;
    				continue L1;
    			} else{       
    				_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 0);//C[i] = 0;
    				//這裏必定會看到更新的C[i],從而根本上保證了互斥,臨界區最多一個線程。
    				for(int j = 1; j <= N; ++j ) 
    					if(j != i &&  _unsafe.getIntVolatile(C, _Obase + j * _Oscale) == 0){
    						//將C[i]的值更新回去,寫這裏效率更高
    						_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
    						continue L1;
    					}
    			}
    			break L1;
        	}
    	
    		//臨界區開始
        	long val = _unsafe.getLongVolatile(countObj, countOffset);
        	_unsafe.putLongVolatile(countObj, countOffset, val + 1);
        	//臨界區結束
        	
        	_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
        	B[i]=1;
        	if( --times != 0){
        		continue L0; //goto L0;
        	}
        	return;
        }
	}
	
	public static void main(String[] args) throws InterruptedException
	{
		//開始時間
		long start = System.currentTimeMillis();
		//打印累加器初始值
	    System.out.println( count + " initial\n");
	    Thread handle[] = new Thread[N+1];
        //創建線程
	    for (int i = 1; i <= N; ++i){
	    	int j = i;
		    handle[i] = new Thread(new Runnable(){
		    	@Override
		    	public void run(){
		    		dijkstrasConcurMethod(j);
		    	}
		    });
	    }
	    //線程開始執行
	    for (int i = 1; i <= N; ++i)
	        handle[i].start();
	    //主線程等待子線程結束
	    for (int i = 1; i <= N; ++i)
	        handle[i].join();
	    //打印累加值,== N * TIMES
	    System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
	  //打印程序執行時間
	    System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds"); 
	}
}

執行一遍,輸出爲:

0 initial

100000000
12.936 seconds

10個線程,每個進入臨界區1千萬次,總共累加爲1億。費時12.936秒。所以這個示例,至少看起來是正確的

我們接着,

重點關注dijkstrasConcurMethod這個方法:

	final static void dijkstrasConcurMethod(int pM){
	    int times = TIMES;
        int i = pM;
    L0: for(;;){
    	B[i] = 0;
    L1: 	for(;;){
    			if( k != i ) {
    				//C[i] = 1;
    				if(B[_unsafe.getIntVolatile(kObj, kOffset)] == 1)
    					_unsafe.putIntVolatile(kObj, kOffset, i);//k = i;
    				continue L1;
    			} else{       
    				_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 0);//C[i] = 0;
    				//這裏必定會看到更新的C[i],從而根本上保證了互斥,臨界區最多一個線程。
    				for(int j = 1; j <= N; ++j ) 
    					if(j != i &&  _unsafe.getIntVolatile(C, _Obase + j * _Oscale) == 0){
    						//將C[i]的值更新回去,寫這裏效率更高
    						_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
    						continue L1;
    					}
    			}
    			break L1;
        	}
    	
    		//臨界區開始
        	long val = _unsafe.getLongVolatile(countObj, countOffset);
        	_unsafe.putLongVolatile(countObj, countOffset, val + 1);
        	//臨界區結束
        	
        	_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
        	B[i]=1;
        	if( --times != 0){
        		continue L0; //goto L0;
        	}
        	return;
        }
	}

我們將paper中的ture/false用1/0來代替。由於JAVA中沒有goto語句,所以我們有了一個帶表情的循環for(;;)來實現一樣的功能。這裏的 pM代表了線程本身的下標,TIMES爲需要執行臨界區的次數。

其實從嚴格意義上來說這裏的程序並不完全等同於Dijkstra上面paper中的示例,paper中的共享內存要求是強一致的,也就是說任何的一個寫入操作B[i],C[i],k立刻能夠被其他線程看到。


paper發表時是1965年,那個時候對於內存模型以及硬件能力的設想可能是這樣的。但是隨着現代的計算機體系結構的發展,爲了提高程序執行的熟讀,尤其是多層緩存以及指令亂序執行的引入,使得大部分程序設計語言的模型已經不符合上面的假設了。

然而儘管如此,我們的JAVA程序加入volatile語義的操作之後,我們這個程序依然是對的。因爲保證了兩點


  1. _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 0);//C[i] = 0;
        				//這裏必定會看到更新的C[i],從而根本上保證了互斥,臨界區最多一個線程。
        				for(int j = 1; j <= N; ++j ) 
        					if(j != i &&  _unsafe.getIntVolatile(C, _Obase + j * _Oscale) == 0){
        						//將C[i]的值更新回去,寫這裏效率更高
        						_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
        						continue L1;
        					}
    保證C上面更新的值在開始探測整個C數組之前被看到。
  2.     		//臨界區開始
            	long val = _unsafe.getLongVolatile(countObj, countOffset);
            	_unsafe.putLongVolatile(countObj, countOffset, val + 1);
            	//臨界區結束
            	
            	_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
    保證離開臨界區之後纔將C[i]更新回1,從而防止這個1過早泄露出來,從而導致前面循環探測的失誤。
我們接着來看第二篇paper,由於篇幅短,可以直接貼出來:

只是將原來的N個執行單元簡化成了2個,從而更好理解。這篇paper的算法是錯誤的,可以自行推導下。

我們接着來看第三篇paper,也是出自另一位圖靈獎得住、著名計算機科學家Donald Ervin Knuth
算法如下:

他的想法是,只採用一個control(初始化爲0)的環形線程id數組,一個k用於指示臨界區id。思想是:
  • 首先從k開始遍歷到自己的id(i),假如發現一個control(j)!=0,說明前面已經有線程在競爭了,所以我們goto返回。否則從k到前一個id的control都爲0,那麼我們就進入第二步。
  • 第二步首先將contrl值設置爲2,說明已經進一步競爭了,此時依然可能有多個線程到達此處,所以接下來,我們採用與Dijkstra類似的探測排除方法,最多可以得到一個進入下一步的線程。
  • 第三步,將k的值設置爲當前id,進入臨界區。
  • 第四部,從臨界區出來之後,將k值設置爲當前id右邊→_→的一個id,如此一來很可能形成環形的執行順序。最後將control[i]設置爲0。
  • 最後返回。 注意, 這裏的k設置是沒有競爭的 k:=if i = 1 then N else i -1;是爲了儘量讓右邊一個線程執行,但是極端情況下依然可能被其他線程獲取鎖。所以還是得有L3: k := i; 這一行。
我們也一樣採用JAVA來完成這個算法,
可執行代碼如下:
package com.psly.testatomic;

import java.util.Random;

import com.psly.locksupprot.LockSupport;

import sun.misc.Unsafe;

public class TestVolatileKnuthMethod {
	private final static Random random = new Random();

	//用於內存保證:putXXVolatile/getXXVolatile
	private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
	private static final int _Obase  = _unsafe.arrayBaseOffset(int[].class);
	private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
	//N:線程數,TIMES每個線程需要進入臨界區的次數。
	private final static int N = 5;
	private final static int TIMES = 1000;
	
	private final static int[] B = new int[N+1];
	private final static int[] C = new int[N+1];
	//knuth's method
	private final static int[] control = new int[N+1];
	
	//每個線程進入臨界區++count,最終count == N * TIMES
	private static long count;
	//countObj:獲取count字段所屬於的對象(其實就是地址),
	private final static Object countObj;
	//countOffset:獲取count字段處於所在對象地址的偏移量
	private final static long countOffset;
	//k與上面的count字段類似
	private static int k = 1;
	private final static Object kObj; 
	private final static long kOffset;
	
	static{
		for(int i = 1; i <= N; ++i){
			B[i] = 1;
			C[i] = 1;
		}
		try {
			countObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("count"));
			countOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("count"));
			kObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("k"));
			kOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("k"));
		} catch (Exception e) {
			throw new Error(e);
		} 
	}
	private static Object obj = new Object();

	final static void knuthConcurMethod(int pM){
	    int times = TIMES;
        int i = pM;
        
    L0: for(;;){
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
    L1:		for(;;){
	    		for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
	    			if(j == i)
	    				break L1;
		    		if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){ 
		    			continue L1;
		    		}
	    		}
	    		for(int j = N; j >= 1; --j){
	    			if(j == i)
	    				break L1;
		    		if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){ 
		    			continue L1;
		    		}
	    		}
    		}
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
    		for(int j = N; j >= 1; --j){
    			if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
    				continue L0;
    			}
    		}
    		_unsafe.putIntVolatile(kObj, kOffset, i);

    		long val = _unsafe.getLongVolatile(countObj, countOffset);
        	_unsafe.putLongVolatile(countObj, countOffset, val + 1);
    		_unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i -1);
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
            
    		if( --times != 0)
    			continue L0;
    		return ;
    	}
	}
	private static Thread[] handle = new Thread[N+1];
	public static void main(String[] args) throws InterruptedException
	{
		//開始時間
		long start = System.currentTimeMillis();
		//打印累加器初始值
	    System.out.println( count + " initial\n");
        //創建線程
	    for (int i = 1; i <= N; ++i){
	    	int j = i;
		    handle[i] = new Thread(new Runnable(){
		    	@Override
		    	public void run(){
		    		knuthConcurMethod(j);
		    	}
		    });
	    }
	    //線程開始執行
	    for (int i = 1; i <= N; ++i)
	        handle[i].start();
	    //主線程等待子線程結束
	    for (int i = 1; i <= N; ++i)
	        handle[i].join();
	    //打印累加值,== N * TIMES
	    System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
	  //打印程序執行時間
	    System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds"); 
	}

}

輸出如下:
0 initial

5000
7.464 seconds

5個線程,每個1000次,7.46秒。 可以看出,儘管公平性得到了保證,但是這樣的效率較低,因爲環形數組中多餘的線程一直在佔有CPU資源。knuth的paper中也說要採用queue之類的方式提升效率,


我們這裏採用另外的辦法,想個辦法讓它休眠,然後等到必要的時候喚醒他。剛好java本身提供了park/unpark接口,並且我們這裏的線程數組是固定的。所以可以直接採用。

在上面的示例中,添加如下代碼:
喚醒:
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
    		
            int j = (i == 1)? N : i -1;
            for(int m = 0; m < N - 1; ++m){
            	if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
            		break;
            	if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
            		LockSupport.unpark(handle[j]);
            		break;
            	}
            	j = (j == 1)? N : j -1;
            }
            
    		if( --times != 0)

休眠:
		    		if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){ 
		    			LockSupport.park(obj);
		    			continue L1;
		    		}

然後得到我的阻塞版本,可執行代碼爲:
package com.psly.testatomic;

import com.psly.locksupprot.LockSupport;

import sun.misc.Unsafe;

public class TestVolatileKnuthMethod {

	//用於內存保證:putXXVolatile/getXXVolatile
	private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
	private static final int _Obase  = _unsafe.arrayBaseOffset(int[].class);
	private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
	//N:線程數,TIMES每個線程需要進入臨界區的次數。
	private final static int N = 5;
	private final static int TIMES = 1000;
	
	private final static int[] B = new int[N+1];
	private final static int[] C = new int[N+1];
	//knuth's method
	private final static int[] control = new int[N+1];
	
	//每個線程進入臨界區++count,最終count == N * TIMES
	private static long count;
	//countObj:獲取count字段所屬於的對象(其實就是地址),
	private final static Object countObj;
	//countOffset:獲取count字段處於所在對象地址的偏移量
	private final static long countOffset;
	//k與上面的count字段類似
	private static int k = 1;
	private final static Object kObj; 
	private final static long kOffset;
	
	static{
		for(int i = 1; i <= N; ++i){
			B[i] = 1;
			C[i] = 1;
		}
		try {
			countObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("count"));
			countOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("count"));
			kObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("k"));
			kOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("k"));
		} catch (Exception e) {
			throw new Error(e);
		} 
	}
	private static Object obj = new Object();

	final static void knuthConcurMethod(int pM){
	    int times = TIMES;
        int i = pM;
        
    L0: for(;;){
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
    L1:		for(;;){
	    		for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
	    			if(j == i)
	    				break L1;
		    		if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){ 
		    			LockSupport.park(obj);
		    			continue L1;
		    		}
	    		}
	    		for(int j = N; j >= 1; --j){
	    			if(j == i)
	    				break L1;
		    		if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){ 
		    			LockSupport.park(obj);
		    			continue L1;
		    		}
	    		}
    		}
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
    		for(int j = N; j >= 1; --j){
    			if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
    				continue L0;
    			}
    		}
                //臨界區開始
    		_unsafe.putIntVolatile(kObj, kOffset, i);
               
    		long val = _unsafe.getLongVolatile(countObj, countOffset);
        	_unsafe.putLongVolatile(countObj, countOffset, val + 1);
    		_unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i -1); //臨界區結束
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
    		
            int j = (i == 1)? N : i -1;
            for(int m = 0; m < N - 1; ++m){
            	if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
            		break;
            	if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
            		LockSupport.unpark(handle[j]);
            		break;
            	}
            	j = (j == 1)? N : j -1;
            }
            
    		if( --times != 0)
    			continue L0;
    		return ;
    	}
	}
	private static Thread[] handle = new Thread[N+1];
	public static void main(String[] args) throws InterruptedException
	{
		//開始時間
		long start = System.currentTimeMillis();
		//打印累加器初始值
	    System.out.println( count + " initial\n");
        //創建線程
	    for (int i = 1; i <= N; ++i){
	    	int j = i;
		    handle[i] = new Thread(new Runnable(){
		    	@Override
		    	public void run(){
		    		knuthConcurMethod(j);
		    	}
		    });
	    }
	    //線程開始執行
	    for (int i = 1; i <= N; ++i)
	        handle[i].start();
	    //主線程等待子線程結束
	    for (int i = 1; i <= N; ++i)
	        handle[i].join();
	    //打印累加值,== N * TIMES
	    System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
	  //打印程序執行時間
	    System.out.println((System.currentTimeMillis() - start) / 1000.0 + " milliseconds"); 
	}

}
 
輸出:
0 initial

5000
0.043 milliseconds


5個線程,每個1000次,0.043秒。

我們再嘗試下更多的操作,N=100,TIMES=5000.
輸出:
0 initial

500000
2.938 seconds

100個線程,每個進入臨界區5000次,總共2.938秒,這比輪詢的版本好多啦。

再看下我們的Java代碼主要邏輯:
	final static void knuthConcurMethod(int pM){
	    int times = TIMES;
        int i = pM;
        
    L0: for(;;){
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
    L1:		for(;;){
	    		for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
	    			if(j == i)
	    				break L1;
		    		if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){ 
		    			LockSupport.park(obj);
		    			continue L1;
		    		}
	    		}
	    		for(int j = N; j >= 1; --j){
	    			if(j == i)
	    				break L1;
		    		if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){ 
		    			LockSupport.park(obj);
		    			continue L1;
		    		}
	    		}
    		}
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
    		for(int j = N; j >= 1; --j){
    			if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
    				continue L0;
    			}
    		}
    		_unsafe.putIntVolatile(kObj, kOffset, i);

    		long val = _unsafe.getLongVolatile(countObj, countOffset);
        	_unsafe.putLongVolatile(countObj, countOffset, val + 1);
    		_unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i -1);
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
    		
            int j = (i == 1)? N : i -1;
            for(int m = 0; m < N - 1; ++m){
            	if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
            		break;
            	if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
            		LockSupport.unpark(handle[j]);
            		break;
            	}
            	j = (j == 1)? N : j -1;
            }
            
    		if( --times != 0)
    			continue L0;
    		return ;
    	}
	}


這裏重點是喚醒的邏輯:
向右遍歷過程中,喚醒必要的一個就行,甚至於不需要喚醒。我們儘量少做事情。
       int j = (i == 1)? N : i -1;  
           for(int m = 0; m < N - 1; ++m){  
            if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)  
                break;  
            if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){  
                LockSupport.unpark(handle[j]);  
                break;  
            }  
            j = (j == 1)? N : j -1;  
           }  


好了,我們最後來看最篇與之相關的paper:



一個叫N. G. de Bruijn再次稍微改了Knuth的方法。

優勢是能夠更清楚得看清執行的邏輯,一點細微的改變是k的值並不隨着線程進入臨界區而設置的。據說理論上一個線程究竟需要多少次才能輪到執行,這個次數的上界減少了,只不過沒看懂他假設的前提是什麼。
我們對應的java改變如下:




這種方法據說能夠給每個線程提供一個下次執行臨界區前的最大上限數量turn。

可執行代碼如下(可阻塞版本):
package com.psly.testatomic;

import java.text.SimpleDateFormat;
import java.util.Date;

import com.psly.locksupprot.LockSupport;

import sun.misc.Unsafe;

public class TestVolatileBruijnMethod {

	//用於內存保證:putXXVolatile/getXXVolatile
	private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
	private static final int _Obase  = _unsafe.arrayBaseOffset(int[].class);
	private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
	//N:線程數,TIMES每個線程需要進入臨界區的次數。
	private final static int N = 100;
	private final static int TIMES = 5000;
	
	private final static int[] B = new int[N+1];
	private final static int[] C = new int[N+1];
	//knuth's method
	private final static int[] control = new int[N+1];
	
	//每個線程進入臨界區++count,最終count == N * TIMES
	private static long count;
	//countObj:獲取count字段所屬於的對象(其實就是地址),
	private final static Object countObj;
	//countOffset:獲取count字段處於所在對象地址的偏移量
	private final static long countOffset;
	//k與上面的count字段類似
	private static int k = 1;
	private final static Object kObj; 
	private final static long kOffset;
	
	static{
		for(int i = 1; i <= N; ++i){
			B[i] = 1;
			C[i] = 1;
		}
		try {
			countObj = _unsafe.staticFieldBase(TestVolatileBruijnMethod.class.getDeclaredField("count"));
			countOffset = _unsafe.staticFieldOffset(TestVolatileBruijnMethod.class.getDeclaredField("count"));
			kObj = _unsafe.staticFieldBase(TestVolatileBruijnMethod.class.getDeclaredField("k"));
			kOffset = _unsafe.staticFieldOffset(TestVolatileBruijnMethod.class.getDeclaredField("k"));
		} catch (Exception e) {
			throw new Error(e);
		} 
	}
	private static Object obj = new Object();

	final static void knuthConcurMethod(int pM){
	    int times = TIMES;
        int i = pM;
        
    L0: for(;;){
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
    L1:		for(;;){
	    		for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
	    			if(j == i)
	    				break L1;
		    		if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){ 
		    			LockSupport.park(obj);
		    			continue L1;
		    		}
	    		}
	    		for(int j = N; j >= 1; --j){
	    			if(j == i)
	    				break L1;
		    		if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){ 
		    			LockSupport.park(obj);
		    			continue L1;
		    		}
	    		}
    		}
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
    		for(int j = N; j >= 1; --j){
    			if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
    				continue L0;
    			}
    		}
//    		_unsafe.putIntVolatile(kObj, kOffset, i);
    		int kLocal = _unsafe.getIntVolatile(kObj, kOffset);
    		int kNew = kLocal;
    		long val = _unsafe.getLongVolatile(countObj, countOffset);
        	_unsafe.putLongVolatile(countObj, countOffset, val + 1);
        	if(_unsafe.getIntVolatile(control, _Obase + kLocal * _Oscale) == 0 || kLocal == i)
    			_unsafe.putIntVolatile(kObj, kOffset, kNew = ((kLocal == 1)? N: kLocal - 1));
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
    		
            int j = kNew;
            for(int m = 0; m < N; ++m){
            	if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
            		break;
            	if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
            		LockSupport.unpark(handle[j]);
            		break;
            	}
            	j = (j == 1)? N : j -1;
            }
            
    		if( --times != 0)
    			continue L0;
    		return ;
    	}
	}
	private static Thread[] handle = new Thread[N+1];
	public static void main(String[] args) throws InterruptedException
	{
		System.out.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
				.format(new Date()));
		//開始時間
		long start = System.currentTimeMillis();
		//打印累加器初始值
	    System.out.println( count + " initial\n");
        //創建線程
	    for (int i = 1; i <= N; ++i){
	    	int j = i;
		    handle[i] = new Thread(new Runnable(){
		    	@Override
		    	public void run(){
		    		knuthConcurMethod(j);
		    	}
		    });
	    }
	    //線程開始執行
	    for (int i = 1; i <= N; ++i)
	        handle[i].start();
	    //主線程等待子線程結束
	    for (int i = 1; i <= N; ++i)
	        handle[i].join();
	    //打印累加值,== N * TIMES
	    System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
	  //打印程序執行時間
	    System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds"); 
	}

}



接着再後來的1972年,The MITRE Corporation的Murray A. Eisenberg and Michael R. McGuire又提出了據說協調性更好的算法,如下:


你會看到這個算法中的goto語句更多了,也更復雜了。

我們同樣也給出可執行的JAVA代碼:
package com.psly.testatomic;

import java.text.SimpleDateFormat;
import java.util.Date;

import com.psly.locksupprot.LockSupport;

import sun.misc.Unsafe;

public class TestVolatileEisenbergMethod {

	//用於內存保證:putXXVolatile/getXXVolatile
	private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
	private static final int _Obase  = _unsafe.arrayBaseOffset(int[].class);
	private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
	//N:線程數,TIMES每個線程需要進入臨界區的次數。
	private final static int N = 100;
	private final static int TIMES = 5000;
	
	private final static int[] B = new int[N+1];
	private final static int[] C = new int[N+1];
	//knuth's method
	private final static int[] control = new int[N+1];
	
	//每個線程進入臨界區++count,最終count == N * TIMES
	private static long count;
	//countObj:獲取count字段所屬於的對象(其實就是地址),
	private final static Object countObj;
	//countOffset:獲取count字段處於所在對象地址的偏移量
	private final static long countOffset;
	//k與上面的count字段類似
	private static int k = 1;
	private final static Object kObj; 
	private final static long kOffset;
	
	static{
		for(int i = 1; i <= N; ++i){
			B[i] = 1;
			C[i] = 1;
		}
		try {
			countObj = _unsafe.staticFieldBase(TestVolatileEisenbergMethod.class.getDeclaredField("count"));
			countOffset = _unsafe.staticFieldOffset(TestVolatileEisenbergMethod.class.getDeclaredField("count"));
			kObj = _unsafe.staticFieldBase(TestVolatileEisenbergMethod.class.getDeclaredField("k"));
			kOffset = _unsafe.staticFieldOffset(TestVolatileEisenbergMethod.class.getDeclaredField("k"));
		} catch (Exception e) {
			throw new Error(e);
		} 
	}
	private static Object obj = new Object();

	final static void EisenbergConcurMethod(int pM){
	    int times = TIMES;
        int i = pM;
        
    L0: for(;;){
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
    L1:		for(;;){
    			int kLocal;
	    		for(int j = (kLocal = _unsafe.getIntVolatile(kObj, kOffset)); j <= N; ++j){
	    			if(j == i)
	    				break L1;
		    		if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){ 
		    			LockSupport.park(obj);
		    			continue L1;
		    		}
	    		}
	    		for(int j = 1; j <= kLocal - 1; ++j){
	    			if(j == i)
	    				break L1;
		    		if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){ 
		    			LockSupport.park(obj);
		    			continue L1;
		    		}
	    		}
    		}
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
    		for(int j = 1; j <= N; ++j){
    			if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
    				continue L0;
    			}
    		}
    		int kLocal;
    		if(_unsafe.getIntVolatile(control, _Obase + (kLocal = _unsafe.getIntVolatile(kObj, kOffset)) *_Oscale ) != 0 
    				&& kLocal != i)
    			continue L0;
    		
    		_unsafe.putIntVolatile(kObj, kOffset, i);
    		long val = _unsafe.getLongVolatile(countObj, countOffset);
        	_unsafe.putLongVolatile(countObj, countOffset, val + 1);
 //       	System.out.println(Thread.currentThread().getName());
        	
        	int kNew = i;
     L2:  	for(;;){
	        	for(int j = i + 1; j <= N; ++j){
	        		if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
	        			_unsafe.putIntVolatile(kObj, kOffset, j);
	        //			LockSupport.unpark(handle[j]);
	        			kNew = j;
	        			break L2;
	        		}
	        	}
	    		for(int j = 1; j <= i - 1; ++j){
	        		if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
	        			_unsafe.putIntVolatile(kObj, kOffset, j);
	        //			LockSupport.unpark(handle[j]);
	        			kNew = j;
	        			break L2;
	        		}
	    		}
	    		break;
        	}
    		
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
    		
            int j = kNew;
            for(int m = 0; m < N; ++m){
            	if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
            		break;
            	if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
            		LockSupport.unpark(handle[j]);
            		break;
            	}
            	j = (j == N)? 1 : j + 1;
            }
            
    		if( --times != 0)
    			continue L0;
    		return ;
    	}
	}
	private static Thread[] handle = new Thread[N+1];
	public static void main(String[] args) throws InterruptedException
	{
		System.out.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
				.format(new Date()));
		//開始時間
		long start = System.currentTimeMillis();
		//打印累加器初始值
	    System.out.println( count + " initial\n");
        //創建線程
	    for (int i = 1; i <= N; ++i){
	    	int j = i;
		    handle[i] = new Thread(new Runnable(){
		    	@Override
		    	public void run(){
		    		EisenbergConcurMethod(j);
		    	}
		    });
	    }
	    //線程開始執行
	    for (int i = 1; i <= N; ++i)
	        handle[i].start();
	    //主線程等待子線程結束
	    for (int i = 1; i <= N; ++i)
	        handle[i].join();
	    //打印累加值,== N * TIMES
	    System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
	  //打印程序執行時間
	    System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds"); 
	}

}
同樣,他們的執行效率並沒有很大的差別。
但是採用的想法卻都各不相同。

最後我們來看一遍, 爲什麼後面三個算法能夠實現臨界區:

    L1:		for(;;){ //以下兩個循環的代碼判斷當前線程是否適合競爭臨界區
	    		for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
	    			if(j == i)
	    				break L1;
		    		if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){ 
		   // 			LockSupport.park(obj);
		    			continue L1;
		    		}
	    		}
	    		for(int j = N; j >= 1; --j){
	    			if(j == i)
	    				break L1;
		    		if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){ 
		   // 			LockSupport.park(obj);
		    			continue L1;
		    		}
	    		} //以上兩個循環的代碼判斷當前線程是否適合競爭臨界區
    		}
    		//以下代碼保證最多一個線程進去臨界區
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
    		for(int j = N; j >= 1; --j){
    			if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
    				continue L0;
    			}
    		}
        	//以上代碼保證最多一個線程進入臨界區
    		_unsafe.putIntVolatile(kObj, kOffset, i);
    		//臨界區start
    		long val = _unsafe.getLongVolatile(countObj, countOffset);
        	_unsafe.putLongVolatile(countObj, countOffset, val + 1);
    		_unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i - 1);
    		//臨界區end
    		_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
  • 先通過兩個循環來判斷當前線程是否適合競爭鎖,適合跳出L1,否則繼續循環
  • 接着第二個循環通過探測其他線程的control值,假如發現都不爲0則結束循環,獲得鎖,否則跳回L0,繼續前面的循環判斷。注意這裏的語義確保最多隻有一個線程進入臨界區,存在全部線程都無法獲得鎖,跳回L0的極端情況。
  • 臨界區結尾處將0給control[i],替換掉了它的2值,從而之後,讓其他線程有機會獲得鎖(根據競爭判斷的語義,假如一個線程看到其他的某個爲2是無法獲取鎖的)。
Over

附上:

package com.psly.testatomic;

import java.lang.reflect.Field;

import sun.misc.Unsafe;

public class UtilUnsafe {
  private UtilUnsafe() { } // dummy private constructor
  /** Fetch the Unsafe.  Use With Caution. */
  public static Unsafe getUnsafe() {
    // Not on bootclasspath
    if( UtilUnsafe.class.getClassLoader() == null )
      return Unsafe.getUnsafe();
    try {
      final Field fld = Unsafe.class.getDeclaredField("theUnsafe");
      fld.setAccessible(true);
      return (Unsafe) fld.get(UtilUnsafe.class);
    } catch (Exception e) {
      throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
    }
  }
}
package com.psly.locksupprot;

import com.psly.testatomic.UtilUnsafe;

public class LockSupport {
    private LockSupport() {} // Cannot be instantiated.

    private static void setBlocker(Thread t, Object arg) {
        // Even though volatile, hotspot doesn't need a write barrier here.
        UNSAFE.putObject(t, parkBlockerOffset, arg);
    }

    /**
     * Makes available the permit for the given thread, if it
     * was not already available.  If the thread was blocked on
     * {@code park} then it will unblock.  Otherwise, its next call
     * to {@code park} is guaranteed not to block. This operation
     * is not guaranteed to have any effect at all if the given
     * thread has not been started.
     *
     * @param thread the thread to unpark, or {@code null}, in which case
     *        this operation has no effect
     */
    public static void unpark(Thread thread) {
        if (thread != null)
            UNSAFE.unpark(thread);
    }

    /**
     * Disables the current thread for thread scheduling purposes unless the
     * permit is available.
     *
     * <p>If the permit is available then it is consumed and the call returns
     * immediately; otherwise
     * the current thread becomes disabled for thread scheduling
     * purposes and lies dormant until one of three things happens:
     *
     * <ul>
     * <li>Some other thread invokes {@link #unpark unpark} with the
     * current thread as the target; or
     *
     * <li>Some other thread {@linkplain Thread#interrupt interrupts}
     * the current thread; or
     *
     * <li>The call spuriously (that is, for no reason) returns.
     * </ul>
     *
     * <p>This method does <em>not</em> report which of these caused the
     * method to return. Callers should re-check the conditions which caused
     * the thread to park in the first place. Callers may also determine,
     * for example, the interrupt status of the thread upon return.
     *
     * @param blocker the synchronization object responsible for this
     *        thread parking
     * @since 1.6
     */
    public static void park(Object blocker) {
        Thread t = Thread.currentThread();
        setBlocker(t, blocker);
        UNSAFE.park(false, 0L);
        setBlocker(t, null);
    }

    /**
     * Disables the current thread for thread scheduling purposes, for up to
     * the specified waiting time, unless the permit is available.
     *
     * <p>If the permit is available then it is consumed and the call
     * returns immediately; otherwise the current thread becomes disabled
     * for thread scheduling purposes and lies dormant until one of four
     * things happens:
     *
     * <ul>
     * <li>Some other thread invokes {@link #unpark unpark} with the
     * current thread as the target; or
     *
     * <li>Some other thread {@linkplain Thread#interrupt interrupts}
     * the current thread; or
     *
     * <li>The specified waiting time elapses; or
     *
     * <li>The call spuriously (that is, for no reason) returns.
     * </ul>
     *
     * <p>This method does <em>not</em> report which of these caused the
     * method to return. Callers should re-check the conditions which caused
     * the thread to park in the first place. Callers may also determine,
     * for example, the interrupt status of the thread, or the elapsed time
     * upon return.
     *
     * @param blocker the synchronization object responsible for this
     *        thread parking
     * @param nanos the maximum number of nanoseconds to wait
     * @since 1.6
     */
    public static void parkNanos(Object blocker, long nanos) {
        if (nanos > 0) {
            Thread t = Thread.currentThread();
            setBlocker(t, blocker);
            UNSAFE.park(false, nanos);
            setBlocker(t, null);
        }
    }

    /**
     * Disables the current thread for thread scheduling purposes, until
     * the specified deadline, unless the permit is available.
     *
     * <p>If the permit is available then it is consumed and the call
     * returns immediately; otherwise the current thread becomes disabled
     * for thread scheduling purposes and lies dormant until one of four
     * things happens:
     *
     * <ul>
     * <li>Some other thread invokes {@link #unpark unpark} with the
     * current thread as the target; or
     *
     * <li>Some other thread {@linkplain Thread#interrupt interrupts} the
     * current thread; or
     *
     * <li>The specified deadline passes; or
     *
     * <li>The call spuriously (that is, for no reason) returns.
     * </ul>
     *
     * <p>This method does <em>not</em> report which of these caused the
     * method to return. Callers should re-check the conditions which caused
     * the thread to park in the first place. Callers may also determine,
     * for example, the interrupt status of the thread, or the current time
     * upon return.
     *
     * @param blocker the synchronization object responsible for this
     *        thread parking
     * @param deadline the absolute time, in milliseconds from the Epoch,
     *        to wait until
     * @since 1.6
     */
    public static void parkUntil(Object blocker, long deadline) {
        Thread t = Thread.currentThread();
        setBlocker(t, blocker);
        UNSAFE.park(true, deadline);
        setBlocker(t, null);
    }

    /**
     * Returns the blocker object supplied to the most recent
     * invocation of a park method that has not yet unblocked, or null
     * if not blocked.  The value returned is just a momentary
     * snapshot -- the thread may have since unblocked or blocked on a
     * different blocker object.
     *
     * @param t the thread
     * @return the blocker
     * @throws NullPointerException if argument is null
     * @since 1.6
     */
    public static Object getBlocker(Thread t) {
        if (t == null)
            throw new NullPointerException();
        return UNSAFE.getObjectVolatile(t, parkBlockerOffset);
    }

    /**
     * Disables the current thread for thread scheduling purposes unless the
     * permit is available.
     *
     * <p>If the permit is available then it is consumed and the call
     * returns immediately; otherwise the current thread becomes disabled
     * for thread scheduling purposes and lies dormant until one of three
     * things happens:
     *
     * <ul>
     *
     * <li>Some other thread invokes {@link #unpark unpark} with the
     * current thread as the target; or
     *
     * <li>Some other thread {@linkplain Thread#interrupt interrupts}
     * the current thread; or
     *
     * <li>The call spuriously (that is, for no reason) returns.
     * </ul>
     *
     * <p>This method does <em>not</em> report which of these caused the
     * method to return. Callers should re-check the conditions which caused
     * the thread to park in the first place. Callers may also determine,
     * for example, the interrupt status of the thread upon return.
     */
    public static void park() {
        UNSAFE.park(false, 0L);
    }

    /**
     * Disables the current thread for thread scheduling purposes, for up to
     * the specified waiting time, unless the permit is available.
     *
     * <p>If the permit is available then it is consumed and the call
     * returns immediately; otherwise the current thread becomes disabled
     * for thread scheduling purposes and lies dormant until one of four
     * things happens:
     *
     * <ul>
     * <li>Some other thread invokes {@link #unpark unpark} with the
     * current thread as the target; or
     *
     * <li>Some other thread {@linkplain Thread#interrupt interrupts}
     * the current thread; or
     *
     * <li>The specified waiting time elapses; or
     *
     * <li>The call spuriously (that is, for no reason) returns.
     * </ul>
     *
     * <p>This method does <em>not</em> report which of these caused the
     * method to return. Callers should re-check the conditions which caused
     * the thread to park in the first place. Callers may also determine,
     * for example, the interrupt status of the thread, or the elapsed time
     * upon return.
     *
     * @param nanos the maximum number of nanoseconds to wait
     */
    public static void parkNanos(long nanos) {
        if (nanos > 0)
            UNSAFE.park(false, nanos);
    }

    /**
     * Disables the current thread for thread scheduling purposes, until
     * the specified deadline, unless the permit is available.
     *
     * <p>If the permit is available then it is consumed and the call
     * returns immediately; otherwise the current thread becomes disabled
     * for thread scheduling purposes and lies dormant until one of four
     * things happens:
     *
     * <ul>
     * <li>Some other thread invokes {@link #unpark unpark} with the
     * current thread as the target; or
     *
     * <li>Some other thread {@linkplain Thread#interrupt interrupts}
     * the current thread; or
     *
     * <li>The specified deadline passes; or
     *
     * <li>The call spuriously (that is, for no reason) returns.
     * </ul>
     *
     * <p>This method does <em>not</em> report which of these caused the
     * method to return. Callers should re-check the conditions which caused
     * the thread to park in the first place. Callers may also determine,
     * for example, the interrupt status of the thread, or the current time
     * upon return.
     *
     * @param deadline the absolute time, in milliseconds from the Epoch,
     *        to wait until
     */
    public static void parkUntil(long deadline) {
        UNSAFE.park(true, deadline);
    }

    /**
     * Returns the pseudo-randomly initialized or updated secondary seed.
     * Copied from ThreadLocalRandom due to package access restrictions.
     */
    static final int nextSecondarySeed() {
        int r;
        Thread t = Thread.currentThread();
        if ((r = UNSAFE.getInt(t, SECONDARY)) != 0) {
            r ^= r << 13;   // xorshift
            r ^= r >>> 17;
            r ^= r << 5;
        }
        else if ((r = java.util.concurrent.ThreadLocalRandom.current().nextInt()) == 0)
            r = 1; // avoid zero
        UNSAFE.putInt(t, SECONDARY, r);
        return r;
    }

    // Hotspot implementation via intrinsics API
    private static final sun.misc.Unsafe UNSAFE;
    private static final long parkBlockerOffset;
    private static final long SEED;
    private static final long PROBE;
    private static final long SECONDARY;
    static {
        try {
            UNSAFE = UtilUnsafe.getUnsafe();
            Class<?> tk = Thread.class;
            parkBlockerOffset = UNSAFE.objectFieldOffset
                (tk.getDeclaredField("parkBlocker"));
            SEED = UNSAFE.objectFieldOffset
                (tk.getDeclaredField("threadLocalRandomSeed"));
            PROBE = UNSAFE.objectFieldOffset
                (tk.getDeclaredField("threadLocalRandomProbe"));
            SECONDARY = UNSAFE.objectFieldOffset
                (tk.getDeclaredField("threadLocalRandomSecondarySeed"));
        } catch (Exception ex) { throw new Error(ex); }
    }

}


最近發現,針對這個獨佔互斥區的併發控制,2013年圖靈獎得主Leslie Lamport在1974年也提出過另一種算法,paper截圖如下:


證明過程:



這個算法的特點是,沒有中心控制。

我們用JAVA代碼實現下:

package com.psly.testatomic;

import sun.misc.Unsafe;

public class TestVolatile {
	//用於內存保證:putXXVolatile/getXXVolatile
	private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
	private static final int _Obase  = _unsafe.arrayBaseOffset(long[].class);
	private static final int _Oscale = _unsafe.arrayIndexScale(long[].class);
	//N:線程數,TIMES每個線程需要進入臨界區的次數。
	private final static int N = 2000;
	private final static int TIMES = 1000;
	
	private final static long[] choosing = new long[N+1];
	private final static long[] number = new long[N+1];
	//每個線程進入臨界區++count,最終count == N * TIMES
	private static long count;
	//countObj:獲取count字段所屬於的對象(其實就是地址),
	private final static Object mainObj;
	//countOffset:獲取count字段處於所在對象地址的偏移量
	private final static long countOffset;
	private static Object obj = new Object(); 
	
//	private static Queue<Thread> queues = new ConcurrentLinkedQueue();

	static{
		for(int i = 1; i <= N; ++i){
			choosing[i] = 0;
			number[i] = 0;
		}
		try {
			mainObj = _unsafe.staticFieldBase(TestVolatile.class.getDeclaredField("count"));
			countOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("count"));
    //        waitersOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("waiters"));
		} catch (Exception e) {
			throw new Error(e);
		} 
	}

    
	final static void dijkstrasConcurMethod(int pM){
	    int times = TIMES;
        int i = pM;
    L0: for(;;){
    		_unsafe.putLongVolatile(choosing, _Obase + i * _Oscale, 1);
    		//獲取最大的number並+1。
    		long maxNum = _unsafe.getLongVolatile(number, _Obase + _Oscale), midNum;
    		for(int j = 2; j <= N; ++j)
    			if(maxNum < (midNum = _unsafe.getLongVolatile(number, _Obase + j * _Oscale)))
    				maxNum = midNum;
    		_unsafe.putLongVolatile(number, _Obase + i * _Oscale, 1 + maxNum);
    		_unsafe.putLongVolatile(choosing, _Obase + i * _Oscale, 0);
   /* 		for(int j = 1; j <i; ++j)
    			LockSupport.unpark(handle[j]);
    		for(int j = i+1; j <= N; ++j)
    			LockSupport.unpark(handle[j]);*/
    		long jNumber, iNumber;
    		for(int j = 1; j <= N; ++j){
    	L1:		for(;;){
    				for(int k = 0 ; k < 100; ++k)
		    			if(!(_unsafe.getLongVolatile(choosing, _Obase + j * _Oscale) != 0))
		    				 break L1;
    		//		LockSupport.park(obj);
    			}
    	L2:		for(;;){
    				for(int k = 0; k < 1000; ++k)
		    			if(!(_unsafe.getLongVolatile(number, _Obase + j * _Oscale) != 0 
		    					&& ((jNumber=_unsafe.getLongVolatile(number, _Obase + j * _Oscale)) 
		    							< (iNumber=_unsafe.getLongVolatile(number, _Obase + i * _Oscale)) 
		    							|| (jNumber == iNumber && j < i))))
		    				break L2;
    				LockSupport.park(obj);
    			}
    		}
    		//critical section
    		 //臨界區開始  
            long val = _unsafe.getLongVolatile(mainObj, countOffset);  
            _unsafe.putLongVolatile(mainObj, countOffset, val + 1); 
            //臨界區結束
            
            //設置標識
            _unsafe.putLongVolatile(number, _Obase + i * _Oscale, 0);
            //喚醒需要的線程
            Thread target = handle[i];
            long numMax = Long.MAX_VALUE, arg;
            
      		for(int j = 1; j <i; ++j)
      			if((arg = _unsafe.getLongVolatile(number, _Obase + j * _Oscale)) != 0 && arg < numMax)
      				{ target = handle[j]; numMax = arg;}
    		for(int j = i+1; j <= N; ++j)
    			if((arg = _unsafe.getLongVolatile(number, _Obase + j * _Oscale)) != 0 && arg < numMax)
  					{ target = handle[j]; numMax = arg;}
    		LockSupport.unpark(target);
            /*for(int j = 1; j <= N; ++j)
            	LockSupport.unpark(handle[j]);*/
    		//計算次數
        	if( --times != 0){
        		continue L0; //goto L0;
        	}
        	return;
        }
	}
	
	private static Thread[] handle = new Thread[N+1]; 
	public static void main(String[] args) throws InterruptedException
	{
		//開始時間
		long start = System.currentTimeMillis();
		//打印累加器初始值
	    System.out.println( count + " initial\n");
	  //  Thread handle[] = new Thread[N+1];
        //創建線程
	    for (int i = 1; i <= N; ++i){
	    	int j = i;
		    handle[i] = new Thread(new Runnable(){
		    	@Override
		    	public void run(){
		    		dijkstrasConcurMethod(j);
		    	}
		    });
	    }
	    //線程開始執行
	    for (int i = 1; i <= N; ++i)
	        handle[i].start();
	    //主線程等待子線程結束
	    for (int i = 1; i <= N; ++i)
	        handle[i].join();
	    //打印累加值,== N * TIMES
	    System.out.println(_unsafe.getLongVolatile(mainObj, countOffset));
	  //打印程序執行時間
	    System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds"); 
	}
}


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章