Java編程拾遺『ThreadLocal』

本篇文章,我們來介紹一下,Java多線程編程中一個比較常用的工具,線程本地變量——ThreadLocal。ThreadLocal簡單的來講,就是每個線程都有同一個成員變量的獨有拷貝。由於每個線程都有成員變量獨立的拷貝,所以不存在多線程訪問同一成員變量的問題,也就解決了線程安全問題。之前我們介紹的鎖,是解決線程安全問題的一個途徑,那麼本篇文章介紹的線程本地變量是另一種解決線程安全問題的思想,就是通過額外每個線程都分配成員變量的拷貝,來換取加鎖帶來的併發效率問題,可以講是一種空間換時間的思想。本篇文章我們就來介紹一下ThreadLocal的用法及實現原理。

1. 基本概念

ThreadLocal是一個泛型類,接受一個類型參數T,它只有一個空的構造方法,有兩個主要的public方法:

set就是設置值,get就是獲取值,如果沒有值,返回null。看上去,ThreadLocal就是一個單一對象的容器,比如:

public static void main(String[] args) {
    ThreadLocal<Integer> local = new ThreadLocal<>();
    local.set(100);
    System.out.println(local.get());
}

輸出100,沒什麼好講的。ThreadLocal方便的地方主要在多線程的場景,如下:

public class ThreadLocalTest {
    private static ThreadLocal<Integer> local = new ThreadLocal<>();

    public static void main(String[] args) throws InterruptedException {
        Thread child = new Thread(() -> {
            System.out.println("child thread initial: " + local.get());
            local.set(200);
            System.out.println("child thread final: " + local.get());
        });
        
        local.set(100);
        child.start();
        child.join();
        System.out.println("main thread final: " + local.get());
    }
}

運行結果:

child thread initial: null
child thread final: 200
main thread final: 100

說明,main線程對local變量的設置對child線程不起作用,child線程對local變量的改變也不會影響main線程,它們訪問的雖然是同一個變量local,但每個線程都有自己的獨立的值,這就是線程本地變量的含義

ThreadLocal中除了上述方法,還有一個比較重要的protected方法initialValue,如下:

protected T initialValue()
public void remove()

initialValue用於提供初始值,它是一個protected方法,可以通過匿名內部類的方式提供,當調用get方法時,如果之前沒有設置過,會調用initialValue方法獲取初始值,默認實現是返回null。remove刪掉當前線程對應的值,如果刪掉後,再次調用get,會再調用initialValue獲取初始值。

public class ThreadLocalTest1 {
    private static ThreadLocal<Integer> local = new ThreadLocal<Integer>(){

        @Override
        protected Integer initialValue() {
            return 100;
        }
    };

    public static void main(String[] args) {
        System.out.println(local.get());
        local.set(200);
        local.remove();
        System.out.println(local.get());
    }
}

運行結果:

100
100

說明當調用get方法時,如果之前沒有設置過,會調用initialValue方法獲取初始值,默認實現是返回null。remove刪掉當前線程對應的值,如果刪掉後,再次調用get,會再調用initialValue獲取初始值。

Java8之後,ThreadLocal新增了一個靜態方法withInitial,可以用來指定ThreadLocal的初始值,達到跟initialValue相同的效果,聲明如下:

public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
    return new SuppliedThreadLocal<>(supplier);
}

withInitial參數supplier就是用來指定初始值的,所以上述ThreadLocal聲明可以替換爲:

private static ThreadLocal<Integer> local = ThreadLocal.withInitial(() -> 100);

2. ThreadLocal應用場景

2.1 實現線程安全

ThreadLocal是實現線程安全的一種方案,比如對於DateFormat/SimpleDateFormat,這兩個類是非線程安全類,Java8及以後提供了線程安全的解決方案:DateTimeFormat。而如果要以線程安全的方式使用DateFormat/SimpleDateFormat,可以又以下集中選擇:

  • 使用鎖
  • 不共享,每次使用DateFormat/SimpleDateFormat都新建一個對象
  • 使用ThreadLocal

