深入理解ThreadLocal

前言

併發是Java開發中繞不開的一個話題。現代處理器都是多核心,想要更好地榨乾機器的性能,多線程編程是必不可少,所以,線程安全是每位Java Engineer的必修課。

應對線程安全問題,可大致分爲兩種方式:

  1. 同步: 用Synchronized關鍵字,或者用java.util.concurrent.locks.Lock工具類給臨界資源加鎖。
  2. 避免資源爭用: 將全局資源放在ThreadLocal變量中,避免併發訪問。

本文將介紹第二種方式:ThreadLocal的實現原理以及爲什麼能保證線程安全。

ThreadLocal

下面是ThreadLocal的一個常見使用場景:

public class ThreadLocalTest {
    // 一般都將ThreadLocal定義爲靜態變量
    private static final ThreadLocal<DateFormat> format = new ThreadLocal<DateFormat>(){
        // 初始化ThreadLocal的值
        protected DateFormat initialValue() {
            return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
        }
    };

    public static void main(String[] args) {
        // 啓動20個線程
        for (int i = 0; i < 20; i++) {
            new Thread(() -> {
                try {
                    // 得到SimpleDateFormat在本線程中的副本
                    DateFormat localFormat = format.get();
                    // 解析日期,這裏並不會報錯
                    Date date = localFormat.parse("2000-11-11 11:11:11");
                    System.out.println(date);
                } catch (ParseException e) {
                    e.printStackTrace();
                }
            }).start();
        }
    }
}

大家應該都知道,Java中SimpleDateFormat不是線程安全的,參考這篇文章。然而上述代碼的確不會報錯,說明ThreadLocal確實能保證併發安全。

源碼解析

ThreadLocal概覽

上面的例子中,我們調用了ThreadLocalinitialValueget方法,且來看一下get方法的實現:

// 此類的作者是兩個大神,前者是《Effective Java》的作者,後者是Java併發包的作者,併發大師
/*
 * @author  Josh Bloch and Doug Lea
 * @since   1.2
 */
public class ThreadLocal<T> {
    public T get() {
        // 得到當前線程
        Thread t = Thread.currentThread();
        // 根據當前線程,拿到一個Map,暫且可以將之類比爲HashMap鍵值對形式
        // 可見這個Map是與本線程相關的
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            // 通過this從Map中拿Entry,說明Map中的Key就是ThreadLocal變量本身
            // value就是ThreadLocal中所保存的對象
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        // 若Map沒有初始化(map == null),或者當前ThreadLocal變量沒有初始化(e == null)
        // 則調用此方法完成初始化
        return setInitialValue();
    }

    // 原來,這個ThreadLocalMap只是線程的一個成員變量!
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
}

public class Thread implements Runnable {
    // Thread類中定義了一個全局變量ThreadLocalMap
    // 用來存放本線程中所有的ThreadLocal類型變量,初始值爲null
    ThreadLocal.ThreadLocalMap threadLocals = null;
}

