Java编程拾遗『ThreadLocal』

本篇文章,我们来介绍一下,Java多线程编程中一个比较常用的工具,线程本地变量——ThreadLocal。ThreadLocal简单的来讲,就是每个线程都有同一个成员变量的独有拷贝。由于每个线程都有成员变量独立的拷贝,所以不存在多线程访问同一成员变量的问题,也就解决了线程安全问题。之前我们介绍的锁,是解决线程安全问题的一个途径,那么本篇文章介绍的线程本地变量是另一种解决线程安全问题的思想,就是通过额外每个线程都分配成员变量的拷贝,来换取加锁带来的并发效率问题,可以讲是一种空间换时间的思想。本篇文章我们就来介绍一下ThreadLocal的用法及实现原理。

1. 基本概念

ThreadLocal是一个泛型类,接受一个类型参数T,它只有一个空的构造方法,有两个主要的public方法:

set就是设置值,get就是获取值,如果没有值,返回null。看上去,ThreadLocal就是一个单一对象的容器,比如:

public static void main(String[] args) {
    ThreadLocal<Integer> local = new ThreadLocal<>();
    local.set(100);
    System.out.println(local.get());
}

输出100,没什么好讲的。ThreadLocal方便的地方主要在多线程的场景,如下:

public class ThreadLocalTest {
    private static ThreadLocal<Integer> local = new ThreadLocal<>();

    public static void main(String[] args) throws InterruptedException {
        Thread child = new Thread(() -> {
            System.out.println("child thread initial: " + local.get());
            local.set(200);
            System.out.println("child thread final: " + local.get());
        });
        
        local.set(100);
        child.start();
        child.join();
        System.out.println("main thread final: " + local.get());
    }
}

运行结果:

child thread initial: null
child thread final: 200
main thread final: 100

说明,main线程对local变量的设置对child线程不起作用,child线程对local变量的改变也不会影响main线程,它们访问的虽然是同一个变量local,但每个线程都有自己的独立的值,这就是线程本地变量的含义

ThreadLocal中除了上述方法,还有一个比较重要的protected方法initialValue,如下:

protected T initialValue()
public void remove()

initialValue用于提供初始值,它是一个protected方法,可以通过匿名内部类的方式提供,当调用get方法时,如果之前没有设置过,会调用initialValue方法获取初始值,默认实现是返回null。remove删掉当前线程对应的值,如果删掉后,再次调用get,会再调用initialValue获取初始值。

public class ThreadLocalTest1 {
    private static ThreadLocal<Integer> local = new ThreadLocal<Integer>(){

        @Override
        protected Integer initialValue() {
            return 100;
        }
    };

    public static void main(String[] args) {
        System.out.println(local.get());
        local.set(200);
        local.remove();
        System.out.println(local.get());
    }
}

运行结果:

100
100

说明当调用get方法时,如果之前没有设置过,会调用initialValue方法获取初始值,默认实现是返回null。remove删掉当前线程对应的值,如果删掉后,再次调用get,会再调用initialValue获取初始值。

Java8之后,ThreadLocal新增了一个静态方法withInitial,可以用来指定ThreadLocal的初始值,达到跟initialValue相同的效果,声明如下:

public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
    return new SuppliedThreadLocal<>(supplier);
}

withInitial参数supplier就是用来指定初始值的,所以上述ThreadLocal声明可以替换为:

private static ThreadLocal<Integer> local = ThreadLocal.withInitial(() -> 100);

2. ThreadLocal应用场景

2.1 实现线程安全

ThreadLocal是实现线程安全的一种方案,比如对于DateFormat/SimpleDateFormat,这两个类是非线程安全类,Java8及以后提供了线程安全的解决方案:DateTimeFormat。而如果要以线程安全的方式使用DateFormat/SimpleDateFormat,可以又以下集中选择:

  • 使用锁
  • 不共享,每次使用DateFormat/SimpleDateFormat都新建一个对象
  • 使用ThreadLocal

