相关文章:
InheritableThreadLocal 的局限性
在上一篇文章中分析了 ThreadLocal
使用的注意事项,即不适用于出现线程切换的场景。提出了一种解决思路,也分析了 JDK 的解决方案:InheritableThreadLocal
,但是 InheritableThreadLocal
的使用也有很大的限制,因为它是在 Thread
初始化的时候会保存父线程中的变量,但是实际开发中我们几乎不会去 new
线程,而是会通过线程池去创建线程,这种时候 InheritableThreadLocal
就无法发挥作用了,先看一个例子:
public class InheritableThreadLocalTest2 {
public static final InheritableThreadLocal<Integer> HOLDER = new InheritableThreadLocal<>();
private static final ExecutorService FIXED_EXECUTOR = Executors.newFixedThreadPool(2);
public static void main(String[] args) throws InterruptedException {
FIXED_EXECUTOR.submit(() -> System.out.println("ready..."));
Thread.sleep(20);
HOLDER.set(1);
print();
Runnable task = InheritableThreadLocalTest2::print;
IntStream.range(0, 5).forEach(i -> {
new Thread(task, "simple task").start();
FIXED_EXECUTOR.submit(task);
});
FIXED_EXECUTOR.shutdown();
}
private static void print() {
System.out.println(currentName() + ":" + HOLDER.get());
}
private static String currentName() {
return Thread.currentThread().getName();
}
}
输出结果:
ready...
main:1
simple task:1
pool-1-thread-2:1
simple task:1
pool-1-thread-1:null
pool-1-thread-1:null
simple task:1
pool-1-thread-1:null
simple task:1
pool-1-thread-2:1
simple task:1
会发现 new
出来的线程可以获取 main
线程中的数据,但是线程池中有的线程可以获取 mian
线程中的数据,有的线程不行,这样在实际项目中使用会有很大问题。
TransmittableThreadLocal
TransmittableThreadLocal
是 Alibaba 开源的 Java 库 TTL 中的一个工具 ,可以看下官方的介绍:
JDK
的InheritableThreadLocal
类可以完成父线程到子线程的值传递。但对于使用线程池等会池化复用线程的执行组件的情况,线程由线程池创建好,并且线程是池化起来反复使用的;这时父子线程关系的ThreadLocal
值传递已经没有意义,应用需要的实际上是把 任务提交给线程池时的ThreadLocal
值传递到 任务执行时。
其中加粗的部分也不表达了主要设计思想,即“应用需要的实际上是把任务提交给线程池时的ThreadLocal
值传递到任务执行时”,与在上一篇文章中提到的一种解决思路很类似。先看看 TransmittableThreadLocal 的效果,将上面的例子中的 InheritableThreadLocal
换成使用 TransmittableThreadLocal
:
public class InheritableThreadLocalTest2 {
//public static final InheritableThreadLocal<Integer> HOLDER = new InheritableThreadLocal<>();
public static final TransmittableThreadLocal<Integer> HOLDER = new TransmittableThreadLocal<>();
private static final ExecutorService FIXED_EXECUTOR = Executors.newFixedThreadPool(2);
public static void main(String[] args) throws InterruptedException {
FIXED_EXECUTOR.submit(() -> System.out.println("ready..."));
Thread.sleep(20);
HOLDER.set(1);
print();
Runnable task = InheritableThreadLocalTest2::print;
// 额外的处理,生成修饰了的对象ttlRunnable
TtlRunnable ttlRunnable = TtlRunnable.get(task);
IntStream.range(0, 5).forEach(i -> {
new Thread(ttlRunnable, "simple task").start();
FIXED_EXECUTOR.submit(ttlRunnable);
});
FIXED_EXECUTOR.shutdown();
}
private static void print() {
System.out.println(currentName() + ":" + HOLDER.get());
}
private static String currentName() {
return Thread.currentThread().getName();
}
}
输出结果:
ready...
main:1
pool-1-thread-2:1
simple task:1
pool-1-thread-1:1
pool-1-thread-2:1
simple task:1
simple task:1
pool-1-thread-1:1
simple task:1
pool-1-thread-2:1
simple task:1
发现正是我们想要的,子线程可以获取到父线程到数据。
原理分析
到这里其实可以大致猜想其实现原理:存储父线程的数据,包装 Runable
类,执行任务时子线将父线程的数据设置到子线程到上下文中。
先看如何“存储父线程的数据”,看这个方法 TransmittableThreadLocal#set
:
@Override
public final void set(T value) {
if (!disableIgnoreNullValueSemantics && null == value) {
// may set null to remove value
remove();
} else {
super.set(value);
addThisToHolder();
}
}
TransmittableThreadLocal
继承了 InheritableThreadLocal
,也就是说首先会将数据设置到父线程的 inheritableThreadLocals
中,这里可以先测试一下:
public class TransmittableThreadLocalTest1 {
public static final TransmittableThreadLocal<Integer> TTL_HOLDER = new TransmittableThreadLocal<>();
public static final InheritableThreadLocal<Integer> ITL_HOLDER = new InheritableThreadLocal<>();
public static void main(String[] args) {
TTL_HOLDER.set(1);
System.out.println(ITL_HOLDER.get());
}
}
输出结果:
null
这里输出为 null
貌似有点奇怪,但是想一下也很正常,因为本质上数据都是存在 ThreadLocalMap
里面,我们是要通过 this
去获取到底存在数组的哪个槽里面,而这里 this
是不一样的,这也是 ThreadLocal
设计的时候就不会出现同一个线程的数据会在多个 ThreadLocal
中错乱。具体可以 debug
找出答案:
此时会发现 main
线程的 inheritableThreadLocals
是空的。再看:
也就是在 main
线程中设置了 TransmittableThreadLocal
的值后,会发现 InheritableThreadLocal
会为当前线程的 inheritableThreadLocals
赋值,但是 InheritableThreadLocal#get
是无法获取数据,是因为 this
不一样,所以槽不一样。
这里只是一个小插曲,再看 addThisToHolder
方法:
private void addThisToHolder() {
if (!holder.get().containsKey(this)) {
holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
}
}
直接将当前 TransmittableThreadLocal
存进去了,holder
如下:
// Note about holder:
// 1. holder self is a InheritableThreadLocal(a *ThreadLocal*).
// 2. The type of value in holder is WeakHashMap<TransmittableThreadLocal<Object>, ?>.
// 2.1 but the WeakHashMap is used as a *Set*:
// - the value of WeakHashMap is *always null,
// - and never be used.
// 2.2 WeakHashMap support *null* value.
private static InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
@Override
protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
}
//重写了 childValue 方法,在 Thread#init 的时候会调用,也就是说线程初始化的时候直接将父线程的 inheritableThreadLocals 进行赋值
@Override
protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
}
};
发现 holder 就是 InheritableThreadLocal
,也就是说子线程是可以获取到父线程中 holder
的数据的,可以继承。即当前 TransmittableThreadLocal
类是一个中心,它的 holder
维护了所有的 TransmittableThreadLocal
。
再看 TtlRunnable#get
方法,猜测这个方法的作用就是包装 Runable
类,存储父线程的数据:
@Nullable
public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
if (null == runnable) return null;
if (runnable instanceof TtlEnhanced) {
// avoid redundant decoration, and ensure idempotency
if (idempotent) return (TtlRunnable) runnable;
else throw new IllegalStateException("Already TtlRunnable!");
}
return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}
private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
this.capturedRef = new AtomicReference<Object>(capture());
this.runnable = runnable;
this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}
能够看到 TtlRunnable
是对 Runnable
的一个包装,将数据赋值给了 capturedRef
,接下来看看 capture
方法:
/**
* Capture all {@link TransmittableThreadLocal} and registered {@link ThreadLocal} values in the current thread.
*
* @return the captured {@link TransmittableThreadLocal} values
* @since 2.3.0
*/
@NonNull
public static Object capture() {
return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}
再看 captureTtlValues
方法:
private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
//copyValue 是提供的一个可扩展的方法,即如何存储值
ttl2Value.put(threadLocal, threadLocal.copyValue());
}
return ttl2Value;
}
这里说白了就是将 holder
中的所有的数据"copy"过来。
再看 TtlRunnable#run
:
@Override
public void run() {
//获取捕获的父线程数据
Object captured = capturedRef.get();
if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
throw new IllegalStateException("TTL value reference is released after run!");
}
//设置数据,返回当前线程原本就有的数据
Object backup = replay(captured);
try {
runnable.run();
} finally {
//恢复当前线程原有数据
restore(backup);
}
}
这个方法很清晰,主体流程和之前的猜想差不多。先看 replay
方法:
public static Object replay(@NonNull Object captured) {
final Snapshot capturedSnapshot = (Snapshot) captured;
//重放父线程中的 TransmittableThreadLocal数据(这里封装了一个 Snapshot,我们主要关注 ttl2Value)
return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}
@NonNull
private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) {
WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
// backup
backup.put(threadLocal, threadLocal.get());
// clear the TTL values that is not in captured
// avoid the extra TTL values after replay when run task
if (!captured.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// set TTL values to captured
setTtlValuesTo(captured);
// call beforeExecute callback
doExecuteCallback(true);
return backup;
}
这段代码就是首先将原来的数据存下来,然后为当前线程设置值:
private static void setTtlValuesTo(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
TransmittableThreadLocal<Object> threadLocal = entry.getKey();
//数据设置到当前线程
threadLocal.set(entry.getValue());
}
}
再看 restore
方法:
public static void restore(@NonNull Object backup) {
final Snapshot backupSnapshot = (Snapshot) backup;
restoreTtlValues(backupSnapshot.ttl2Value);
restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}
private static void restoreTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> backup) {
// call afterExecute callback
doExecuteCallback(false);
for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
// clear the TTL values that is not in backup
// avoid the extra TTL values after restore
if (!backup.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// restore TTL values
setTtlValuesTo(backup);
}
就是重置之前当前线程的 TransmittableThreadLocal
数据,主要是因为线程池中线程会复用,避免当前任务执行修改了上下文数据会影响下一次线程的使用。可以简单看下这个例子:
public class Test4 {
public static final TransmittableThreadLocal<String> TTL_HOLDER = new TransmittableThreadLocal<>();
private static final ExecutorService FIXED_EXECUTOR = Executors.newSingleThreadExecutor();
public static void main(String[] args) {
print();
Runnable task = ()->{
print();
TTL_HOLDER.set(UUID.randomUUID().toString());
print();
};
// 额外的处理,生成修饰了的对象ttlRunnable
TtlRunnable ttlRunnable = TtlRunnable.get(task);
new Thread(()-> FIXED_EXECUTOR.submit(ttlRunnable),"T!").start();
sleep20();
new Thread(()-> FIXED_EXECUTOR.submit(ttlRunnable),"T2").start();
sleep20();
print();
}
private static void sleep20() {
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
private static void print() {
System.out.println(Thread.currentThread().getName() + ":" + TTL_HOLDER.get());
}
}
输出结果:
main:null
pool-1-thread-1:null
pool-1-thread-1:b07df1f9-f085-4204-aa0a-1f16b53f6c65
pool-1-thread-1:null
pool-1-thread-1:580ee3aa-0fe9-4f9d-9214-3b5988c4c477
main:null
总得来说,最关键的就是提前会将数据存储在全局的 holder
中,这样不会因为线程池线程生成时机问题而造成数据无法传递。
References
- https://github.com/alibaba/transmittable-thread-local
- https://mp.weixin.qq.com/s/Y57WCfhAZylXvraD1Eypug
欢迎关注公众号