ThreadLocal 系列之 TransmittableThreadLocal

相關文章:

InheritableThreadLocal 的侷限性

上一篇文章中分析了 ThreadLocal 使用的注意事項,即不適用於出現線程切換的場景。提出了一種解決思路,也分析了 JDK 的解決方案:InheritableThreadLocal,但是 InheritableThreadLocal 的使用也有很大的限制,因爲它是在 Thread 初始化的時候會保存父線程中的變量,但是實際開發中我們幾乎不會去 new 線程,而是會通過線程池去創建線程,這種時候 InheritableThreadLocal 就無法發揮作用了,先看一個例子:

public class InheritableThreadLocalTest2 {

    public static final InheritableThreadLocal<Integer> HOLDER = new InheritableThreadLocal<>();

    private static final ExecutorService FIXED_EXECUTOR = Executors.newFixedThreadPool(2);

    public static void main(String[] args) throws InterruptedException {
        FIXED_EXECUTOR.submit(() -> System.out.println("ready..."));
        Thread.sleep(20);
        HOLDER.set(1);
        print();
        Runnable task = InheritableThreadLocalTest2::print;
        IntStream.range(0, 5).forEach(i -> {
            new Thread(task, "simple task").start();
            FIXED_EXECUTOR.submit(task);
        });
        FIXED_EXECUTOR.shutdown();
    }

    private static void print() {
        System.out.println(currentName() + ":" + HOLDER.get());
    }

    private static String currentName() {
        return Thread.currentThread().getName();
    }

}

輸出結果:

ready...
main:1
simple task:1
pool-1-thread-2:1
simple task:1
pool-1-thread-1:null
pool-1-thread-1:null
simple task:1
pool-1-thread-1:null
simple task:1
pool-1-thread-2:1
simple task:1

會發現 new 出來的線程可以獲取 main 線程中的數據,但是線程池中有的線程可以獲取 mian 線程中的數據,有的線程不行,這樣在實際項目中使用會有很大問題。

TransmittableThreadLocal

TransmittableThreadLocal 是 Alibaba 開源的 Java 庫 TTL 中的一個工具 ,可以看下官方的介紹:

JDKInheritableThreadLocal類可以完成父線程到子線程的值傳遞。但對於使用線程池等會池化複用線程的執行組件的情況,線程由線程池創建好,並且線程是池化起來反覆使用的;這時父子線程關係的ThreadLocal值傳遞已經沒有意義,應用需要的實際上是把 任務提交給線程池時ThreadLocal值傳遞到 任務執行時

其中加粗的部分也不表達了主要設計思想,即“應用需要的實際上是把任務提交給線程池時ThreadLocal值傳遞到任務執行時”,與在上一篇文章中提到的一種解決思路很類似。先看看 TransmittableThreadLocal 的效果,將上面的例子中的 InheritableThreadLocal 換成使用 TransmittableThreadLocal

public class InheritableThreadLocalTest2 {

    //public static final InheritableThreadLocal<Integer> HOLDER = new InheritableThreadLocal<>();
    public static final TransmittableThreadLocal<Integer> HOLDER = new TransmittableThreadLocal<>();

    private static final ExecutorService FIXED_EXECUTOR = Executors.newFixedThreadPool(2);

    public static void main(String[] args) throws InterruptedException {
        FIXED_EXECUTOR.submit(() -> System.out.println("ready..."));
        Thread.sleep(20);
        HOLDER.set(1);
        print();
        Runnable task = InheritableThreadLocalTest2::print;
        // 額外的處理,生成修飾了的對象ttlRunnable
        TtlRunnable ttlRunnable = TtlRunnable.get(task);
        IntStream.range(0, 5).forEach(i -> {
            new Thread(ttlRunnable, "simple task").start();
            FIXED_EXECUTOR.submit(ttlRunnable);
        });
        FIXED_EXECUTOR.shutdown();
    }

    private static void print() {
        System.out.println(currentName() + ":" + HOLDER.get());
    }

    private static String currentName() {
        return Thread.currentThread().getName();
    }

}

輸出結果:

ready...
main:1
pool-1-thread-2:1
simple task:1
pool-1-thread-1:1
pool-1-thread-2:1
simple task:1
simple task:1
pool-1-thread-1:1
simple task:1
pool-1-thread-2:1
simple task:1