下面我们就来介绍以下为什么DateFormat/SimpleDateFormat非线程安全,以及上述三种解决方案。首先来看个例子:

public class ThreadLocalTest2 {
    private static SimpleDateFormat simpleDateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");

    private static ExecutorService pool = Executors.newFixedThreadPool(200);

    private static CountDownLatch countDownLatch = new CountDownLatch(100);
    
    public static void main(String[] args) throws InterruptedException {
        Set<String> dates = Collections.synchronizedSet(new HashSet<String>());
        for (int i = 0; i < 100; i++) {
            //获取当前时间
            Calendar calendar = Calendar.getInstance();
            int finalI = i;
            pool.execute(() -> {
                //时间增加
                calendar.add(Calendar.DATE, finalI);
                //通过simpleDateFormat把时间转换成字符串
                String dateString = simpleDateFormat.format(calendar.getTime());
                //把字符串放入Set中
                dates.add(dateString);
                //countDown
                countDownLatch.countDown();
            });
        }
        //阻塞,直到countDown数量为0
        countDownLatch.await();
        //输出去重后的时间个数
        System.out.println(dates.size());

    }
}

上述示例,启动一个线程池,for循环中循环提交100个任务,每个任务的职责是将当前日期加上每次循环的天数,并通过simpleDateFormat格式化,将格式化的结果添加到Set中。那么理论上循环100次,肯定能得到100个不同的结果,set中最终会有100个元素。但是上述代码的执行结果一般不是100,说明SimpleDateFormat类是存在线程安全问题的。

我们来看一下SimpleDateFormat类的format方法实现,就能指导为什么SimpleDateFormat非线程安全了:

SimpleDateFormat中的format方法在执行过程中,会使用一个成员变量calendar来保存时间。我们在声明SimpleDateFormat的时候,使用的是static定义的。那么这个SimpleDateFormat就是一个共享变量,所以,SimpleDateFormat中的calendar也就可以被多个线程访问到。

假设线程1刚刚执行完calendar.setTime把时间设置成2019-10-10,还没等执行完,线程2又执行了calendar.setTime把时间改成了2019-10-11。这时候线程1继续往下执行,拿到的calendar.getTime得到的时间就是线程2改过之后的。

除了format方法以外,SimpleDateFormat的parse方法也有同样的问题,DateFormat也一样。所以DateFormat/SimpleDateFormat是非线程安全类,不要把SimpleDateFormat作为一个共享变量使用。如果要线程安全地使用DateFormat/SimpleDateFormat,可以通过如下方式:

2.1.1 使用锁

如果不存在多个线程同时访问SimpleDateFormat对象,那么就不存在SimpleDateFormat内部的Calendar相互影响的问题。基于这种考虑,最简单的实现就是使用同步锁,保证同一时刻只有一个线程可以访问SimpleDateFormat对象,如下:

for (int i = 0; i < 100; i++) {
    //获取当前时间
    Calendar calendar = Calendar.getInstance();
    int finalI = i;
    pool.execute(() -> {
        //加锁
        synchronized (simpleDateFormat) {
            //时间增加
            calendar.add(Calendar.DATE, finalI);
            //通过simpleDateFormat把时间转换成字符串
            String dateString = simpleDateFormat.format(calendar.getTime());
            //把字符串放入Set中
            dates.add(dateString);
            //countDown
            countDownLatch.countDown();
        }
    });
}

2.1.2 使用局部变量代替成员变量

如果SimpleDateFormat对象为局部变量,那么也不会存在多个线程同时访问SimpleDateFormat对象的情况,如下:

for (int i = 0; i < 100; i++) {
    //获取当前时间
    Calendar calendar = Calendar.getInstance();
    int finalI = i;
    pool.execute(() -> {
        // SimpleDateFormat声明成局部变量
        SimpleDateFormat simpleDateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");

        calendar.add(Calendar.DATE, finalI);
        String dateString = simpleDateFormat.format(calendar.getTime());
        dates.add(dateString);
        countDownLatch.countDown();
    });
}