get方法,我們可以得到如下信息:

  1. ThreadLocal變量保存在一個Map中,而這個Map正是Thread類的一個全局變量。這也是ThreadLocal實現線程安全的一個關鍵點:各個線程都有自己的Map,每個線程操作的都是自己的ThreadLocal變量副本,互不影響。
  2. ThreadLocalMap保存線程中所有的ThreadLocal變量,ThreadLocal變量是Key,ThreadLocal所對應的值爲Value。(在本文開始的例子中,Key爲format變量,Value爲initialValue方法返回的值new SimpleDateFormat(“yyyy-MM-dd HH:mm:ss”)
  3. ThreadLocal是懶加載的,當發現ThreadLocalMap或者當前ThreadLocal變量未初始化時,會調用setInitialValue方法進行初始化。

ThreadLocal

ThreadLocal其他方法

繼續來看setInitialValue方法做了什麼事情:

    private T setInitialValue() {
        // 調用initialValue方法初始化
        // 這個方法即爲我們定義ThreadLocal變量的時候重寫的方法
        T value = initialValue();
        Thread t = Thread.currentThread();
        // 獲取當前線程的ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        if (map != null)
            // 如果Map已經初始化好了,那直接初始化當前ThreadLocal變量:
            // 將自己(當前ThreadLocal變量)作爲key,保存的值作爲value,set到Map裏面去
            map.set(this, value);
        else
            // 如果Map還未初始化,則初始化Map
            createMap(t, value);
        return value;
    }

    // 默認的initialValue方法定義爲protected,就是給我們重寫的
    protected T initialValue() {
        return null;
    }

    void createMap(Thread t, T firstValue) {
        // 新建一個ThreadLocalMap,賦值給當前Thread
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

其他還有setremove方法,很簡單這裏不另外講解。

難道ThreadLocal就此結束了麼?有這麼簡單麼?當然沒有。因爲ThreadLocalMapThread的一個成員變量,所以它的生命週期跟線程是一樣長的。也就是說,只要線程還沒有被銷燬,那麼Map就會常駐內存,無法被GC,很容易造成內存泄漏。那ThreadLocal是如何解決的呢?

答案是弱引用,Java中的引用類型,可以參考這篇文章

ThreadLocalMap

ThreadLocalMapThreadLocal的一個內部類。Java中有現成的類HashMap,而ThreadLocal又費勁千辛萬苦自己實現了一個ThreadLocalMap,就是爲防止內存泄漏。

下面我們來探祕ThreadLocalMap,它跟普通的HashMap有什麼區別。

ThreadLocalMap的數據結構

static class ThreadLocalMap {

    // 內部類Entry繼承了WeakReference
    static class Entry extends WeakReference<ThreadLocal<?>> {
        // ThreadLocal變量中保存的值
        Object value;

        // 可以看到,Entry只是簡單的Key-Value,並沒有類似HashMap中的鏈表
        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }
    // ThreadLocalMap默認大小
    private static final int INITIAL_CAPACITY = 16;
    // 此Entry數組,就是所有ThreadLocal存放的地方
    private Entry[] table;
}

ThreadLocalMap維護了一個Entry數組(沒有鏈表,這是跟HashMap不一樣的地方),用來存放線程中所有的ThreadLocal變量。Entry繼承了WeakReference,並關聯了ThreadLocal,當外界沒有其他強引用指向ThreadLocal對象時,該ThreadLocal對象會在下一次GC時被內存回收,也就是Entry中的Key會被回收掉,所以下面會看到清理key爲null的Entry的操作。

Set操作

HashMap遇到哈希衝突的時候,是通過在同一個Hash Key上建立鏈表來解決的。既然ThreadLocalMap只維護了一個Entry數組,那它是怎麼解決哈希衝突的呢?我們來看set方法的源碼:

    private void set(ThreadLocal<?> key, Object value) {
        Entry[] tab = table;
        int len = tab.length;
        // 根據ThreadLocal的hashcode,計算出在table中的槽位(index)
        int i = key.threadLocalHashCode & (len-1);
        // 從位置i開始,逐個往後循環,找到第一個空的槽位(條件e == null)
        for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
            ThreadLocal<?> k = e.get();
            // 如果key相等,則直接將舊value覆蓋掉,換成新value
            if (k == key) {
                // 新值替換掉舊值,並return掉
                e.value = value;
                return;
            }
            // key == null,說明弱引用之前已經被內存回收,則將值設在此槽位
            if (k == null) {
                // 該方法後面再解析
                replaceStaleEntry(key, value, i);
                return;
            }
        }

        // 走到這裏,這個i 是從key真正所在的hash槽之後數,第一個非空槽位
        // 將value包裝成Entry,放到位置i中
        tab[i] = new Entry(key, value);
        int sz = ++size;
        // 查找是否有Entry已經被回收
        // 如果找到有Entry被回收,或者table的size大於閾值,執行rehash操作
        if (!cleanSomeSlots(i, sz) && sz >= threshold)
            rehash();
    }
    
    // 獲取下一個index。其實就是i + 1。當超出table長度的時候,歸0重新來
    private static int nextIndex(int i, int len) {
        return ((i + 1 < len) ? i + 1 : 0);
    }

ThreadLocalMap是用開放地址發來解決哈希衝突的。如果目標槽位已經有值了,首先判斷該值是不是就是自己。如果是,那就替換舊值;如果不是,再判斷該槽位的值是否有效(槽位上的ThreadLocal變量有沒有被垃圾回收),如果無效,則直接設置在該槽位,並執行一些清理操作。如果該槽位上是一個有效的值,那麼往後繼續尋找,直到找到空槽位爲止。流程大概如下:

ThreadLocalMap

清理無效的Entry

到這裏,我們應該帶着一個疑問:弱引用清除的只是Entry中的key,也就是ThreadLocal變量,而Entry本身依然佔據着table中的槽位。那代碼中是在哪裏清理這些無效的Entry的呢?我們重點看一下上面沒有分析的兩個方法replaceStaleEntrycleanSomeSlots

cleanSomeSlots

    // 顧名思義,清除部分槽位,默認掃描log(n)個槽位
    private boolean cleanSomeSlots(int i, int n) {
        boolean removed = false;
        Entry[] tab = table;
        int len = tab.length;
        do {
            i = nextIndex(i, len);
            Entry e = tab[i];
            // 注意無效Entry的判斷條件是,e.get() == null
            // 即Entry中保存的弱引用已經被GC,這種情況需要將對應Entry清除
            if (e != null && e.get() == null) {
                // 如果發現有無效entry,那n會重新設置爲table的長度
                // 即會繼續查找log(n)個槽位,判斷有沒有無效Entry
                n = len;
                removed = true;
                // 調用expungeStaleEntry方法清除i位置的槽位
                i = expungeStaleEntry(i);
            }
        // 循環條件爲n右移一位,即除以2。所以默認是循環log(n)次
        } while ( (n >>>= 1) != 0);
        // 如果有槽位被清除,返回true
        return removed;
    }

    private int expungeStaleEntry(int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;
        // 將i位置的槽位置爲空
        tab[staleSlot].value = null;
        tab[staleSlot] = null;
        size--;

        Entry e;
        int i;
        // 繼續往後檢查是否有無效Entry,直到遇到空的槽位tab[i]==null爲止
        for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
            ThreadLocal<?> k = e.get();
            // 如果Entry無效,將其清除
            if (k == null) {
                e.value = null;
                tab[i] = null;
                size--;
            } else {
                // 重新計算hash值h
                int h = k.threadLocalHashCode & (len - 1);
                // 如果新hash值h不等於當前位置的槽位值i,這種情況需要rehash
                // 給當前i位置的e,重新找更合理的槽位
                if (h != i) {
                    // 將i位置置空
                    tab[i] = null;
                    // 從h位置往後找第一個空槽位
                    while (tab[h] != null)
                        h = nextIndex(h, len);
                    // 將e放在第一個空槽位上
                    tab[h] = e;
                }
            }
        }
        // 返回接下來第一個空槽位的下標
        return i;
    }

