深入理解 ThreadLocal

前言

上篇文章 深入理解 Handler 消息機制 中提到了獲取線程的 Looper 是通過 ThreadLocal 來實現的:

public static @Nullable Looper myLooper() {
    return sThreadLocal.get();
}

每個線程都有自己的 Looper,它們之間不應該有任何交集,互不干擾,我們把這種變量稱爲 線程局部變量 。而 ThreadLocal 的作用正是存儲線程局部變量,每個線程中存儲的都是獨立存在的數據副本。如果你還是不太理解,看一下下面這個簡單的例子:

public static void main(String[] args) throws InterruptedException {

    ThreadLocal<Boolean> threadLocal = new ThreadLocal<Boolean>();
    threadLocal.set(true);

    Thread t1 = new Thread(() -> {
        threadLocal.set(false);
        System.out.println(threadLocal.get());
    });

    Thread t2 = new Thread(() -> {
        System.out.println(threadLocal.get());
    });

    t1.start();
    t2.start();
    t1.join();
    t2.join();
    System.out.println(threadLocal.get());
}

執行結果是:

false
null
true

可以看到,我們在不同的線程中調用同一個 ThreadLocal 的 get() 方法,獲得的值是不同的,看起來就像 ThreadLocal 爲每個線程分別存儲了不同的值。那麼這到底是如何實現的呢?一起來看看源碼吧。

以下源碼基於 JDK 1.8 , 相關文件:

Thread.java

ThreadLocal.java

ThreadLocal

首先 ThreadLocal 是一個泛型類,public class ThreadLocal<T>,支持存儲各種數據類型。它對外暴露的方法很少,基本就 get()set()remove() 這三個。下面依次來看一下。

set()

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t); // 獲取當前線程的 ThreadLocalMap
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value); // 創建 ThreadLocalMap
}

這裏出現了一個新東西 ThreadLocalMap,暫且就把他當做一個普通的 Map。從 map.set(this, value) 可以看出來這個 map 的鍵是 ThreadLocal 對象,值是要存儲的 value 對象。其實看到這,ThreadLocal 的原理你應該基本都明白了。

每一個 Thread 都有一個 ThreadLocalMap ,這個 Map 以 ThreadLocal 對象爲鍵,以要保存的線程局部變量爲值。這樣就做到了爲每個線程保存不同的副本。

首先通過 getMap() 函數獲取當前線程的 ThreadLocalMap :

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

原來 Thread 還有這麼一個變量 threadLocals

/* ThreadLocal values pertaining to this thread. This map is maintained
* by the ThreadLocal class.
*
* 存儲線程私有變量,由 ThreadLocal 進行管理
*/
ThreadLocal.ThreadLocalMap threadLocals = null;

默認爲 null,所以第一次調用時返回 null ,調用 createMap(t, value) 進行初始化:

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

get()

set() 方法是向 ThreadLocalMap 中插值,那麼 get() 就是在 ThreadLocalMap 中取值了。

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t); // 獲取當前線程的 ThreadLocalMap
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result; // 找到值,直接返回
        }
    }
    return setInitialValue(); // 設置初始值
}

首先獲取 ThreadLocalMap,在 Map 中尋找當前 ThreadLocal 對應的 value 值。如果 Map 爲空,或者沒有找到 value,則通過 setInitialValue() 函數設置初始值。