2.1.3 使用ThreadLocal

ThreadLocal可以确保每个线程都可以得到单独的一个SimpleDateFormat对象,那么自然也就不存在竞争问题了。如下:

private static ThreadLocal<SimpleDateFormat> simpleDateFormatThreadLocal = ThreadLocal.withInitial(() ->
            new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"));

2.2 保存上下文信息

ThreadLocal的另一个典型用途是保存上下文信息,比如在一个Web服务器中,一个线程执行用户的请求,在执行过程中,很多代码都会访问一些共同的信息,比如请求信息、用户身份信息、数据库连接、当前事务等,它们是线程执行过程中的全局信息,如果作为参数在不同代码间传递,代码会很啰嗦,这时,使用ThreadLocal就很方便,所以它被用于各种框架如Spring中,如下:

public class RequestContext {
    public static class Request { //...
    };

    private static ThreadLocal<String> localUserId = new ThreadLocal<>();
    private static ThreadLocal<Request> localRequest = new ThreadLocal<>();

    public static String getCurrentUserId() {
        return localUserId.get();
    }

    public static void setCurrentUserId(String userId) {
        localUserId.set(userId);
    }

    public static Request getCurrentRequest() {
        return localRequest.get();
    }

    public static void setCurrentRequest(Request request) {
        localRequest.set(request);
    }

    public void clear() {
        localUserId.clear();
        localRequest.clear();
    }
}

使用上,可以通过切面,在切面中调用set方法如setCurrentRequest/setCurrentUserId进行设置,然后就可以在代码的任意其他地方调用get相关方法进行获取了。

但同时要注意的是,在RequestContext使用结束,要注意调用clear清理ThreadLocal,否则有可能会影响其它线程的执行(因为一般web容器都会使用线程池,线程会复用,如果某个线程使用结束不清理的话,ThreadLocal中的变量会影响后续的使用)。这点会在下面单独介绍。

3. ThreadLocal的底层实现

接下来我们来看一下ThreadLocal的底层实现原理,为什么对同一个ThreadLocal对象的get/set,每个线程都能有自己独立的值。先从set方法看起:

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

首先获取当前线程,并调用getMap方法获取ThreadLocalMap对象。那么来看一下getMap:

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

返回线程的实例对象的threadLocals成员,它的初始值为null。在null时,ThreadLocal类的set方法调用createMap初始化:

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

可以看出,每个线程都有一个Map,类型为ThreadLocalMap,调用set实际上是在线程自己的Map里设置了一个条目,键为当前的ThreadLocal对象,值为value。所以能做到不同线程之间不相互影响。

get方法的代码为:

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

通过线程访问到Map,以ThreadLocal对象为键从Map中获取到条目,取其value,如果Map中没有,调用setInitialValue,如下:

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

initialValue()就是之前提到的提供初始值的方法,默认实现就是返回null。Java8提供的withInitial方法返回为SuppliedThreadLocal对象,该类是ThreadLocal的子类,并在该类中重写了initialValue方法,如下:

public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
    return new SuppliedThreadLocal<>(supplier);
}
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

    private final Supplier<? extends T> supplier;

    SuppliedThreadLocal(Supplier<? extends T> supplier) {
        this.supplier = Objects.requireNonNull(supplier);
    }

    @Override
    protected T initialValue() {
        return supplier.get();
    }
}

所以withInitial方法的函数式参数,就是ThreadLocal的初始值。

最后来看一下remove方法:

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

就是将该ThreadLocal对象在线程map中条目删除。

最后我们来总结一下,为什么ThreadLocal可以实现每个线程都有自己的独立拷贝。因为每个线程都有一个Map,对于每个ThreadLocal对象,调用其get/set方法,实际上就是以ThreadLocal对象为键读写当前线程的Map

4. ThreadLocal与线程池