下面我們就來介紹以下爲什麼DateFormat/SimpleDateFormat非線程安全,以及上述三種解決方案。首先來看個例子:

public class ThreadLocalTest2 {
    private static SimpleDateFormat simpleDateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");

    private static ExecutorService pool = Executors.newFixedThreadPool(200);

    private static CountDownLatch countDownLatch = new CountDownLatch(100);
    
    public static void main(String[] args) throws InterruptedException {
        Set<String> dates = Collections.synchronizedSet(new HashSet<String>());
        for (int i = 0; i < 100; i++) {
            //獲取當前時間
            Calendar calendar = Calendar.getInstance();
            int finalI = i;
            pool.execute(() -> {
                //時間增加
                calendar.add(Calendar.DATE, finalI);
                //通過simpleDateFormat把時間轉換成字符串
                String dateString = simpleDateFormat.format(calendar.getTime());
                //把字符串放入Set中
                dates.add(dateString);
                //countDown
                countDownLatch.countDown();
            });
        }
        //阻塞,直到countDown數量爲0
        countDownLatch.await();
        //輸出去重後的時間個數
        System.out.println(dates.size());

    }
}

上述示例,啓動一個線程池,for循環中循環提交100個任務,每個任務的職責是將當前日期加上每次循環的天數,並通過simpleDateFormat格式化,將格式化的結果添加到Set中。那麼理論上循環100次,肯定能得到100個不同的結果,set中最終會有100個元素。但是上述代碼的執行結果一般不是100,說明SimpleDateFormat類是存在線程安全問題的。

我們來看一下SimpleDateFormat類的format方法實現,就能指導爲什麼SimpleDateFormat非線程安全了:

SimpleDateFormat中的format方法在執行過程中,會使用一個成員變量calendar來保存時間。我們在聲明SimpleDateFormat的時候,使用的是static定義的。那麼這個SimpleDateFormat就是一個共享變量,所以,SimpleDateFormat中的calendar也就可以被多個線程訪問到。

假設線程1剛剛執行完calendar.setTime把時間設置成2019-10-10,還沒等執行完,線程2又執行了calendar.setTime把時間改成了2019-10-11。這時候線程1繼續往下執行,拿到的calendar.getTime得到的時間就是線程2改過之後的。

除了format方法以外,SimpleDateFormat的parse方法也有同樣的問題,DateFormat也一樣。所以DateFormat/SimpleDateFormat是非線程安全類,不要把SimpleDateFormat作爲一個共享變量使用。如果要線程安全地使用DateFormat/SimpleDateFormat,可以通過如下方式:

2.1.1 使用鎖

如果不存在多個線程同時訪問SimpleDateFormat對象,那麼就不存在SimpleDateFormat內部的Calendar相互影響的問題。基於這種考慮,最簡單的實現就是使用同步鎖,保證同一時刻只有一個線程可以訪問SimpleDateFormat對象,如下:

for (int i = 0; i < 100; i++) {
    //獲取當前時間
    Calendar calendar = Calendar.getInstance();
    int finalI = i;
    pool.execute(() -> {
        //加鎖
        synchronized (simpleDateFormat) {
            //時間增加
            calendar.add(Calendar.DATE, finalI);
            //通過simpleDateFormat把時間轉換成字符串
            String dateString = simpleDateFormat.format(calendar.getTime());
            //把字符串放入Set中
            dates.add(dateString);
            //countDown
            countDownLatch.countDown();
        }
    });
}

2.1.2 使用局部變量代替成員變量

如果SimpleDateFormat對象爲局部變量,那麼也不會存在多個線程同時訪問SimpleDateFormat對象的情況,如下:

for (int i = 0; i < 100; i++) {
    //獲取當前時間
    Calendar calendar = Calendar.getInstance();
    int finalI = i;
    pool.execute(() -> {
        // SimpleDateFormat聲明成局部變量
        SimpleDateFormat simpleDateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");

        calendar.add(Calendar.DATE, finalI);
        String dateString = simpleDateFormat.format(calendar.getTime());
        dates.add(dateString);
        countDownLatch.countDown();
    });
}

