ThreadLocal 系列之 TransmittableThreadLocal

相关文章:

InheritableThreadLocal 的局限性

上一篇文章中分析了 ThreadLocal 使用的注意事项,即不适用于出现线程切换的场景。提出了一种解决思路,也分析了 JDK 的解决方案:InheritableThreadLocal,但是 InheritableThreadLocal 的使用也有很大的限制,因为它是在 Thread 初始化的时候会保存父线程中的变量,但是实际开发中我们几乎不会去 new 线程,而是会通过线程池去创建线程,这种时候 InheritableThreadLocal 就无法发挥作用了,先看一个例子:

public class InheritableThreadLocalTest2 {

    public static final InheritableThreadLocal<Integer> HOLDER = new InheritableThreadLocal<>();

    private static final ExecutorService FIXED_EXECUTOR = Executors.newFixedThreadPool(2);

    public static void main(String[] args) throws InterruptedException {
        FIXED_EXECUTOR.submit(() -> System.out.println("ready..."));
        Thread.sleep(20);
        HOLDER.set(1);
        print();
        Runnable task = InheritableThreadLocalTest2::print;
        IntStream.range(0, 5).forEach(i -> {
            new Thread(task, "simple task").start();
            FIXED_EXECUTOR.submit(task);
        });
        FIXED_EXECUTOR.shutdown();
    }

    private static void print() {
        System.out.println(currentName() + ":" + HOLDER.get());
    }

    private static String currentName() {
        return Thread.currentThread().getName();
    }

}

输出结果:

ready...
main:1
simple task:1
pool-1-thread-2:1
simple task:1
pool-1-thread-1:null
pool-1-thread-1:null
simple task:1
pool-1-thread-1:null
simple task:1
pool-1-thread-2:1
simple task:1

会发现 new 出来的线程可以获取 main 线程中的数据,但是线程池中有的线程可以获取 mian 线程中的数据,有的线程不行,这样在实际项目中使用会有很大问题。

TransmittableThreadLocal

TransmittableThreadLocal 是 Alibaba 开源的 Java 库 TTL 中的一个工具 ,可以看下官方的介绍:

JDKInheritableThreadLocal类可以完成父线程到子线程的值传递。但对于使用线程池等会池化复用线程的执行组件的情况,线程由线程池创建好,并且线程是池化起来反复使用的;这时父子线程关系的ThreadLocal值传递已经没有意义,应用需要的实际上是把 任务提交给线程池时ThreadLocal值传递到 任务执行时

其中加粗的部分也不表达了主要设计思想,即“应用需要的实际上是把任务提交给线程池时ThreadLocal值传递到任务执行时”,与在上一篇文章中提到的一种解决思路很类似。先看看 TransmittableThreadLocal 的效果,将上面的例子中的 InheritableThreadLocal 换成使用 TransmittableThreadLocal

public class InheritableThreadLocalTest2 {

    //public static final InheritableThreadLocal<Integer> HOLDER = new InheritableThreadLocal<>();
    public static final TransmittableThreadLocal<Integer> HOLDER = new TransmittableThreadLocal<>();

    private static final ExecutorService FIXED_EXECUTOR = Executors.newFixedThreadPool(2);

    public static void main(String[] args) throws InterruptedException {
        FIXED_EXECUTOR.submit(() -> System.out.println("ready..."));
        Thread.sleep(20);
        HOLDER.set(1);
        print();
        Runnable task = InheritableThreadLocalTest2::print;
        // 额外的处理,生成修饰了的对象ttlRunnable
        TtlRunnable ttlRunnable = TtlRunnable.get(task);
        IntStream.range(0, 5).forEach(i -> {
            new Thread(ttlRunnable, "simple task").start();
            FIXED_EXECUTOR.submit(ttlRunnable);
        });
        FIXED_EXECUTOR.shutdown();
    }

    private static void print() {
        System.out.println(currentName() + ":" + HOLDER.get());
    }

    private static String currentName() {
        return Thread.currentThread().getName();
    }

}

输出结果:

ready...
main:1
pool-1-thread-2:1
simple task:1
pool-1-thread-1:1
pool-1-thread-2:1
simple task:1
simple task:1
pool-1-thread-1:1
simple task:1
pool-1-thread-2:1
simple task:1

