只要有一定的內存order保證,不通過比較並交換(CAS)那些需要硬件支持的原子操作,能不能實現一個互斥的臨界區?答案是:能。
計算機先驅 Edsger Wybe Dijkstra,50多年前的這篇經典paper中就提出瞭解決方案。並且自這以後開啓瞭如何通過一般編程語言實現併發控制的 研究。
這裏的假設我們有N個線程,代表序號分別爲1-N,一個公共變量k用於輔助指示當前佔有臨界區的線程。臨界區是critical section,並且內存模型是先執行的操作對後面可見,對同一個內存位置的訪問是一個接着另一個。
初始數組b[N],c[N]完全都爲true。k的初始值任意(1-N)。這裏的i變量代表當前的執行邏輯單元(線程)。
對於每個線程i,b[i]和c[i]都代表了線程i的參與競爭臨界區的意願,b[i]==false代表線程i將要參與臨界區的爭奪,線程c[i]==false代表線程i正在爭競臨界區。線程退出臨界區時,會而將b[i]、c[i]都置爲true。從而其他線程能夠通過查看當前的b[k]和c[k]來判斷線程是否仍然佔據臨界區,這裏的判斷是一個大概的判斷,由於各個線程執行順序的不確定。
存在多個線程查看b[k],從而將k設置爲自身的id,從而都進入了臨界區前的位置,但即使這樣,由於進臨界區前先要查看其他線程的c[j]值,所以這裏至多隻有一個線程進入臨界區,其他線程都退回到Li1的邏輯。存在這種情況,這裏一個線程都無法獲取臨界區,從而全部回到Li1,下一次繼續競爭。
注意:paper中的Li2,c[i] := true這一句會導致許多重複的無意義操作(因爲c[i]本來就是true),這裏針對的情況僅僅是從Li4裏面goto Li1的時候,所以我們將c[i]:=true放到goto Li1之前就能保持程序語義,並且減少了無用功。
我們用JAVA來實現一遍這個方案試試,並且用10個線程,每個進入臨界區1千萬次,每次+1來驗證它,可執行代碼如下:
package com.psly.testatomic;
import sun.misc.Unsafe;
public class TestVolatile {
//用於內存保證:putXXVolatile/getXXVolatile
private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
private static final int _Obase = _unsafe.arrayBaseOffset(int[].class);
private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
//N:線程數,TIMES每個線程需要進入臨界區的次數。
private final static int N = 10;
private final static int TIMES = 10000000;
private final static int[] B = new int[N+1];
private final static int[] C = new int[N+1];
//每個線程進入臨界區++count,最終count == N * TIMES
private static long count;
//countObj:獲取count字段所屬於的對象(其實就是地址),
private final static Object countObj;
//countOffset:獲取count字段處於所在對象地址的偏移量
private final static long countOffset;
//k與上面的count字段類似
private static int k = 1;
private final static Object kObj;
private final static long kOffset;
static{
for(int i = 1; i <= N; ++i){
B[i] = 1;
C[i] = 1;
}
try {
countObj = _unsafe.staticFieldBase(TestVolatile.class.getDeclaredField("count"));
countOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("count"));
kObj = _unsafe.staticFieldBase(TestVolatile.class.getDeclaredField("k"));
kOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("k"));
} catch (Exception e) {
throw new Error(e);
}
}
final static void dijkstrasConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
B[i] = 0;
L1: for(;;){
if( k != i ) {
//C[i] = 1;
if(B[_unsafe.getIntVolatile(kObj, kOffset)] == 1)
_unsafe.putIntVolatile(kObj, kOffset, i);//k = i;
continue L1;
} else{
_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 0);//C[i] = 0;
//這裏必定會看到更新的C[i],從而根本上保證了互斥,臨界區最多一個線程。
for(int j = 1; j <= N; ++j )
if(j != i && _unsafe.getIntVolatile(C, _Obase + j * _Oscale) == 0){
//將C[i]的值更新回去,寫這裏效率更高
_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
continue L1;
}
}
break L1;
}
//臨界區開始
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
//臨界區結束
_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
B[i]=1;
if( --times != 0){
continue L0; //goto L0;
}
return;
}
}
public static void main(String[] args) throws InterruptedException
{
//開始時間
long start = System.currentTimeMillis();
//打印累加器初始值
System.out.println( count + " initial\n");
Thread handle[] = new Thread[N+1];
//創建線程
for (int i = 1; i <= N; ++i){
int j = i;
handle[i] = new Thread(new Runnable(){
@Override
public void run(){
dijkstrasConcurMethod(j);
}
});
}
//線程開始執行
for (int i = 1; i <= N; ++i)
handle[i].start();
//主線程等待子線程結束
for (int i = 1; i <= N; ++i)
handle[i].join();
//打印累加值,== N * TIMES
System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
//打印程序執行時間
System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}
執行一遍,輸出爲:
0 initial
100000000
12.936 seconds
10個線程,每個進入臨界區1千萬次,總共累加爲1億。費時12.936秒。所以這個示例,至少看起來是正確的
我們接着,
重點關注dijkstrasConcurMethod這個方法:
final static void dijkstrasConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
B[i] = 0;
L1: for(;;){
if( k != i ) {
//C[i] = 1;
if(B[_unsafe.getIntVolatile(kObj, kOffset)] == 1)
_unsafe.putIntVolatile(kObj, kOffset, i);//k = i;
continue L1;
} else{
_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 0);//C[i] = 0;
//這裏必定會看到更新的C[i],從而根本上保證了互斥,臨界區最多一個線程。
for(int j = 1; j <= N; ++j )
if(j != i && _unsafe.getIntVolatile(C, _Obase + j * _Oscale) == 0){
//將C[i]的值更新回去,寫這裏效率更高
_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
continue L1;
}
}
break L1;
}
//臨界區開始
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
//臨界區結束
_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
B[i]=1;
if( --times != 0){
continue L0; //goto L0;
}
return;
}
}
我們將paper中的ture/false用1/0來代替。由於JAVA中沒有goto語句,所以我們有了一個帶表情的循環for(;;)來實現一樣的功能。這裏的 pM代表了線程本身的下標,TIMES爲需要執行臨界區的次數。
其實從嚴格意義上來說這裏的程序並不完全等同於Dijkstra上面paper中的示例,paper中的共享內存要求是強一致的,也就是說任何的一個寫入操作B[i],C[i],k立刻能夠被其他線程看到。
paper發表時是1965年,那個時候對於內存模型以及硬件能力的設想可能是這樣的。但是隨着現代的計算機體系結構的發展,爲了提高程序執行的熟讀,尤其是多層緩存以及指令亂序執行的引入,使得大部分程序設計語言的模型已經不符合上面的假設了。
然而儘管如此,我們的JAVA程序加入volatile語義的操作之後,我們這個程序依然是對的。因爲保證了兩點
保證C上面更新的值在開始探測整個C數組之前被看到。_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 0);//C[i] = 0; //這裏必定會看到更新的C[i],從而根本上保證了互斥,臨界區最多一個線程。 for(int j = 1; j <= N; ++j ) if(j != i && _unsafe.getIntVolatile(C, _Obase + j * _Oscale) == 0){ //將C[i]的值更新回去,寫這裏效率更高 _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1); continue L1; }
保證離開臨界區之後纔將C[i]更新回1,從而防止這個1過早泄露出來,從而導致前面循環探測的失誤。//臨界區開始 long val = _unsafe.getLongVolatile(countObj, countOffset); _unsafe.putLongVolatile(countObj, countOffset, val + 1); //臨界區結束 _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
- 首先從k開始遍歷到自己的id(i),假如發現一個control(j)!=0,說明前面已經有線程在競爭了,所以我們goto返回。否則從k到前一個id的control都爲0,那麼我們就進入第二步。
- 第二步首先將contrl值設置爲2,說明已經進一步競爭了,此時依然可能有多個線程到達此處,所以接下來,我們採用與Dijkstra類似的探測排除方法,最多可以得到一個進入下一步的線程。
- 第三步,將k的值設置爲當前id,進入臨界區。
- 第四部,從臨界區出來之後,將k值設置爲當前id右邊→_→的一個id,如此一來很可能形成環形的執行順序。最後將control[i]設置爲0。
- 最後返回。 注意, 這裏的k設置是沒有競爭的 k:=if i = 1 then N else i -1;是爲了儘量讓右邊一個線程執行,但是極端情況下依然可能被其他線程獲取鎖。所以還是得有L3: k := i; 這一行。
package com.psly.testatomic;
import java.util.Random;
import com.psly.locksupprot.LockSupport;
import sun.misc.Unsafe;
public class TestVolatileKnuthMethod {
private final static Random random = new Random();
//用於內存保證:putXXVolatile/getXXVolatile
private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
private static final int _Obase = _unsafe.arrayBaseOffset(int[].class);
private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
//N:線程數,TIMES每個線程需要進入臨界區的次數。
private final static int N = 5;
private final static int TIMES = 1000;
private final static int[] B = new int[N+1];
private final static int[] C = new int[N+1];
//knuth's method
private final static int[] control = new int[N+1];
//每個線程進入臨界區++count,最終count == N * TIMES
private static long count;
//countObj:獲取count字段所屬於的對象(其實就是地址),
private final static Object countObj;
//countOffset:獲取count字段處於所在對象地址的偏移量
private final static long countOffset;
//k與上面的count字段類似
private static int k = 1;
private final static Object kObj;
private final static long kOffset;
static{
for(int i = 1; i <= N; ++i){
B[i] = 1;
C[i] = 1;
}
try {
countObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("count"));
countOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("count"));
kObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("k"));
kOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("k"));
} catch (Exception e) {
throw new Error(e);
}
}
private static Object obj = new Object();
final static void knuthConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
L1: for(;;){
for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
continue L1;
}
}
for(int j = N; j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
continue L1;
}
}
}
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
for(int j = N; j >= 1; --j){
if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
continue L0;
}
}
_unsafe.putIntVolatile(kObj, kOffset, i);
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
_unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i -1);
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
if( --times != 0)
continue L0;
return ;
}
}
private static Thread[] handle = new Thread[N+1];
public static void main(String[] args) throws InterruptedException
{
//開始時間
long start = System.currentTimeMillis();
//打印累加器初始值
System.out.println( count + " initial\n");
//創建線程
for (int i = 1; i <= N; ++i){
int j = i;
handle[i] = new Thread(new Runnable(){
@Override
public void run(){
knuthConcurMethod(j);
}
});
}
//線程開始執行
for (int i = 1; i <= N; ++i)
handle[i].start();
//主線程等待子線程結束
for (int i = 1; i <= N; ++i)
handle[i].join();
//打印累加值,== N * TIMES
System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
//打印程序執行時間
System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}
輸出如下:
0 initial
5000
7.464 seconds
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
int j = (i == 1)? N : i -1;
for(int m = 0; m < N - 1; ++m){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
break;
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
LockSupport.unpark(handle[j]);
break;
}
j = (j == 1)? N : j -1;
}
if( --times != 0)
休眠:
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
package com.psly.testatomic;
import com.psly.locksupprot.LockSupport;
import sun.misc.Unsafe;
public class TestVolatileKnuthMethod {
//用於內存保證:putXXVolatile/getXXVolatile
private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
private static final int _Obase = _unsafe.arrayBaseOffset(int[].class);
private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
//N:線程數,TIMES每個線程需要進入臨界區的次數。
private final static int N = 5;
private final static int TIMES = 1000;
private final static int[] B = new int[N+1];
private final static int[] C = new int[N+1];
//knuth's method
private final static int[] control = new int[N+1];
//每個線程進入臨界區++count,最終count == N * TIMES
private static long count;
//countObj:獲取count字段所屬於的對象(其實就是地址),
private final static Object countObj;
//countOffset:獲取count字段處於所在對象地址的偏移量
private final static long countOffset;
//k與上面的count字段類似
private static int k = 1;
private final static Object kObj;
private final static long kOffset;
static{
for(int i = 1; i <= N; ++i){
B[i] = 1;
C[i] = 1;
}
try {
countObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("count"));
countOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("count"));
kObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("k"));
kOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("k"));
} catch (Exception e) {
throw new Error(e);
}
}
private static Object obj = new Object();
final static void knuthConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
L1: for(;;){
for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
for(int j = N; j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
}
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
for(int j = N; j >= 1; --j){
if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
continue L0;
}
}
//臨界區開始
_unsafe.putIntVolatile(kObj, kOffset, i);
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
_unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i -1); //臨界區結束
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
int j = (i == 1)? N : i -1;
for(int m = 0; m < N - 1; ++m){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
break;
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
LockSupport.unpark(handle[j]);
break;
}
j = (j == 1)? N : j -1;
}
if( --times != 0)
continue L0;
return ;
}
}
private static Thread[] handle = new Thread[N+1];
public static void main(String[] args) throws InterruptedException
{
//開始時間
long start = System.currentTimeMillis();
//打印累加器初始值
System.out.println( count + " initial\n");
//創建線程
for (int i = 1; i <= N; ++i){
int j = i;
handle[i] = new Thread(new Runnable(){
@Override
public void run(){
knuthConcurMethod(j);
}
});
}
//線程開始執行
for (int i = 1; i <= N; ++i)
handle[i].start();
//主線程等待子線程結束
for (int i = 1; i <= N; ++i)
handle[i].join();
//打印累加值,== N * TIMES
System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
//打印程序執行時間
System.out.println((System.currentTimeMillis() - start) / 1000.0 + " milliseconds");
}
}
0 initial
5000
0.043 milliseconds
0 initial
500000
2.938 seconds
100個線程,每個進入臨界區5000次,總共2.938秒,這比輪詢的版本好多啦。
final static void knuthConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
L1: for(;;){
for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
for(int j = N; j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
}
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
for(int j = N; j >= 1; --j){
if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
continue L0;
}
}
_unsafe.putIntVolatile(kObj, kOffset, i);
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
_unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i -1);
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
int j = (i == 1)? N : i -1;
for(int m = 0; m < N - 1; ++m){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
break;
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
LockSupport.unpark(handle[j]);
break;
}
j = (j == 1)? N : j -1;
}
if( --times != 0)
continue L0;
return ;
}
}
int j = (i == 1)? N : i -1;
for(int m = 0; m < N - 1; ++m){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
break;
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
LockSupport.unpark(handle[j]);
break;
}
j = (j == 1)? N : j -1;
}
package com.psly.testatomic;
import java.text.SimpleDateFormat;
import java.util.Date;
import com.psly.locksupprot.LockSupport;
import sun.misc.Unsafe;
public class TestVolatileBruijnMethod {
//用於內存保證:putXXVolatile/getXXVolatile
private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
private static final int _Obase = _unsafe.arrayBaseOffset(int[].class);
private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
//N:線程數,TIMES每個線程需要進入臨界區的次數。
private final static int N = 100;
private final static int TIMES = 5000;
private final static int[] B = new int[N+1];
private final static int[] C = new int[N+1];
//knuth's method
private final static int[] control = new int[N+1];
//每個線程進入臨界區++count,最終count == N * TIMES
private static long count;
//countObj:獲取count字段所屬於的對象(其實就是地址),
private final static Object countObj;
//countOffset:獲取count字段處於所在對象地址的偏移量
private final static long countOffset;
//k與上面的count字段類似
private static int k = 1;
private final static Object kObj;
private final static long kOffset;
static{
for(int i = 1; i <= N; ++i){
B[i] = 1;
C[i] = 1;
}
try {
countObj = _unsafe.staticFieldBase(TestVolatileBruijnMethod.class.getDeclaredField("count"));
countOffset = _unsafe.staticFieldOffset(TestVolatileBruijnMethod.class.getDeclaredField("count"));
kObj = _unsafe.staticFieldBase(TestVolatileBruijnMethod.class.getDeclaredField("k"));
kOffset = _unsafe.staticFieldOffset(TestVolatileBruijnMethod.class.getDeclaredField("k"));
} catch (Exception e) {
throw new Error(e);
}
}
private static Object obj = new Object();
final static void knuthConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
L1: for(;;){
for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
for(int j = N; j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
}
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
for(int j = N; j >= 1; --j){
if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
continue L0;
}
}
// _unsafe.putIntVolatile(kObj, kOffset, i);
int kLocal = _unsafe.getIntVolatile(kObj, kOffset);
int kNew = kLocal;
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
if(_unsafe.getIntVolatile(control, _Obase + kLocal * _Oscale) == 0 || kLocal == i)
_unsafe.putIntVolatile(kObj, kOffset, kNew = ((kLocal == 1)? N: kLocal - 1));
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
int j = kNew;
for(int m = 0; m < N; ++m){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
break;
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
LockSupport.unpark(handle[j]);
break;
}
j = (j == 1)? N : j -1;
}
if( --times != 0)
continue L0;
return ;
}
}
private static Thread[] handle = new Thread[N+1];
public static void main(String[] args) throws InterruptedException
{
System.out.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
.format(new Date()));
//開始時間
long start = System.currentTimeMillis();
//打印累加器初始值
System.out.println( count + " initial\n");
//創建線程
for (int i = 1; i <= N; ++i){
int j = i;
handle[i] = new Thread(new Runnable(){
@Override
public void run(){
knuthConcurMethod(j);
}
});
}
//線程開始執行
for (int i = 1; i <= N; ++i)
handle[i].start();
//主線程等待子線程結束
for (int i = 1; i <= N; ++i)
handle[i].join();
//打印累加值,== N * TIMES
System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
//打印程序執行時間
System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}
package com.psly.testatomic;
import java.text.SimpleDateFormat;
import java.util.Date;
import com.psly.locksupprot.LockSupport;
import sun.misc.Unsafe;
public class TestVolatileEisenbergMethod {
//用於內存保證:putXXVolatile/getXXVolatile
private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
private static final int _Obase = _unsafe.arrayBaseOffset(int[].class);
private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
//N:線程數,TIMES每個線程需要進入臨界區的次數。
private final static int N = 100;
private final static int TIMES = 5000;
private final static int[] B = new int[N+1];
private final static int[] C = new int[N+1];
//knuth's method
private final static int[] control = new int[N+1];
//每個線程進入臨界區++count,最終count == N * TIMES
private static long count;
//countObj:獲取count字段所屬於的對象(其實就是地址),
private final static Object countObj;
//countOffset:獲取count字段處於所在對象地址的偏移量
private final static long countOffset;
//k與上面的count字段類似
private static int k = 1;
private final static Object kObj;
private final static long kOffset;
static{
for(int i = 1; i <= N; ++i){
B[i] = 1;
C[i] = 1;
}
try {
countObj = _unsafe.staticFieldBase(TestVolatileEisenbergMethod.class.getDeclaredField("count"));
countOffset = _unsafe.staticFieldOffset(TestVolatileEisenbergMethod.class.getDeclaredField("count"));
kObj = _unsafe.staticFieldBase(TestVolatileEisenbergMethod.class.getDeclaredField("k"));
kOffset = _unsafe.staticFieldOffset(TestVolatileEisenbergMethod.class.getDeclaredField("k"));
} catch (Exception e) {
throw new Error(e);
}
}
private static Object obj = new Object();
final static void EisenbergConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
L1: for(;;){
int kLocal;
for(int j = (kLocal = _unsafe.getIntVolatile(kObj, kOffset)); j <= N; ++j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
for(int j = 1; j <= kLocal - 1; ++j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
}
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
for(int j = 1; j <= N; ++j){
if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
continue L0;
}
}
int kLocal;
if(_unsafe.getIntVolatile(control, _Obase + (kLocal = _unsafe.getIntVolatile(kObj, kOffset)) *_Oscale ) != 0
&& kLocal != i)
continue L0;
_unsafe.putIntVolatile(kObj, kOffset, i);
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
// System.out.println(Thread.currentThread().getName());
int kNew = i;
L2: for(;;){
for(int j = i + 1; j <= N; ++j){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
_unsafe.putIntVolatile(kObj, kOffset, j);
// LockSupport.unpark(handle[j]);
kNew = j;
break L2;
}
}
for(int j = 1; j <= i - 1; ++j){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
_unsafe.putIntVolatile(kObj, kOffset, j);
// LockSupport.unpark(handle[j]);
kNew = j;
break L2;
}
}
break;
}
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
int j = kNew;
for(int m = 0; m < N; ++m){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
break;
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
LockSupport.unpark(handle[j]);
break;
}
j = (j == N)? 1 : j + 1;
}
if( --times != 0)
continue L0;
return ;
}
}
private static Thread[] handle = new Thread[N+1];
public static void main(String[] args) throws InterruptedException
{
System.out.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
.format(new Date()));
//開始時間
long start = System.currentTimeMillis();
//打印累加器初始值
System.out.println( count + " initial\n");
//創建線程
for (int i = 1; i <= N; ++i){
int j = i;
handle[i] = new Thread(new Runnable(){
@Override
public void run(){
EisenbergConcurMethod(j);
}
});
}
//線程開始執行
for (int i = 1; i <= N; ++i)
handle[i].start();
//主線程等待子線程結束
for (int i = 1; i <= N; ++i)
handle[i].join();
//打印累加值,== N * TIMES
System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
//打印程序執行時間
System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}
L1: for(;;){ //以下兩個循環的代碼判斷當前線程是否適合競爭臨界區
for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
// LockSupport.park(obj);
continue L1;
}
}
for(int j = N; j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
// LockSupport.park(obj);
continue L1;
}
} //以上兩個循環的代碼判斷當前線程是否適合競爭臨界區
}
//以下代碼保證最多一個線程進去臨界區
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
for(int j = N; j >= 1; --j){
if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
continue L0;
}
}
//以上代碼保證最多一個線程進入臨界區
_unsafe.putIntVolatile(kObj, kOffset, i);
//臨界區start
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
_unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i - 1);
//臨界區end
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
- 先通過兩個循環來判斷當前線程是否適合競爭鎖,適合跳出L1,否則繼續循環
- 接着第二個循環通過探測其他線程的control值,假如發現都不爲0則結束循環,獲得鎖,否則跳回L0,繼續前面的循環判斷。注意這裏的語義確保最多隻有一個線程進入臨界區,存在全部線程都無法獲得鎖,跳回L0的極端情況。
- 臨界區結尾處將0給control[i],替換掉了它的2值,從而之後,讓其他線程有機會獲得鎖(根據競爭判斷的語義,假如一個線程看到其他的某個爲2是無法獲取鎖的)。
附上:
package com.psly.testatomic;
import java.lang.reflect.Field;
import sun.misc.Unsafe;
public class UtilUnsafe {
private UtilUnsafe() { } // dummy private constructor
/** Fetch the Unsafe. Use With Caution. */
public static Unsafe getUnsafe() {
// Not on bootclasspath
if( UtilUnsafe.class.getClassLoader() == null )
return Unsafe.getUnsafe();
try {
final Field fld = Unsafe.class.getDeclaredField("theUnsafe");
fld.setAccessible(true);
return (Unsafe) fld.get(UtilUnsafe.class);
} catch (Exception e) {
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
}
}
}
package com.psly.locksupprot;
import com.psly.testatomic.UtilUnsafe;
public class LockSupport {
private LockSupport() {} // Cannot be instantiated.
private static void setBlocker(Thread t, Object arg) {
// Even though volatile, hotspot doesn't need a write barrier here.
UNSAFE.putObject(t, parkBlockerOffset, arg);
}
/**
* Makes available the permit for the given thread, if it
* was not already available. If the thread was blocked on
* {@code park} then it will unblock. Otherwise, its next call
* to {@code park} is guaranteed not to block. This operation
* is not guaranteed to have any effect at all if the given
* thread has not been started.
*
* @param thread the thread to unpark, or {@code null}, in which case
* this operation has no effect
*/
public static void unpark(Thread thread) {
if (thread != null)
UNSAFE.unpark(thread);
}
/**
* Disables the current thread for thread scheduling purposes unless the
* permit is available.
*
* <p>If the permit is available then it is consumed and the call returns
* immediately; otherwise
* the current thread becomes disabled for thread scheduling
* purposes and lies dormant until one of three things happens:
*
* <ul>
* <li>Some other thread invokes {@link #unpark unpark} with the
* current thread as the target; or
*
* <li>Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread; or
*
* <li>The call spuriously (that is, for no reason) returns.
* </ul>
*
* <p>This method does <em>not</em> report which of these caused the
* method to return. Callers should re-check the conditions which caused
* the thread to park in the first place. Callers may also determine,
* for example, the interrupt status of the thread upon return.
*
* @param blocker the synchronization object responsible for this
* thread parking
* @since 1.6
*/
public static void park(Object blocker) {
Thread t = Thread.currentThread();
setBlocker(t, blocker);
UNSAFE.park(false, 0L);
setBlocker(t, null);
}
/**
* Disables the current thread for thread scheduling purposes, for up to
* the specified waiting time, unless the permit is available.
*
* <p>If the permit is available then it is consumed and the call
* returns immediately; otherwise the current thread becomes disabled
* for thread scheduling purposes and lies dormant until one of four
* things happens:
*
* <ul>
* <li>Some other thread invokes {@link #unpark unpark} with the
* current thread as the target; or
*
* <li>Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread; or
*
* <li>The specified waiting time elapses; or
*
* <li>The call spuriously (that is, for no reason) returns.
* </ul>
*
* <p>This method does <em>not</em> report which of these caused the
* method to return. Callers should re-check the conditions which caused
* the thread to park in the first place. Callers may also determine,
* for example, the interrupt status of the thread, or the elapsed time
* upon return.
*
* @param blocker the synchronization object responsible for this
* thread parking
* @param nanos the maximum number of nanoseconds to wait
* @since 1.6
*/
public static void parkNanos(Object blocker, long nanos) {
if (nanos > 0) {
Thread t = Thread.currentThread();
setBlocker(t, blocker);
UNSAFE.park(false, nanos);
setBlocker(t, null);
}
}
/**
* Disables the current thread for thread scheduling purposes, until
* the specified deadline, unless the permit is available.
*
* <p>If the permit is available then it is consumed and the call
* returns immediately; otherwise the current thread becomes disabled
* for thread scheduling purposes and lies dormant until one of four
* things happens:
*
* <ul>
* <li>Some other thread invokes {@link #unpark unpark} with the
* current thread as the target; or
*
* <li>Some other thread {@linkplain Thread#interrupt interrupts} the
* current thread; or
*
* <li>The specified deadline passes; or
*
* <li>The call spuriously (that is, for no reason) returns.
* </ul>
*
* <p>This method does <em>not</em> report which of these caused the
* method to return. Callers should re-check the conditions which caused
* the thread to park in the first place. Callers may also determine,
* for example, the interrupt status of the thread, or the current time
* upon return.
*
* @param blocker the synchronization object responsible for this
* thread parking
* @param deadline the absolute time, in milliseconds from the Epoch,
* to wait until
* @since 1.6
*/
public static void parkUntil(Object blocker, long deadline) {
Thread t = Thread.currentThread();
setBlocker(t, blocker);
UNSAFE.park(true, deadline);
setBlocker(t, null);
}
/**
* Returns the blocker object supplied to the most recent
* invocation of a park method that has not yet unblocked, or null
* if not blocked. The value returned is just a momentary
* snapshot -- the thread may have since unblocked or blocked on a
* different blocker object.
*
* @param t the thread
* @return the blocker
* @throws NullPointerException if argument is null
* @since 1.6
*/
public static Object getBlocker(Thread t) {
if (t == null)
throw new NullPointerException();
return UNSAFE.getObjectVolatile(t, parkBlockerOffset);
}
/**
* Disables the current thread for thread scheduling purposes unless the
* permit is available.
*
* <p>If the permit is available then it is consumed and the call
* returns immediately; otherwise the current thread becomes disabled
* for thread scheduling purposes and lies dormant until one of three
* things happens:
*
* <ul>
*
* <li>Some other thread invokes {@link #unpark unpark} with the
* current thread as the target; or
*
* <li>Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread; or
*
* <li>The call spuriously (that is, for no reason) returns.
* </ul>
*
* <p>This method does <em>not</em> report which of these caused the
* method to return. Callers should re-check the conditions which caused
* the thread to park in the first place. Callers may also determine,
* for example, the interrupt status of the thread upon return.
*/
public static void park() {
UNSAFE.park(false, 0L);
}
/**
* Disables the current thread for thread scheduling purposes, for up to
* the specified waiting time, unless the permit is available.
*
* <p>If the permit is available then it is consumed and the call
* returns immediately; otherwise the current thread becomes disabled
* for thread scheduling purposes and lies dormant until one of four
* things happens:
*
* <ul>
* <li>Some other thread invokes {@link #unpark unpark} with the
* current thread as the target; or
*
* <li>Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread; or
*
* <li>The specified waiting time elapses; or
*
* <li>The call spuriously (that is, for no reason) returns.
* </ul>
*
* <p>This method does <em>not</em> report which of these caused the
* method to return. Callers should re-check the conditions which caused
* the thread to park in the first place. Callers may also determine,
* for example, the interrupt status of the thread, or the elapsed time
* upon return.
*
* @param nanos the maximum number of nanoseconds to wait
*/
public static void parkNanos(long nanos) {
if (nanos > 0)
UNSAFE.park(false, nanos);
}
/**
* Disables the current thread for thread scheduling purposes, until
* the specified deadline, unless the permit is available.
*
* <p>If the permit is available then it is consumed and the call
* returns immediately; otherwise the current thread becomes disabled
* for thread scheduling purposes and lies dormant until one of four
* things happens:
*
* <ul>
* <li>Some other thread invokes {@link #unpark unpark} with the
* current thread as the target; or
*
* <li>Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread; or
*
* <li>The specified deadline passes; or
*
* <li>The call spuriously (that is, for no reason) returns.
* </ul>
*
* <p>This method does <em>not</em> report which of these caused the
* method to return. Callers should re-check the conditions which caused
* the thread to park in the first place. Callers may also determine,
* for example, the interrupt status of the thread, or the current time
* upon return.
*
* @param deadline the absolute time, in milliseconds from the Epoch,
* to wait until
*/
public static void parkUntil(long deadline) {
UNSAFE.park(true, deadline);
}
/**
* Returns the pseudo-randomly initialized or updated secondary seed.
* Copied from ThreadLocalRandom due to package access restrictions.
*/
static final int nextSecondarySeed() {
int r;
Thread t = Thread.currentThread();
if ((r = UNSAFE.getInt(t, SECONDARY)) != 0) {
r ^= r << 13; // xorshift
r ^= r >>> 17;
r ^= r << 5;
}
else if ((r = java.util.concurrent.ThreadLocalRandom.current().nextInt()) == 0)
r = 1; // avoid zero
UNSAFE.putInt(t, SECONDARY, r);
return r;
}
// Hotspot implementation via intrinsics API
private static final sun.misc.Unsafe UNSAFE;
private static final long parkBlockerOffset;
private static final long SEED;
private static final long PROBE;
private static final long SECONDARY;
static {
try {
UNSAFE = UtilUnsafe.getUnsafe();
Class<?> tk = Thread.class;
parkBlockerOffset = UNSAFE.objectFieldOffset
(tk.getDeclaredField("parkBlocker"));
SEED = UNSAFE.objectFieldOffset
(tk.getDeclaredField("threadLocalRandomSeed"));
PROBE = UNSAFE.objectFieldOffset
(tk.getDeclaredField("threadLocalRandomProbe"));
SECONDARY = UNSAFE.objectFieldOffset
(tk.getDeclaredField("threadLocalRandomSecondarySeed"));
} catch (Exception ex) { throw new Error(ex); }
}
}
最近發現,針對這個獨佔互斥區的併發控制,2013年圖靈獎得主Leslie Lamport在1974年也提出過另一種算法,paper截圖如下:
證明過程:
這個算法的特點是,沒有中心控制。
我們用JAVA代碼實現下:
package com.psly.testatomic;
import sun.misc.Unsafe;
public class TestVolatile {
//用於內存保證:putXXVolatile/getXXVolatile
private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
private static final int _Obase = _unsafe.arrayBaseOffset(long[].class);
private static final int _Oscale = _unsafe.arrayIndexScale(long[].class);
//N:線程數,TIMES每個線程需要進入臨界區的次數。
private final static int N = 2000;
private final static int TIMES = 1000;
private final static long[] choosing = new long[N+1];
private final static long[] number = new long[N+1];
//每個線程進入臨界區++count,最終count == N * TIMES
private static long count;
//countObj:獲取count字段所屬於的對象(其實就是地址),
private final static Object mainObj;
//countOffset:獲取count字段處於所在對象地址的偏移量
private final static long countOffset;
private static Object obj = new Object();
// private static Queue<Thread> queues = new ConcurrentLinkedQueue();
static{
for(int i = 1; i <= N; ++i){
choosing[i] = 0;
number[i] = 0;
}
try {
mainObj = _unsafe.staticFieldBase(TestVolatile.class.getDeclaredField("count"));
countOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("count"));
// waitersOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("waiters"));
} catch (Exception e) {
throw new Error(e);
}
}
final static void dijkstrasConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
_unsafe.putLongVolatile(choosing, _Obase + i * _Oscale, 1);
//獲取最大的number並+1。
long maxNum = _unsafe.getLongVolatile(number, _Obase + _Oscale), midNum;
for(int j = 2; j <= N; ++j)
if(maxNum < (midNum = _unsafe.getLongVolatile(number, _Obase + j * _Oscale)))
maxNum = midNum;
_unsafe.putLongVolatile(number, _Obase + i * _Oscale, 1 + maxNum);
_unsafe.putLongVolatile(choosing, _Obase + i * _Oscale, 0);
/* for(int j = 1; j <i; ++j)
LockSupport.unpark(handle[j]);
for(int j = i+1; j <= N; ++j)
LockSupport.unpark(handle[j]);*/
long jNumber, iNumber;
for(int j = 1; j <= N; ++j){
L1: for(;;){
for(int k = 0 ; k < 100; ++k)
if(!(_unsafe.getLongVolatile(choosing, _Obase + j * _Oscale) != 0))
break L1;
// LockSupport.park(obj);
}
L2: for(;;){
for(int k = 0; k < 1000; ++k)
if(!(_unsafe.getLongVolatile(number, _Obase + j * _Oscale) != 0
&& ((jNumber=_unsafe.getLongVolatile(number, _Obase + j * _Oscale))
< (iNumber=_unsafe.getLongVolatile(number, _Obase + i * _Oscale))
|| (jNumber == iNumber && j < i))))
break L2;
LockSupport.park(obj);
}
}
//critical section
//臨界區開始
long val = _unsafe.getLongVolatile(mainObj, countOffset);
_unsafe.putLongVolatile(mainObj, countOffset, val + 1);
//臨界區結束
//設置標識
_unsafe.putLongVolatile(number, _Obase + i * _Oscale, 0);
//喚醒需要的線程
Thread target = handle[i];
long numMax = Long.MAX_VALUE, arg;
for(int j = 1; j <i; ++j)
if((arg = _unsafe.getLongVolatile(number, _Obase + j * _Oscale)) != 0 && arg < numMax)
{ target = handle[j]; numMax = arg;}
for(int j = i+1; j <= N; ++j)
if((arg = _unsafe.getLongVolatile(number, _Obase + j * _Oscale)) != 0 && arg < numMax)
{ target = handle[j]; numMax = arg;}
LockSupport.unpark(target);
/*for(int j = 1; j <= N; ++j)
LockSupport.unpark(handle[j]);*/
//計算次數
if( --times != 0){
continue L0; //goto L0;
}
return;
}
}
private static Thread[] handle = new Thread[N+1];
public static void main(String[] args) throws InterruptedException
{
//開始時間
long start = System.currentTimeMillis();
//打印累加器初始值
System.out.println( count + " initial\n");
// Thread handle[] = new Thread[N+1];
//創建線程
for (int i = 1; i <= N; ++i){
int j = i;
handle[i] = new Thread(new Runnable(){
@Override
public void run(){
dijkstrasConcurMethod(j);
}
});
}
//線程開始執行
for (int i = 1; i <= N; ++i)
handle[i].start();
//主線程等待子線程結束
for (int i = 1; i <= N; ++i)
handle[i].join();
//打印累加值,== N * TIMES
System.out.println(_unsafe.getLongVolatile(mainObj, countOffset));
//打印程序執行時間
System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}