private T setInitialValue() {
    T value = initialValue(); // 爲 null
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

protected T initialValue() {
    return null;
}

setInitialValue()set() 邏輯基本一致,只不過 value 是 null 而已。這也解釋了文章開頭的例子會輸出 null。當然,在 ThreadLocal 的子類中,我們可以通過重寫 setInitialValue() 來提供其他默認值。

remove()

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

remove() 就更簡單了,根據鍵直接移除對應條目。

看到這裏,ThreadLocal 的原理好像就說完了,其實不然。ThreadLocalMap 是什麼樣的一個哈希表呢?它是如何解決哈希衝突的?它是如何添加,獲取和刪除元素的?可能會導致內存泄露嗎?

其實 ThreadLocalMap 纔是 ThreadLocal 的核心。ThreadLocal 僅僅只是提供給開發者的一個工具而已,就像 Handler 一樣。帶着上面的問題,來閱讀 ThreadLocalMap 的源碼,體會 JDK 工程師的鬼斧神工。

ThreadLocalMap

Entry

ThreadLocalMap 是 ThreadLocal 的靜態內部類,它沒有直接使用 HashMap,而是一個自定義的哈希表,使用數組實現,數組元素是 Entry

static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

Entry 類繼承了 WeakReference<ThreadLocal<?>>,我們可以把它看成是一個鍵值對。鍵是當前的 ThreadLocal 對象,值是存儲的對象。注意 ThreadLocal 對象的引用是弱引用,值對象 value 的引用是強引用。ThreadLocal 使用弱引用其實很好理解,源碼註釋中也告訴了我們答案:

To help deal with very large and long-lived usages, the hash table entries use WeakReferences for keys

Thread 持有 ThreadLocalMap 的強引用,ThreadLocalMap 中的 Entry 的鍵是 ThreadLocal 引用。如果線程長期存活或者使用了線程池,而 ThreadLocal 在外部又沒有任何強引用了,這種情況下如果 ThreadLocalMap 的鍵仍然使用強引用 ThreadLocal,就會導致 ThreadLocal 永遠無法被垃圾回收,造成內存泄露。

圖片來源:https://www.jianshu.com/p/a1cd61fa22da

那麼,使用弱引用是不是就萬無一失了呢?答案也是否定的。同樣是上面說到使用情況,線程長期存活,由於 Entry 的 key 使用了弱引用,當 ThreadLocal 不存在外部強引用時,可以在 GC 中被回收。但是根據可達性分析算法,仍然存在着這麼一個引用鏈:

Current Thread -> ThreadLocalMap -> Entry -> value

key 已經被回收了,此時 key == null。那麼,value 呢?如果線程長期存在,這個針對 value 的強引用也會一直存在,外部是否對 value 指向的對象還存在其他強引用也不得而知。所以這裏還是有機率發生內存泄漏的。就算我們不知道外部的引用情況,但至少在這裏應該是可以切斷 value 引用的。

所以,爲了解決可能存在的內存泄露問題,我們有必要對於這種 key 已經被 GC 的過期 Entry 進行處理,手動釋放 value 引用。當然,JDK 中已經爲我們處理了,而且處理的十分巧妙。下面就來看看 ThreadLocalMap 的源碼。

構造函數

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    table = new Entry[INITIAL_CAPACITY];
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    table[i] = new Entry(firstKey, firstValue);
    size = 1;
    setThreshold(INITIAL_CAPACITY);
}

table 是存儲 Entry 的數組,初始容量 INITIAL_CAPACITY 是 16。

firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1) 是 ThreadLocalMap 計算哈希的方式。&(2^n-1) 其實等同於 % 2^n,位運算效率更高。

threadLocalHashCode 是如何計算的呢?看下面的代碼:

private static final int HASH_INCREMENT = 0x61c88647;

private static AtomicInteger nextHashCode = new AtomicInteger();

private final int threadLocalHashCode = nextHashCode();

private static int nextHashCode() {
    return nextHashCode.getAndAdd(HASH_INCREMENT);
}

0x61c88647 是一個增量,每次取哈希都要再加上這個數字。又是一個神奇的數字,讓我想到了 Integer 源碼中的 52429 這個數字,見 走進 JDK 之 Integer0x61c88647 背後肯定也有它的數學原理,總之肯定是爲了效率。

原理就不去探究了,其實我也不知道是啥原理。不過我們可以試用一下,看看效果如何。按照上面的方式來計算連續幾個元素的哈希值,也就是在 Entry 數組中的位置。代碼如下:

public class Test {

    private static final int INITIAL_CAPACITY = 16;
    private static final int HASH_INCREMENT = 0x61c88647;
    private static AtomicInteger nextHashCode = new AtomicInteger();

    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

    private static int hash() {
        return nextHashCode() & (INITIAL_CAPACITY - 1);
    }

    public static void main(String[] args) {

        for (int i = 0; i < 8; i++) {
            System.out.println(hash());
        }
    }
}

運算結果如下:

0
7
14
5
12
3
10
1

計算結果分佈還是比較均勻的。既然是哈希表,肯定就會存在哈希衝突的情況。那麼,ThreadLocalMap 是如何解決哈希衝突呢?很簡單,看一下 nextIndex() 方法。

private static int nextIndex(int i, int len) {
    return ((i + 1 < len) ? i + 1 : 0);
}

