詳述 ThreadLocal 的實現原理及其使用方法

Threadlocal是一個線程內部的存儲類,可以在指定線程內存儲數據,並且該數據只有指定線程能夠獲取到,其官方解釋如下:

/**
 * This class provides thread-local variables.  These variables differ from
 * their normal counterparts in that each thread that accesses one (via its
 * {@code get} or {@code set} method) has its own, independently initialized
 * copy of the variable.  {@code ThreadLocal} instances are typically private
 * static fields in classes that wish to associate state with a thread (e.g.,
 * a user ID or Transaction ID).
 */

其大致意思就是,ThreadLocal提供了線程內存儲變量的能力,這些變量不同之處在於每一個線程讀取的變量是對應的互相獨立的,通過setget方法就可以得到當前線程對應的值。

做個不恰當的比喻,從表面上看ThreadLocal相當於維護了一個Mapkey就是當前的線程,value就是需要存儲的對象。至於爲什麼說不恰當,因爲實際上是ThreadLocal的靜態內部類ThreadLocalMap爲每個Thread都維護了一個數組tableThreadLocal確定了一個數組下標,而這個下標就是value存儲的對應位置。

實現原理

ThreadLocal中,最重要的兩個方法就是setget,如果我們理解了這兩個方法的實現原理,那麼也就可以說我們理解了ThreadLocal的實現原理。

ThreadLocal 的 get 方法

首先,我們來看一下ThreadLocalset方法。

 public void set(T value) {
      //獲取當前線程
      Thread t = Thread.currentThread();
      //實際存儲的數據結構類型
      ThreadLocalMap map = getMap(t);
      //如果存在map就直接set,沒有則創建map並set
      if (map != null)
          map.set(this, value);
      else
          createMap(t, value);
  }
  
ThreadLocalMap getMap(Thread t) {
      //thread中維護了一個ThreadLocalMap
      return t.threadLocals;
 }
 
void createMap(Thread t, T firstValue) {
      //實例化一個新的ThreadLocalMap,並賦值給線程的成員變量threadLocals
      t.threadLocals = new ThreadLocalMap(this, firstValue);
}

如上述代碼所示,我們可以看出來每個線程持有一個ThreadLocalMap對象。每創建一個新的線程Thread都會實例化一個ThreadLocalMap並賦值給成員變量threadLocals,使用時若已經存在threadLocals則直接使用已經存在的對象;否則的話,新創建一個ThreadLocalMap並賦值給threadLocals變量。

    /* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. */
    ThreadLocal.ThreadLocalMap threadLocals = null;

如上述代碼所示,其爲Thread類中關於threadLocals變量的聲明。

接下來,我們看一下createMap方法中的實例化過程,主要就是創建ThreadLocalMap對象。

//Entry爲ThreadLocalMap靜態內部類,對ThreadLocal的若引用
//同時讓ThreadLocal和儲值形成key-value的關係
static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
           super(k);
            value = v;
    }
}

//ThreadLocalMap構造方法
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        //內部成員數組,INITIAL_CAPACITY值爲16的常量
        table = new Entry[INITIAL_CAPACITY];
        //位運算,結果與取模相同,計算出需要存放的位置
        //threadLocalHashCode比較有趣
        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
        table[i] = new Entry(firstKey, firstValue);
        size = 1;
        setThreshold(INITIAL_CAPACITY);
}

通過上面的代碼不難看出在實例化ThreadLocalMap時創建了一個長度爲 16 的Entry數組。通過hashCodelength位運算確定出一個索引值i,這個i就是被存儲在table數組中的位置。

前面講過每個線程Thread持有一個ThreadLocalMap類型的變量threadLocals,結合此處的構造方法可以理解成每個線程Thread都持有一個Entry型的數組table,而一切的讀取過程都是通過操作這個數組table完成的。

thread-local

顯然tablesetget的焦點,在看具體的setget方法前,先看下面這段代碼。