2.1.3 使用ThreadLocal

ThreadLocal可以確保每個線程都可以得到單獨的一個SimpleDateFormat對象,那麼自然也就不存在競爭問題了。如下:

private static ThreadLocal<SimpleDateFormat> simpleDateFormatThreadLocal = ThreadLocal.withInitial(() ->
            new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"));

2.2 保存上下文信息

ThreadLocal的另一個典型用途是保存上下文信息,比如在一個Web服務器中,一個線程執行用戶的請求,在執行過程中,很多代碼都會訪問一些共同的信息,比如請求信息、用戶身份信息、數據庫連接、當前事務等,它們是線程執行過程中的全局信息,如果作爲參數在不同代碼間傳遞,代碼會很囉嗦,這時,使用ThreadLocal就很方便,所以它被用於各種框架如Spring中,如下:

public class RequestContext {
    public static class Request { //...
    };

    private static ThreadLocal<String> localUserId = new ThreadLocal<>();
    private static ThreadLocal<Request> localRequest = new ThreadLocal<>();

    public static String getCurrentUserId() {
        return localUserId.get();
    }

    public static void setCurrentUserId(String userId) {
        localUserId.set(userId);
    }

    public static Request getCurrentRequest() {
        return localRequest.get();
    }

    public static void setCurrentRequest(Request request) {
        localRequest.set(request);
    }

    public void clear() {
        localUserId.clear();
        localRequest.clear();
    }
}

使用上,可以通過切面,在切面中調用set方法如setCurrentRequest/setCurrentUserId進行設置,然後就可以在代碼的任意其他地方調用get相關方法進行獲取了。

但同時要注意的是,在RequestContext使用結束,要注意調用clear清理ThreadLocal,否則有可能會影響其它線程的執行(因爲一般web容器都會使用線程池,線程會複用,如果某個線程使用結束不清理的話,ThreadLocal中的變量會影響後續的使用)。這點會在下面單獨介紹。

3. ThreadLocal的底層實現

接下來我們來看一下ThreadLocal的底層實現原理,爲什麼對同一個ThreadLocal對象的get/set,每個線程都能有自己獨立的值。先從set方法看起:

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

首先獲取當前線程,並調用getMap方法獲取ThreadLocalMap對象。那麼來看一下getMap:

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

返回線程的實例對象的threadLocals成員,它的初始值爲null。在null時,ThreadLocal類的set方法調用createMap初始化:

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

可以看出,每個線程都有一個Map,類型爲ThreadLocalMap,調用set實際上是在線程自己的Map裏設置了一個條目,鍵爲當前的ThreadLocal對象,值爲value。所以能做到不同線程之間不相互影響。

get方法的代碼爲:

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

通過線程訪問到Map,以ThreadLocal對象爲鍵從Map中獲取到條目,取其value,如果Map中沒有,調用setInitialValue,如下:

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

initialValue()就是之前提到的提供初始值的方法,默認實現就是返回null。Java8提供的withInitial方法返回爲SuppliedThreadLocal對象,該類是ThreadLocal的子類,並在該類中重寫了initialValue方法,如下:

public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
    return new SuppliedThreadLocal<>(supplier);
}
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

    private final Supplier<? extends T> supplier;

    SuppliedThreadLocal(Supplier<? extends T> supplier) {
        this.supplier = Objects.requireNonNull(supplier);
    }

    @Override
    protected T initialValue() {
        return supplier.get();
    }
}

所以withInitial方法的函數式參數,就是ThreadLocal的初始值。

最後來看一下remove方法:

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

就是將該ThreadLocal對象在線程map中條目刪除。

最後我們來總結一下,爲什麼ThreadLocal可以實現每個線程都有自己的獨立拷貝。因爲每個線程都有一個Map,對於每個ThreadLocal對象,調用其get/set方法,實際上就是以ThreadLocal對象爲鍵讀寫當前線程的Map

4. ThreadLocal與線程池