cleanSomeSlots方法會掃描部分的槽位,查看是否有無效的Entry。如果沒有找到,那麼只掃描log(n)個槽位;如果有找到無效槽位,則會清除該槽位,並額外再掃描log(n)個槽位,以此類推。
清空槽位的工作是expungeStaleEntry方法做的,除了清除當前位置的Entry之外,它還會檢查往後連續的非空Entry,並清除其中無效值。同時還會判斷並處理rehash。這裏爲什麼要rehash?因爲前面有無效Entry被清除掉了,如果後面的Entry是因爲hash衝突而被延到後面的,就可以把後面的Entry移到前面空出來的位置上,從而提高查詢效率。

cleanSomeSlots舉例

CleanSomeSlots Example

上圖的情況,我們分兩種情況討論:

  • 如果從i=2開始找:

    1. tab[2]所在位置爲null,繼續循環i=nextIndex(i, len)=nextIndex(2, 8)=3
    2. tab[3]所在位置(k3,v3)有效,繼續循環i=nextIndex(i, len)=nextIndex(3, 4)=0
    3. tab[0]所在位置(k1,v1)有效,繼續循環i=nextIndex(i, len)=nextIndex(0, 2)=1
    4. tab[1]所在位置(k2,v2)有效,繼續循環i=nextIndex(i, len)=nextIndex(1, 1)=0
    5. tab[0]所在位置(k1,v1)有效,n==0結束
  • 如果從i=11開始找:

    1. tab[11]所在位置(null,v7)無效,調用expungeStaleEntry方法,expungeStaleEntry方法清空tab[11],並會往後循環判斷。因爲tab[12]位置(null,v8)無效,所以tab[12]也會被清空;tab[13]位置(k9,v9)有效,則會判斷是否需要給(k9,v9)重新放位置。如果對k9執行rehash之後依然是12,則不作處理;如果對k9執行rehash之後是11,說明該元素是因爲hash碰撞被放到了12的位置,那麼需要把元素放到tab[11]的位置。expungeStaleEntry方法返回第一個爲null的下標14,n重新設置爲16,i=nextIndex(i, len)=nextIndex(14, 16)=15
    2. tab[15]所在位置(k10,v10)有效,繼續循環i=nextIndex(i, len)=nextIndex(15, 8)=0
    3. tab[0]所在位置(k1,v1)有效,繼續循環i=nextIndex(i, len)=nextIndex(0, 2)=1
    4. tab[1]所在位置(k2,v2)有效,繼續循環i=nextIndex(i, len)=nextIndex(1, 1)=0
    5. tab[0]所在位置(k1,v1)有效,n==0結束