//在某一線程聲明瞭ABC三種類型的ThreadLocal
ThreadLocal<A> sThreadLocalA = new ThreadLocal<A>();
ThreadLocal<B> sThreadLocalB = new ThreadLocal<B>();
ThreadLocal<C> sThreadLocalC = new ThreadLocal<C>();

由前面我們知道對於一個Thread來說只有持有一個ThreadLocalMap,所以 A、B、C 對應同一個ThreadLocalMap對象。爲了管理 A、B、C,於是將他們存儲在一個數組的不同位置,而這個數組就是上面提到的Entry型的數組table

那麼問題來了, A、B、C 在table中的位置是如何確定的?爲了能正常夠正常的訪問對應的值,肯定存在一種方法計算出確定的索引值i,代碼如下:

//ThreadLocalMap中set方法。
  private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            //獲取索引值,這個地方是比較特別的地方
            int i = key.threadLocalHashCode & (len-1);

            //遍歷tab如果已經存在則更新值
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }

                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            
            //如果上面沒有遍歷成功則創建新值
            tab[i] = new Entry(key, value);
            int sz = ++size;
            //滿足條件數組擴容x2
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

ThreadLocalMap中的set方法與構造方法中,能看到以下代碼片段:

  • int i = key.threadLocalHashCode & (len-1)
  • int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1)

簡而言之,就是將threadLocalHashCode進行一個位運算(取模)得到索引ithreadLocalHashCode代碼如下:

    //ThreadLocal中threadLocalHashCode相關代碼.
    
    private final int threadLocalHashCode = nextHashCode();

    /**
     * The next hash code to be given out. Updated atomically. Starts at
     * zero.
     */
    private static AtomicInteger nextHashCode =
        new AtomicInteger();

    /**
     * The difference between successively generated hash codes - turns
     * implicit sequential thread-local IDs into near-optimally spread
     * multiplicative hash values for power-of-two-sized tables.
     */
    private static final int HASH_INCREMENT = 0x61c88647;

    /**
     * Returns the next hash code.
     */
    private static int nextHashCode() {
        //自增
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

因爲static的原因,在每次new ThreadLocal時因爲threadLocalHashCode的初始化,會使threadLocalHashCode值自增一次,增量爲0x61c88647。其中,0x61c88647是斐波那契散列乘數,它的優點是通過它散列(hash)出來的結果分佈會比較均勻,可以很大程度上避免hash衝突,已初始容量 16 爲例,hash並與 15 位運算計算數組下標結果如下:

hashCode 數組下標
0x61c88647 7
0xc3910c8e 14
0x255992d5 5
0x8722191c 12
0xe8ea9f63 3
0x4ab325aa 10
0xac7babf1 1
0xe443238 8
0x700cb87f 15

總結如下:

  • 對於某一個ThreadLocal來講,其索引值i是確定的,在不同線程之間訪問時訪問的是不同的table數組的同一位置即都爲table[i],只不過這個不同線程之間的table是獨立的。
  • 對於同一線程的不同ThreadLocal來講,這些ThreadLocal實例共享一個table數組,然後每個ThreadLocal實例在table中的索引i是不同的。

ThreadLocal 的 set 方法

在瞭解完set方法的實現原理之後,我們在來看一下ThreadLocal中的get方法。

//ThreadLocal中get方法
public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}
    
//ThreadLocalMap中getEntry方法
private Entry getEntry(ThreadLocal<?> key) {
       int i = key.threadLocalHashCode & (table.length - 1);
       Entry e = table[i];
       if (e != null && e.get() == key)
            return e;
       else
            return getEntryAfterMiss(key, i, e);
   }

如上述代碼所示,get方法就是通過計算出的索引從數組的對應位置取值,其中getMap獲取的是Thread類中的threadLocals變量。

    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

在取值的時候,又分爲兩種情況,如果獲取的map爲空,則調用setInitialValue設置初始值,默認值爲null,我們也可以在創建ThreadLocal的時候覆寫其initialValue方法,以實現自定義默認值的目的;如果獲取的map非空,則調用getEntry方法返回對應的值e,並當e不爲null時,強轉爲實際的類型,否則,同樣調用setInitialValue設置初始值。