上面我们讲到,ThreadLocal对象可以实现各个子线程隔离,原理是每个线程都有一个用于存储ThreadLocal对象及ThreadLocal对象内部value的map。之前讲的线程池的概念,我们知道线程池中工作线程是复用的,也就是讲同一个线程会处理多个任务。但在使用者看来,多个任务应该就是多个线程,所以多个每个线程的ThreadLocal应该有独立的拷贝。但线程池是通过工作线程来执行任务的,如果存在线程复用的情况,那么就不能保证每个提交的任务都有一个ThreadLocal值的独立拷贝了,任务之间就会相互影响,来看个例子:

public class ThreadLocalTest3 {
    private static ThreadLocal<AtomicInteger> sequencer = ThreadLocal.withInitial(() -> new AtomicInteger(0));

    static class Task implements Runnable {

        @Override
        public void run() {
            AtomicInteger s = sequencer.get();
            int initial = s.getAndIncrement();
            // 期望初始为0
            System.out.println(initial);
        }
    }

    public static void main(String[] args) {
        ExecutorService executor = Executors.newFixedThreadPool(2);
        executor.execute(new Task());
        executor.execute(new Task());
        executor.execute(new Task());
        executor.shutdown();
    }
}

对于异步任务Task而言,它期望的初始值应该总是0,但运行程序,结果却为:

0
0
1

可以发现,第三次执行异步任务,结果就不对了,为什么呢?因为线程池corePoolSize和maximumPoolSize都是2,当前两个任务提交时都是创建工作线程执行任务的,但是第三个任务提交时,因为已经到达corePoolSize,所以会进入等待队列,等工作线程空闲时执行,这时第三个任务就复用了工作线程。同时线程池中的线程在执行完一个任务,执行下一个任务时,其中的ThreadLocal对象并不会被清空,修改后的值带到了下一个异步任务。那怎么办呢?有几种思路:

  1. 第一次使用ThreadLocal对象时,总是先调用set设置初始值,或者如果ThreaLocal重写了initialValue方法,先调用remove
  2. 使用完ThreadLocal对象后,总是调用其remove方法
  3. 使用自定义的线程池

我们分别来看下,对于第一种,在Task的run方法开始处,添加set或remove代码,如下所示:

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

对于第二种,将Task的run方法包裹在try/finally中,并在finally语句中调用remove,如下所示:

static class Task implements Runnable {

    @Override
    public void run() {
        sequencer.set(new AtomicInteger(0));
        //或者 sequencer.remove();
        
        AtomicInteger s = sequencer.get();
        //...
    }
}

以上两种方法都比较麻烦,需要更改所有异步任务的代码,另一种方法是扩展线程池ThreadPoolExecutor,它有一个可以扩展的方法beforeExecute,在之前介绍线程池时也介绍过,我们可以通过重写该方法实现在任务执行前的一些前置操作,我们创建一个自定义的线程池MyThreadPool,如下:

static class MyThreadPool extends ThreadPoolExecutor {
    public MyThreadPool(int corePoolSize, int maximumPoolSize,
            long keepAliveTime, TimeUnit unit,
            BlockingQueue<Runnable> workQueue) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
    }

    @Override
    protected void beforeExecute(Thread t, Runnable r) {
        try {
            //使用反射清空所有ThreadLocal
            Field f = t.getClass().getDeclaredField("threadLocals");
            f.setAccessible(true);
            f.set(t, null);
        } catch (Exception e) {
            e.printStackTrace();
        }
        super.beforeExecute(t, r);
    }
}

这里,使用反射,找到线程中存储ThreadLocal对象的Map变量threadLocals,重置为null。使用MyThreadPool的示例代码如下:

public static void main(String[] args) {
    ExecutorService executor = new MyThreadPool(2, 2, 0,
            TimeUnit.MINUTES, new LinkedBlockingQueue<Runnable>());
    executor.execute(new Task());
    executor.execute(new Task());
    executor.execute(new Task());
    executor.shutdown();
}

参考链接:

1. Java API

2. 《Java编程的逻辑》

发布了117 篇原创文章 · 获赞 33 · 访问量 8万+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章