replaceStaleEntry方法

    private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;
        Entry e;
        // slotToExpunge記錄了包含staleSlot的連續段上,第一個無效Entry的下標
        int slotToExpunge = staleSlot;
        // 往前遍歷非空槽位,找到第一個無效Entry的下標,記錄爲slotToExpunge
        for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len))
            if (e.get() == null)
                slotToExpunge = i;
        // 往後遍歷非空段,查找key所在的位置,即檢查key之前是否之前已經被添加過
        // 爲什麼到tab[i]==null爲止?因爲空的槽之後的hash值肯定已經不一樣
        for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
            ThreadLocal<?> k = e.get();
            if (k == key) {
                // 如果找到了key,那麼說明此key之前已經添加過,直接覆蓋舊值
                // 因爲staleSlot小於i,需要將兩個槽位的值進行交換,以提高查詢效率
                // 而被換到i處的無效Entry,會在之後的cleanSomeSlots被清除掉
                e.value = value;
                tab[i] = tab[staleSlot];
                tab[staleSlot] = e;
                // 如果slotToExpunge的值並沒有變,說明往前查找的過程中並未發現無效Entry
                // 那麼以當前位置作爲cleanSomeSlots的起點
                if (slotToExpunge == staleSlot)
                    slotToExpunge = i;
                // 這兩個方法都已經分析過,從slotToExpunge位置開始清理無效Entry
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                return;
            }

            // 如果前面往前查找沒有發現無效Entry,且此處的Entry無效(k==null)
            // 那麼將說明i處是第一個無效Entry,將slotToExpunge計爲i
            if (k == null && slotToExpunge == staleSlot)
                slotToExpunge = i;
        }
        // 如果key沒有找到,說明這是一個新Entry,那麼直接新建一個Entry放在staleSlot位置
        tab[staleSlot].value = null;
        tab[staleSlot] = new Entry(key, value);
        if (slotToExpunge != staleSlot)
            // 這兩個方法都已經分析過,從slotToExpunge位置開始清理無效Entry
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
    }

這個方法其實就是三個步驟:

  1. 往後查找該key在table中是否存在。如果存在,即之前已經set過該key,那麼需要覆蓋掉舊值,並且將key所在元素移到staleSlot位置。(爲什麼要移位置?因爲原元素所在的位置i,肯定在staleSlot之後,所以將元素往前放到staleSlot上可以提高查詢效率,並避免後續的rehash操作。)
  2. 如果key不存在,說明是新set的操作,直接新建Entry,放在staleSlot位置。
  3. 調用cleanSomeSlots方法,清除無效的Entry