在不超過 len 的情況下直接加 1,否則置 0。其實這樣又可以看成一個環形數組。

接下來看看 ThreadLocalMap 的數據是如何存儲的。

set()

private void set(ThreadLocal<?> key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1); // 當前 key 的哈希,即在數組 table 中的位置

    for (Entry e = tab[i];
        e != null; // 循環直到碰到空 Entry
        e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) { // 更新 key 對應的值
            e.value = value;
            return;
        }

        if (k == null) { // 替代過期 entry
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}
  1. 通過 key.threadLocalHashCode & (len-1) 計算出初始的哈希值
  2. 不斷調用 nextIndex() 直到找到空 Entry
  3. 在第二步遍歷過程中的每個元素,要處理兩種情況:

    (1). k == key,說明當前 key 已存在,直接更新值即可,直接返回

    (2). k == null, 注意這裏的前置條件是 entry != null。說明遇到過期 Entry,直接替換
  4. 不屬於 3 中的兩種情況,則將參數中的鍵值對插入空 Entry 處
  5. cleanSomeSlots()/rehash()

先來看看第三步中的第二種特殊情況。Entry 不爲空,但其中的 key 爲空,什麼時候會發生這種情況呢?對,就是前面說到內存泄漏時提到的 過期 Entry。我們都知道 Entry 的 key 是弱引用的 ThreadLocal,當外部沒有它的任何強引用時,下次 GC 時就會將其回收。所以這時候的 Entry 理論上也是無效的了。

由於這裏是在 set() 方法插入元素的過程中發現了過期 Entry,所以只要將要插入的 Entry 直接替換這個 key==null 的 Entry 就可以了,這就是 replaceStaleEntry() 的核心邏輯。

replaceStaleEntry()

private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // Back up to check for prior stale entry in current run.
    // We clean out whole runs at a time to avoid continual
    // incremental rehashing due to garbage collector freeing
    // up refs in bunches (i.e., whenever the collector runs).
    // 向前找到第一個過期條目
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
        (e = tab[i]) != null;
        i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i; // 記錄前一個過期條目的位置

    // Find either the key or trailing null slot of run, whichever occurs first
    // 向後查找,直到找到 key 或者 空 Entry
    for (int i = nextIndex(staleSlot, len);
        (e = tab[i]) != null;
        i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // If we find key, then we need to swap it
        // with the stale entry to maintain hash table order.
        // The newly stale slot, or any other stale slot
        // encountered above it, can then be sent to expungeStaleEntry
        // to remove or rehash all of the other entries in run.
        if (k == key) {

            // 如果在向後查找過程中發現 key 相同的 entry 就覆蓋並且和過期 entry 進行交換
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // Start expunge at preceding stale entry if it exists
            // 如果在查找過程中還未發現過期 entry,那麼就以當前位置作爲 cleanSomeSlots 的起點
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // If we didn't find stale entry on backward scan, the
        // first stale entry seen while scanning for key is the
        // first still present in the run.
        // 如果向前未搜索到過期 entry,而在向後查找過程遇到過期 entry 的話,後面就以此時這個位置
        // 作爲起點執行 cleanSomeSlots
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // If key not found, put new entry in stale slot
    // 如果在查找過程中沒有找到可以覆蓋的 entry,則將新的 entry 插入在過期 entry
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // If there are any other stale entries in run, expunge them
    // 在上面的代碼運行過程中,找到了其他的過期條目
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

看起來挺累人的。在我理解,replaceStaleEntry 只是做一個標記的作用,在各種情況下最後都會調用 cleanSomeSlots 來真正的清理過期條目。

你可以看到 ``

cleanSomeSlots()

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            i = expungeStaleEntry(i); // 需要清理的 Entry
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

參數 n 表示掃描控制。初始情況下掃描 log2(n) 次,如果遇到過期條目,會再掃描 log2(table.length)-1 次。在 set() 方法中調用,參數 n 表示元素的個數。在 replaceStaleEntry 中調用,參數 n 表示的是數組 table 的長度。

注意 do 循環裏面的判斷條件:e != null && e.get() == null ,還是那些 Entry 不爲空,key 爲空的過期條目。發現過期條目之後,調用 expungeStaleEntry() 去清理。

expungeStaleEntry()

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // expunge entry at staleSlot
    // 清空 staleSlot 處的 過期 entry
    // 將 value 置空,保證不會因爲這裏的強引用造成 memory leak
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // Rehash until we encounter null
    // 繼續搜索直到遇到 tab 中的空 entry
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
        (e = tab[i]) != null;
        i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) { // 搜索過程中遇到過期條目,直接清理
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            // key 還沒有被回收
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i; // 此時從 staleSlot 到 i 之間不存在過期條目
}

直接將 entry.valueentry 都置空,消除內存泄露的隱患。注意這裏僅僅只是置空,並不是回收對象。因爲你不知道 value 在外部的引用情況,只需要管好自己的引用就可以了。

除此之外,不甘寂寞的 expungeStaleEntry() 又發起了一次掃描,直到碰到空 Entry未知。期間遇到的過期 Entry 要置空。

整個 set() 方法就看完了,原理很簡單,但是其中關於內存泄漏的預防處理十分複雜,看的我一度放棄了,也讓我對源碼閱讀產生了一些疑問。有些時候是不是沒有必要逐行去玩去完全理解?比如這一系列關於內存泄露的處理,核心思想就是清理 Entry 不爲 null 但 key 爲 null 的過期條目。理解了核心思想,對於其中複雜的細節處理是不是沒有必要去深究?不知道你怎麼看,歡迎在評論區寫下你的看法。

下面來看一看 getgetEntry 方法。

getEntry()

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key) // 直接命中
        return e;
    else
        // 未直接命中,線性探測,繼續往後找
        return getEntryAfterMiss(key, i, e);
}