发现正是我们想要的,子线程可以获取到父线程到数据。

原理分析

到这里其实可以大致猜想其实现原理:存储父线程的数据,包装 Runable 类,执行任务时子线将父线程的数据设置到子线程到上下文中。

先看如何“存储父线程的数据”,看这个方法 TransmittableThreadLocal#set

@Override
    public final void set(T value) {
        if (!disableIgnoreNullValueSemantics && null == value) {
            // may set null to remove value
            remove();
        } else {
            super.set(value);
            addThisToHolder();
        }
    }

TransmittableThreadLocal 继承了 InheritableThreadLocal,也就是说首先会将数据设置到父线程的 inheritableThreadLocals 中,这里可以先测试一下:

public class TransmittableThreadLocalTest1 {

    public static final TransmittableThreadLocal<Integer> TTL_HOLDER = new TransmittableThreadLocal<>();
    public static final InheritableThreadLocal<Integer> ITL_HOLDER = new InheritableThreadLocal<>();

    public static void main(String[] args) {
        TTL_HOLDER.set(1);
        System.out.println(ITL_HOLDER.get());
    }
}

输出结果:

null

这里输出为 null 貌似有点奇怪,但是想一下也很正常,因为本质上数据都是存在 ThreadLocalMap 里面,我们是要通过 this 去获取到底存在数组的哪个槽里面,而这里 this 是不一样的,这也是 ThreadLocal 设计的时候就不会出现同一个线程的数据会在多个 ThreadLocal 中错乱。具体可以 debug 找出答案:

在这里插入图片描述

此时会发现 main 线程的 inheritableThreadLocals 是空的。再看:

在这里插入图片描述

也就是在 main 线程中设置了 TransmittableThreadLocal 的值后,会发现 InheritableThreadLocal 会为当前线程的 inheritableThreadLocals 赋值,但是 InheritableThreadLocal#get 是无法获取数据,是因为 this 不一样,所以槽不一样。

这里只是一个小插曲,再看 addThisToHolder 方法:

    private void addThisToHolder() {
        if (!holder.get().containsKey(this)) {
            holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
        }
    }

直接将当前 TransmittableThreadLocal 存进去了,holder 如下:

    // Note about holder:
    // 1. holder self is a InheritableThreadLocal(a *ThreadLocal*).
    // 2. The type of value in holder is WeakHashMap<TransmittableThreadLocal<Object>, ?>.
    //    2.1 but the WeakHashMap is used as a *Set*:
    //        - the value of WeakHashMap is *always null,
    //        - and never be used.
    //    2.2 WeakHashMap support *null* value.
    private static InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
            new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
                }
								//重写了 childValue 方法,在 Thread#init 的时候会调用,也就是说线程初始化的时候直接将父线程的 inheritableThreadLocals 进行赋值
                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
                }
            };

发现 holder 就是 InheritableThreadLocal,也就是说子线程是可以获取到父线程中 holder 的数据的,可以继承。即当前 TransmittableThreadLocal 类是一个中心,它的 holder 维护了所有的 TransmittableThreadLocal

再看 TtlRunnable#get 方法,猜测这个方法的作用就是包装 Runable 类,存储父线程的数据:

    @Nullable
    public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
        if (null == runnable) return null;

        if (runnable instanceof TtlEnhanced) {
            // avoid redundant decoration, and ensure idempotency
            if (idempotent) return (TtlRunnable) runnable;
            else throw new IllegalStateException("Already TtlRunnable!");
        }
        return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
    }

    private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
        this.capturedRef = new AtomicReference<Object>(capture());
        this.runnable = runnable;
        this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
    }

能够看到 TtlRunnable 是对 Runnable 的一个包装,将数据赋值给了 capturedRef,接下来看看 capture 方法:

/**
         * Capture all {@link TransmittableThreadLocal} and registered {@link ThreadLocal} values in the current thread.
         *
         * @return the captured {@link TransmittableThreadLocal} values
         * @since 2.3.0
         */
        @NonNull
        public static Object capture() {
            return new Snapshot(captureTtlValues(), captureThreadLocalValues());
        }

再看 captureTtlValues 方法:

private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
            WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
            for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
              	//copyValue 是提供的一个可扩展的方法,即如何存储值
                ttl2Value.put(threadLocal, threadLocal.copyValue());
            }
            return ttl2Value;
        }

这里说白了就是将 holder 中的所有的数据"copy"过来。

再看 TtlRunnable#run

    @Override
    public void run() {
       //获取捕获的父线程数据
        Object captured = capturedRef.get();
        if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
            throw new IllegalStateException("TTL value reference is released after run!");
        }
				//设置数据,返回当前线程原本就有的数据
        Object backup = replay(captured);
        try {
            runnable.run();
        } finally {
        //恢复当前线程原有数据
            restore(backup);
        }
    }

这个方法很清晰,主体流程和之前的猜想差不多。先看 replay 方法:

public static Object replay(@NonNull Object captured) {
            final Snapshot capturedSnapshot = (Snapshot) captured;
 //重放父线程中的 TransmittableThreadLocal数据(这里封装了一个 Snapshot,我们主要关注 ttl2Value)
            return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
        }

@NonNull
        private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) {
            WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();

            for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                TransmittableThreadLocal<Object> threadLocal = iterator.next();

                // backup
                backup.put(threadLocal, threadLocal.get());
								
                // clear the TTL values that is not in captured
                // avoid the extra TTL values after replay when run task
                if (!captured.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.superRemove();
                }
            }

            // set TTL values to captured
            setTtlValuesTo(captured);

            // call beforeExecute callback
            doExecuteCallback(true);

            return backup;
        }

这段代码就是首先将原来的数据存下来,然后为当前线程设置值:

private static void setTtlValuesTo(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
            for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
                TransmittableThreadLocal<Object> threadLocal = entry.getKey();
              //数据设置到当前线程
                threadLocal.set(entry.getValue());
            }
        }

再看 restore 方法:

public static void restore(@NonNull Object backup) {
    final Snapshot backupSnapshot = (Snapshot) backup;
    restoreTtlValues(backupSnapshot.ttl2Value);
    restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}

        private static void restoreTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> backup) {
            // call afterExecute callback
            doExecuteCallback(false);

            for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                TransmittableThreadLocal<Object> threadLocal = iterator.next();

                // clear the TTL values that is not in backup
                // avoid the extra TTL values after restore
                if (!backup.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.superRemove();
                }
            }

            // restore TTL values
            setTtlValuesTo(backup);
        }

就是重置之前当前线程的 TransmittableThreadLocal 数据,主要是因为线程池中线程会复用,避免当前任务执行修改了上下文数据会影响下一次线程的使用。可以简单看下这个例子:

public class Test4 {

    public static final TransmittableThreadLocal<String> TTL_HOLDER = new TransmittableThreadLocal<>();

    private static final ExecutorService FIXED_EXECUTOR = Executors.newSingleThreadExecutor();

    public static void main(String[] args) {
        print();
        Runnable task = ()->{
            print();
            TTL_HOLDER.set(UUID.randomUUID().toString());
            print();
        };
        // 额外的处理,生成修饰了的对象ttlRunnable
        TtlRunnable ttlRunnable = TtlRunnable.get(task);
        new Thread(()-> FIXED_EXECUTOR.submit(ttlRunnable),"T!").start();
        sleep20();
        new Thread(()-> FIXED_EXECUTOR.submit(ttlRunnable),"T2").start();
        sleep20();
        print();
    }

    private static void sleep20() {
        try {
            Thread.sleep(2000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    private static void print() {
        System.out.println(Thread.currentThread().getName() + ":" + TTL_HOLDER.get());
    }
}

输出结果:

main:null
pool-1-thread-1:null
pool-1-thread-1:b07df1f9-f085-4204-aa0a-1f16b53f6c65
pool-1-thread-1:null
pool-1-thread-1:580ee3aa-0fe9-4f9d-9214-3b5988c4c477
main:null

总得来说,最关键的就是提前会将数据存储在全局的 holder 中,这样不会因为线程池线程生成时机问题而造成数据无法传递。

References

  • https://github.com/alibaba/transmittable-thread-local
  • https://mp.weixin.qq.com/s/Y57WCfhAZylXvraD1Eypug

欢迎关注公众号
​​​​​​在这里插入图片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章