其他方法

剩下的方法都比較簡單,解析見源碼註釋,不另外解釋
get方法:

    // get操作的方法
    private Entry getEntry(ThreadLocal<?> key) {
        int i = key.threadLocalHashCode & (table.length - 1);
        Entry e = table[i];
        // i位置元素即爲要找的元素,直接返回
        if (e != null && e.get() == key)
            return e;
        else
            // 否則調用getEntryAfterMiss方法
            return getEntryAfterMiss(key, i, e);
    }

    private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
        Entry[] tab = table;
        int len = tab.length;
        // 從i位置開始,往後遍歷查找,直到空槽位爲止。爲什麼到空槽位爲止?
        // 根據開地址法,空槽位之後的元素hash值肯定已經不一樣,沒必要再繼續
        while (e != null) {
            ThreadLocal<?> k = e.get();
            // key相等,這就是目標元素,直接返回
            if (k == key)
                return e;
            // key爲null,則是無效元素,調用expungeStaleEntry方法清除i位置的元素
            if (k == null)
                expungeStaleEntry(i);
            else
                // 繼續尋找下一個元素
                i = nextIndex(i, len);
            e = tab[i];
        }
        // 沒有找到目標元素,返回null
        return null;
    }

remove方法:

    private void remove(ThreadLocal<?> key) {
        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);
        // 還是一樣的遍歷邏輯
        for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
            // 找到目標元素
            if (e.get() == key) {
                e.clear();
                // 調用expungeStaleEntry方法清除i位置的元素
                expungeStaleEntry(i);
                return;
            }
        }
    }

resize方法

    // 當元素個數大於threshold(默認是table長度的2/3)時,需要resize
    private void resize() {
        Entry[] oldTab = table;
        int oldLen = oldTab.length;
        // 新table長度是舊table的2倍
        int newLen = oldLen * 2;
        Entry[] newTab = new Entry[newLen];
        int count = 0;
        // 遍歷舊table
        for (int j = 0; j < oldLen; ++j) {
            Entry e = oldTab[j];
            if (e != null) {
                ThreadLocal<?> k = e.get();
                // 如果key爲null,則這是個無效Entry,直接跳過(將值置爲空方便GC)
                if (k == null) {
                    e.value = null; // Help the GC
                } else {
                    // 根據新table的長度重新計算hash值
                    int h = k.threadLocalHashCode & (newLen - 1);
                    // 根據開地址法,從h開始找到第一個空槽位
                    while (newTab[h] != null)
                        h = nextIndex(h, newLen);
                    // 將該值放到該位置
                    newTab[h] = e;
                    count++;
                }
            }
        }
        // 設置新table的一些參數
        setThreshold(newLen);
        size = count;
        table = newTab;
    }

總結

本文從代碼層面,深入介紹了ThreadLocal的實現原理。
ThreadLocal可以保證線程安全,是因爲它給爲每個線程都創建了一個變量的副本。每個線程訪問的都是自己內部的變量,不會有併發衝突。
作爲線程內部變量,它跟局部變量有什麼區別呢?一般ThreadLocal都被定義爲static,也就是說,每個線程只需要創建一份,生命週期跟線程一樣。而局部變量生命週期跟方法與方法一樣,每調用一次方法,創建一次變量,方法結束,對象銷燬。ThreadLocal可以避免一些大對象的重複創建銷燬。

ThreadLocalMapEntry繼承自WeakReference,當沒有其他的強引用指向ThreadLocal變量時,該ThreadLocal變量會在下次GC中被回收。對於被回收掉的ThreadLocal變量,不會顯式地去清理,而是在接下來的getsetremove操作中去檢查刪除掉這些無效ThreadLocal變量所在的Entry,防止可能的內存泄漏。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章