上面我們講到,ThreadLocal對象可以實現各個子線程隔離,原理是每個線程都有一個用於存儲ThreadLocal對象及ThreadLocal對象內部value的map。之前講的線程池的概念,我們知道線程池中工作線程是複用的,也就是講同一個線程會處理多個任務。但在使用者看來,多個任務應該就是多個線程,所以多個每個線程的ThreadLocal應該有獨立的拷貝。但線程池是通過工作線程來執行任務的,如果存在線程複用的情況,那麼就不能保證每個提交的任務都有一個ThreadLocal值的獨立拷貝了,任務之間就會相互影響,來看個例子:

public class ThreadLocalTest3 {
    private static ThreadLocal<AtomicInteger> sequencer = ThreadLocal.withInitial(() -> new AtomicInteger(0));

    static class Task implements Runnable {

        @Override
        public void run() {
            AtomicInteger s = sequencer.get();
            int initial = s.getAndIncrement();
            // 期望初始爲0
            System.out.println(initial);
        }
    }

    public static void main(String[] args) {
        ExecutorService executor = Executors.newFixedThreadPool(2);
        executor.execute(new Task());
        executor.execute(new Task());
        executor.execute(new Task());
        executor.shutdown();
    }
}

對於異步任務Task而言,它期望的初始值應該總是0,但運行程序,結果卻爲:

0
0
1

可以發現,第三次執行異步任務,結果就不對了,爲什麼呢?因爲線程池corePoolSize和maximumPoolSize都是2,當前兩個任務提交時都是創建工作線程執行任務的,但是第三個任務提交時,因爲已經到達corePoolSize,所以會進入等待隊列,等工作線程空閒時執行,這時第三個任務就複用了工作線程。同時線程池中的線程在執行完一個任務,執行下一個任務時,其中的ThreadLocal對象並不會被清空,修改後的值帶到了下一個異步任務。那怎麼辦呢?有幾種思路:

  1. 第一次使用ThreadLocal對象時,總是先調用set設置初始值,或者如果ThreaLocal重寫了initialValue方法,先調用remove
  2. 使用完ThreadLocal對象後,總是調用其remove方法
  3. 使用自定義的線程池

我們分別來看下,對於第一種,在Task的run方法開始處,添加set或remove代碼,如下所示:

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

對於第二種,將Task的run方法包裹在try/finally中,並在finally語句中調用remove,如下所示:

static class Task implements Runnable {

    @Override
    public void run() {
        sequencer.set(new AtomicInteger(0));
        //或者 sequencer.remove();
        
        AtomicInteger s = sequencer.get();
        //...
    }
}

以上兩種方法都比較麻煩,需要更改所有異步任務的代碼,另一種方法是擴展線程池ThreadPoolExecutor,它有一個可以擴展的方法beforeExecute,在之前介紹線程池時也介紹過,我們可以通過重寫該方法實現在任務執行前的一些前置操作,我們創建一個自定義的線程池MyThreadPool,如下:

static class MyThreadPool extends ThreadPoolExecutor {
    public MyThreadPool(int corePoolSize, int maximumPoolSize,
            long keepAliveTime, TimeUnit unit,
            BlockingQueue<Runnable> workQueue) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
    }

    @Override
    protected void beforeExecute(Thread t, Runnable r) {
        try {
            //使用反射清空所有ThreadLocal
            Field f = t.getClass().getDeclaredField("threadLocals");
            f.setAccessible(true);
            f.set(t, null);
        } catch (Exception e) {
            e.printStackTrace();
        }
        super.beforeExecute(t, r);
    }
}

這裏,使用反射,找到線程中存儲ThreadLocal對象的Map變量threadLocals,重置爲null。使用MyThreadPool的示例代碼如下:

public static void main(String[] args) {
    ExecutorService executor = new MyThreadPool(2, 2, 0,
            TimeUnit.MINUTES, new LinkedBlockingQueue<Runnable>());
    executor.execute(new Task());
    executor.execute(new Task());
    executor.execute(new Task());
    executor.shutdown();
}

參考鏈接:

1. Java API

2. 《Java編程的邏輯》

發佈了117 篇原創文章 · 獲贊 33 · 訪問量 8萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章