getEntry() 比較粗暴,上來直接根據哈希值查找 table 數組,如果直接命中,就返回。未直接命中,調用 getEntryAfterMiss() 繼續查找。

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    // 向後查找直到遇到空 entry
    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key) // get it
            return e;
        if (k == null) // key 等於 null,清理過期 entry
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len); // 繼續向後查找
        e = tab[i];
    }
    return null;
}

調用 nextIndex() 向後查找,直到遇到 空 Entry,也就是隊尾:

  • k==key,說明找到了對應 Entry
  • k==null,說明遇到了過期 Entry,調用 expungeStaleEntry() 處理

對過期 Entry 的處理真的是無處不在,就是爲了最大程度的降低內存泄漏發生的機率。那麼有沒有什麼一勞永逸的辦法呢?那就是 ThreadLocalMapremove() 方法。

remove()

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
        e != null;
        e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}

直接清除當前 ThreadLocal 對應的 Entry,根本上避免了發生內存泄露。所以,當我們不再需要使用 ThreadLocal 中的相應數據時,調用一下 remove() 方法肯定是個好習慣。

雖然在長期存活的線程(例如線程池)中使用 ThreadLocal 併發生內存泄漏是一個小概率事件,但 JDK 開發者卻爲此多寫了很多代碼。我們在使用中也要多加註意,仔細考慮是否會涉及到內存泄露的問題。

End

最後說說在網上看到的一個觀點,ThreadLocal 比 Synchronized 更適合解決線程同步問題。

首先這個問題本身就不是那麼嚴謹。ThreadLocal 是用來解決線程同步問題的嗎?表面上看,ThreadLocal 的機制的確是線程安全的,但它並不是爲了解決多線程訪問同一個變量的競爭問題,而是給每一個線程都提供單獨的變量,有些文章稱之爲 數據備份,但它們並不是備份,每一個都是獨立存在的,互不干擾,並不存在什麼同步問題。

ThreadLocalSynchronized 的應用場景也是千差萬別的。例如銀行的轉賬場景,涉及多個賬戶同時轉賬的多線程同步問題,ThreadLocal 根本就沒法解決,即使每個線程都單獨保存着用戶的餘額也沒法解決併發問題。ThreadLocal 在 Android 中的典型應用就是 Looper,每個線程都有自己的 Looper 對象,它們都是獨立工作,互不干擾的。

關於 ThreadLocal 就說到這裏了。後續分享的方向主要集中在兩塊,一方面是 AOSP 源碼的閱讀和解析,另一方面是 Kotlin 和 Java 相關特性的對比,敬請期待!

文章首發微信公衆號: 秉心說 , 專注 Java 、 Android 原創知識分享,LeetCode 題解。

更多最新原創文章,掃碼關注我吧!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章