發現正是我們想要的,子線程可以獲取到父線程到數據。

原理分析

到這裏其實可以大致猜想其實現原理:存儲父線程的數據,包裝 Runable 類,執行任務時子線將父線程的數據設置到子線程到上下文中。

先看如何“存儲父線程的數據”,看這個方法 TransmittableThreadLocal#set

@Override
    public final void set(T value) {
        if (!disableIgnoreNullValueSemantics && null == value) {
            // may set null to remove value
            remove();
        } else {
            super.set(value);
            addThisToHolder();
        }
    }

TransmittableThreadLocal 繼承了 InheritableThreadLocal,也就是說首先會將數據設置到父線程的 inheritableThreadLocals 中,這裏可以先測試一下:

public class TransmittableThreadLocalTest1 {

    public static final TransmittableThreadLocal<Integer> TTL_HOLDER = new TransmittableThreadLocal<>();
    public static final InheritableThreadLocal<Integer> ITL_HOLDER = new InheritableThreadLocal<>();

    public static void main(String[] args) {
        TTL_HOLDER.set(1);
        System.out.println(ITL_HOLDER.get());
    }
}

輸出結果:

null

這裏輸出爲 null 貌似有點奇怪,但是想一下也很正常,因爲本質上數據都是存在 ThreadLocalMap 裏面,我們是要通過 this 去獲取到底存在數組的哪個槽裏面,而這裏 this 是不一樣的,這也是 ThreadLocal 設計的時候就不會出現同一個線程的數據會在多個 ThreadLocal 中錯亂。具體可以 debug 找出答案:

在這裏插入圖片描述

此時會發現 main 線程的 inheritableThreadLocals 是空的。再看:

在這裏插入圖片描述

也就是在 main 線程中設置了 TransmittableThreadLocal 的值後,會發現 InheritableThreadLocal 會爲當前線程的 inheritableThreadLocals 賦值,但是 InheritableThreadLocal#get 是無法獲取數據,是因爲 this 不一樣,所以槽不一樣。

這裏只是一個小插曲,再看 addThisToHolder 方法:

    private void addThisToHolder() {
        if (!holder.get().containsKey(this)) {
            holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
        }
    }

直接將當前 TransmittableThreadLocal 存進去了,holder 如下:

    // Note about holder:
    // 1. holder self is a InheritableThreadLocal(a *ThreadLocal*).
    // 2. The type of value in holder is WeakHashMap<TransmittableThreadLocal<Object>, ?>.
    //    2.1 but the WeakHashMap is used as a *Set*:
    //        - the value of WeakHashMap is *always null,
    //        - and never be used.
    //    2.2 WeakHashMap support *null* value.
    private static InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
            new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
                }
								//重寫了 childValue 方法,在 Thread#init 的時候會調用,也就是說線程初始化的時候直接將父線程的 inheritableThreadLocals 進行賦值
                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
                }
            };

發現 holder 就是 InheritableThreadLocal,也就是說子線程是可以獲取到父線程中 holder 的數據的,可以繼承。即當前 TransmittableThreadLocal 類是一箇中心,它的 holder 維護了所有的 TransmittableThreadLocal

再看 TtlRunnable#get 方法,猜測這個方法的作用就是包裝 Runable 類,存儲父線程的數據:

    @Nullable
    public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
        if (null == runnable) return null;

        if (runnable instanceof TtlEnhanced) {
            // avoid redundant decoration, and ensure idempotency
            if (idempotent) return (TtlRunnable) runnable;
            else throw new IllegalStateException("Already TtlRunnable!");
        }
        return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
    }

    private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
        this.capturedRef = new AtomicReference<Object>(capture());
        this.runnable = runnable;
        this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
    }

能夠看到 TtlRunnable 是對 Runnable 的一個包裝,將數據賦值給了 capturedRef,接下來看看 capture 方法:

/**
         * Capture all {@link TransmittableThreadLocal} and registered {@link ThreadLocal} values in the current thread.
         *
         * @return the captured {@link TransmittableThreadLocal} values
         * @since 2.3.0
         */
        @NonNull
        public static Object capture() {
            return new Snapshot(captureTtlValues(), captureThreadLocalValues());
        }

再看 captureTtlValues 方法:

private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
            WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
            for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
              	//copyValue 是提供的一個可擴展的方法,即如何存儲值
                ttl2Value.put(threadLocal, threadLocal.copyValue());
            }
            return ttl2Value;
        }

這裏說白了就是將 holder 中的所有的數據"copy"過來。

再看 TtlRunnable#run

    @Override
    public void run() {
       //獲取捕獲的父線程數據
        Object captured = capturedRef.get();
        if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
            throw new IllegalStateException("TTL value reference is released after run!");
        }
				//設置數據,返回當前線程原本就有的數據
        Object backup = replay(captured);
        try {
            runnable.run();
        } finally {
        //恢復當前線程原有數據
            restore(backup);
        }
    }

這個方法很清晰,主體流程和之前的猜想差不多。先看 replay 方法:

public static Object replay(@NonNull Object captured) {
            final Snapshot capturedSnapshot = (Snapshot) captured;
 //重放父線程中的 TransmittableThreadLocal數據(這裏封裝了一個 Snapshot,我們主要關注 ttl2Value)
            return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
        }

@NonNull
        private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) {
            WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();

            for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                TransmittableThreadLocal<Object> threadLocal = iterator.next();

                // backup
                backup.put(threadLocal, threadLocal.get());
								
                // clear the TTL values that is not in captured
                // avoid the extra TTL values after replay when run task
                if (!captured.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.superRemove();
                }
            }

            // set TTL values to captured
            setTtlValuesTo(captured);

            // call beforeExecute callback
            doExecuteCallback(true);

            return backup;
        }

這段代碼就是首先將原來的數據存下來,然後爲當前線程設置值:

private static void setTtlValuesTo(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
            for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
                TransmittableThreadLocal<Object> threadLocal = entry.getKey();
              //數據設置到當前線程
                threadLocal.set(entry.getValue());
            }
        }

再看 restore 方法:

public static void restore(@NonNull Object backup) {
    final Snapshot backupSnapshot = (Snapshot) backup;
    restoreTtlValues(backupSnapshot.ttl2Value);
    restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}

        private static void restoreTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> backup) {
            // call afterExecute callback
            doExecuteCallback(false);

            for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                TransmittableThreadLocal<Object> threadLocal = iterator.next();

                // clear the TTL values that is not in backup
                // avoid the extra TTL values after restore
                if (!backup.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.superRemove();
                }
            }

            // restore TTL values
            setTtlValuesTo(backup);
        }

就是重置之前當前線程的 TransmittableThreadLocal 數據,主要是因爲線程池中線程會複用,避免當前任務執行修改了上下文數據會影響下一次線程的使用。可以簡單看下這個例子:

public class Test4 {

    public static final TransmittableThreadLocal<String> TTL_HOLDER = new TransmittableThreadLocal<>();

    private static final ExecutorService FIXED_EXECUTOR = Executors.newSingleThreadExecutor();

    public static void main(String[] args) {
        print();
        Runnable task = ()->{
            print();
            TTL_HOLDER.set(UUID.randomUUID().toString());
            print();
        };
        // 額外的處理,生成修飾了的對象ttlRunnable
        TtlRunnable ttlRunnable = TtlRunnable.get(task);
        new Thread(()-> FIXED_EXECUTOR.submit(ttlRunnable),"T!").start();
        sleep20();
        new Thread(()-> FIXED_EXECUTOR.submit(ttlRunnable),"T2").start();
        sleep20();
        print();
    }

    private static void sleep20() {
        try {
            Thread.sleep(2000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    private static void print() {
        System.out.println(Thread.currentThread().getName() + ":" + TTL_HOLDER.get());
    }
}

輸出結果:

main:null
pool-1-thread-1:null
pool-1-thread-1:b07df1f9-f085-4204-aa0a-1f16b53f6c65
pool-1-thread-1:null
pool-1-thread-1:580ee3aa-0fe9-4f9d-9214-3b5988c4c477
main:null

總得來說,最關鍵的就是提前會將數據存儲在全局的 holder 中,這樣不會因爲線程池線程生成時機問題而造成數據無法傳遞。

References

  • https://github.com/alibaba/transmittable-thread-local
  • https://mp.weixin.qq.com/s/Y57WCfhAZylXvraD1Eypug

歡迎關注公衆號
​​​​​​在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章