ThreadLocal 的特性

ThreadLocalsynchronized都是爲了解決多線程中相同變量的訪問衝突問題,不同的點是:

  • synchronized是通過線程等待,犧牲時間來解決訪問衝突;
  • ThreadLocal是通過每個線程單獨一份存儲空間,犧牲空間來解決衝突,並且相比於synchronizedThreadLocal具有線程隔離的效果,只有在線程內才能獲取到對應的值,線程外則不能訪問到想要的值。

正因爲ThreadLocal的線程隔離特性,所以它的應用場景相對來說更爲特殊一些。當某些數據是以線程爲作用域並且不同線程具有不同的數據副本的時候,就可以考慮採用ThreadLocal實現。但是在使用ThreadLocal的時候,需要我們考慮內存泄漏的風險。

至於爲什麼會有內存泄漏的風險,則是因爲在我們使用ThreadLocal保存一個value時,會在ThreadLocalMap中的數組插入一個Entry對象,按理說keyvalue都應該以強引用保存在Entry對象中,但在ThreadLocalMap的實現中,key被保存到了WeakReference對象中。

這就導致了一個問題,ThreadLocal在沒有外部強引用時,發生 GC 時會被回收,但Entry對象和value並沒有被回收,因此如果創建ThreadLocal的線程一直持續運行,那麼這個Entry對象中的value就有可能一直得不到回收,從而發生內存泄露。既然已經發現有內存泄露的隱患,自然有應對的策略。在調用ThreadLocalget方法時會自動清除ThreadLocalMapkeynullEntry對象,其觸發邏輯就在getEntry方法中:

        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }

enull或者e.get()不等於key時,進入getEntryAfterMiss的邏輯:

        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

e不爲nulle.get()等於null時,執行expungeStaleEntry的邏輯,也就是真正刪除過期Entry的方法:

       /**
         * Expunge a stale entry by rehashing any possibly colliding entries
         * lying between staleSlot and the next null slot.  This also expunges
         * any other stale entries encountered before the trailing null.  See
         * Knuth, Section 6.4
         *
         * @param staleSlot index of slot known to have null key
         * @return the index of the next null slot after staleSlot
         * (all between staleSlot and this slot will have been checked
         * for expunging).
         */
        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

這樣對應的value就不會 GC Roots 可達,從而在下次 GC 的時候就可以被回收了。但我們要知道,這僅是在調用ThreadLocalget方法之後,纔有可能執行的邏輯;特別地,當我們誤用“先getset”的使用邏輯時,就更會加大內存泄漏的風險。因此,ThreadLocal的最佳實踐就是在使用完ThreadLocal之後,使用finally關鍵字顯示調用ThreadLocalremove方法,防止內存泄漏

使用方法

假設,有這樣一個類:

@Data
@AllArgsConstructor
public class Counter{
	private int count;
}

我們希望多線程訪問Counter對象時,每個線程各自保留一份count計數,那可以這麼寫:

ThreadLocal<Counter> threadLocal = new ThreadLocal<>();
threadLocal.set(new Counter(0));
Counter counter = threadLocal.get();

如果我們不想每次調用的時候都去初始化,則可以重寫ThreadLocalinitValue()方法給ThreadLocal設置一個對象的初始值:

ThreadLocal<Counter> threadLocal = new ThreadLocal<Counter>() {
    @Override
    protected Counter initialValue() {
        return new Counter(0);
    }
};

如上述代碼所示,這樣每次再調用threadLocal.get()的時候,會去判斷當前線程是否存在Counter對象,如果不存在則調用initValue()方法進行初始化。

@Slf4j
public class MyThreadLocal<T> extends ThreadLocal<T>{
    public T get() {
        try {
            return super.get();
        } catch (Exception e) {
           log.error("獲取ThreadLocal值失敗!");
           return null;
        } finally {
            super.remove();
        }
    }
}

如上述代碼所示,遵循ThreadLocal最佳實現,我們可以創建一個MyThreadLocal類,繼承ThreadLocal並覆寫其get方法。


參考資料

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章