ThreadLocal、InheritableThreadLocal、ThreadLocalRandom

ThreadLocal

ThreadLocalJDk包提供的,它提供了線程的本地變量,也就是如果你創建了一個ThreadLocal變量,那麼訪問這個變量的每個線程都會有這個變量的一個本地副本。

當多個線程操作這個變量時,實際操作的是自己本地內存裏面的變量,從而避免了線程安全問題。

使用示例

public class ThreadLocalTest {
   private static ThreadLocal<String> threadLocal = new ThreadLocal<>();
    
   private static void print(String str) {
      System.out.println(str + ":" + threadLocal.get());
      threadLocal.remove();
   }
    
   public static void main(String[] args) {
      new Thread(() -> {
         threadLocal.set("test t1 ThreadLocal variable");
         print("t1");
         System.out.println("t1 remove after:" + threadLocal.get());
      }).start();

      new Thread(() -> {
         threadLocal.set("test t2 ThreadLocal variable");
         print("t2");
         System.out.println("t2 remove after:" + threadLocal.get());
      }).start();
   }
}

/*
t1:test t1 ThreadLocal variable
t2:test t2 ThreadLocal variable
t1 remove after:null
t2 remove after:null
*/

源碼

Thread 類中有兩個變量

/* ThreadLocal values pertaining to this thread. This map is maintained
 * by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;

/*
 * InheritableThreadLocal values pertaining to this thread. This map is
 * maintained by the InheritableThreadLocal class.
 */
ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

每個線程的本地變量不是存放在ThreadLocal實例裏面,而是存放在在調用線程的threadLocals變量裏面。

ThreadLocal就是一個工具殼,通過set方法把value值放入線程的threadLocals變量裏面並存放起來,當調用線程的get方法時,再從當前線程的threadLocals變量裏面取出。

如果線程一直不終止,那麼這個本地變量會一直存放在調用線程的threadLocals變量裏面,所以不需要本地變量的時可以通過remove方法將其刪除。

set

public void set(T value) {
	Thread t = Thread.currentThread();
    // 獲取當前線程的 threadLocals 變量
	ThreadLocalMap map = getMap(t);
	if (map != null)
		map.set(this, value);
	else
    // 第一次調用時就創建 ThreadLocalMap
		createMap(t, value);
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

get

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    // threadLocals 變量爲空,則初始化當前線程的 threadLocals 變量
    return setInitialValue();
}

private T setInitialValue() {
    // 初始化 value 爲 null
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

protected T initialValue() {
    return null;
}

remove

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

InheritableThreadLocal

同一個ThreadLocal變量在父線程中被設置後,在子線程中是獲取不到的。

InheritableThreadLocal可以解決這個問題。

使用示例

public class InheritableThreadLocalTest {

   private static ThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();

   public static void main(String[] args) {
      inheritableThreadLocal.set("hello world");
      new Thread(() -> {
         // 輸入 inheritableThreadLocal 中的值
         System.out.println("thread: " + inheritableThreadLocal.get());
      }).start();
      System.out.println("main: " + inheritableThreadLocal.get());
      // 習慣性回收不用的變量
      inheritableThreadLocal.remove();
   }
}
/*
main: hello world
thread: hello world
*/

源碼

InheritableThreadLocal重寫了三個方法childValuegetMapcreateMap,使用了Thread的變量inheritableThreadLocals.

public class InheritableThreadLocal<T> extends ThreadLocal<T> {

    protected T childValue(T parentValue) {
        return parentValue;
    }

    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

childValueThread的構建方法調用時執行。

public Thread(Runnable target) {
    init(null, target, "Thread-" + nextThreadNum(), 0);
}

private void init(ThreadGroup g, Runnable target, String name, long stackSize) {
    init(g, target, name, stackSize, null);
}

private void init(ThreadGroup g, Runnable target, String name, long stackSize, AccessControlContext acc) {
    // 獲取父線程
    Thread parent = currentThread();
    // ......
    if (parent.inheritableThreadLocals != null)
        this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
    /* Stash the specified stack size in case the VM cares */
    this.stackSize = stackSize;

    /* Set thread ID */
    tid = nextThreadID();
}

static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
    return new ThreadLocalMap(parentMap);
}

private ThreadLocalMap(ThreadLocalMap parentMap) {
    // 獲取父線程 ThreadLocalMap 的所有 Entry
    Entry[] parentTable = parentMap.table;
    int len = parentTable.length;
    // Set the resize threshold to maintain at worst a 2/3 load factor.
    setThreshold(len);
    // 初始化當前線程的 table 值
    table = new Entry[len];
	// 遍歷父線程的 table
    for (int j = 0; j < len; j++) {
        Entry e = parentTable[j];
        if (e != null) {
            @SuppressWarnings("unchecked")
            ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
            if (key != null) {
               	// 調用 InheritableThreadLocal 重寫的方法 childValue
                Object value = key.childValue(e.value);
                Entry c = new Entry(key, value);
                int h = key.threadLocalHashCode & (len - 1);
                while (table[h] != null)
                    h = nextIndex(h, len);
                // 將新生成的 Entry 方到當前線程的 table 中
                table[h] = c;
                size++;
            }
        }
    }
}

Random

public int nextInt(int bound) {
    if (bound <= 0)
        throw new IllegalArgumentException(BadBound);
	// 1.根據老的種子生成新的種子
    int r = next(31);
    // 2.根據新的種子計算隨機數
    int m = bound - 1;
    if ((bound & m) == 0)  // i.e., bound is a power of 2
        r = (int)((bound * (long)r) >> 31);
    else {
        for (int u = r; u - (r = u % bound) + m < 0; u = next(31))
            ;
    }
    return r;
}

// 採用CAS的方式生成新的種子,多線程下進行CAS只會有一個線程會成功,所以會造成大量線程進行自旋重試,這會降低併發性能
protected int next(int bits) {
    long oldseed, nextseed;
    AtomicLong seed = this.seed;
    do {
        oldseed = seed.get();
        nextseed = (oldseed * multiplier + addend) & mask;
    } while (!seed.compareAndSet(oldseed, nextseed));
    return (int)(nextseed >>> (48 - bits));
}

ThreadLocalRandom

爲了彌補Random的缺陷。新增了ThreadLocalRandom,在JUC包下。

ThreadLocalRandom繼承了Random,並重寫了nextInt等方法,沒有使用父類的原子性種子變量。

ThreadLocalRandom中並沒有存放具體的種子,而是存放在Thread中的threadLocalRandomSeed變量裏面。

當線程調用ThreadLocalRandomcurrent方法時,會初始化此種子變量。

主要代碼實現邏輯

// Unsafe mechanics
private static final sun.misc.Unsafe UNSAFE;
private static final long SEED;
private static final long PROBE;
private static final long SECONDARY;
static {
    try {
        UNSAFE = sun.misc.Unsafe.getUnsafe();
        Class<?> tk = Thread.class;
        SEED = UNSAFE.objectFieldOffset
            (tk.getDeclaredField("threadLocalRandomSeed"));
        PROBE = UNSAFE.objectFieldOffset
            (tk.getDeclaredField("threadLocalRandomProbe"));
        SECONDARY = UNSAFE.objectFieldOffset
            (tk.getDeclaredField("threadLocalRandomSecondarySeed"));
    } catch (Exception e) {
        throw new Error(e);
    }
}

current

public static ThreadLocalRandom current() {
    // 判斷 threadLocalRandomProbe 是否爲0
    if (UNSAFE.getInt(Thread.currentThread(), PROBE) == 0)
        // 計算當前線程的初始化種子變量
        localInit();
    return instance;
}

static final void localInit() {
    int p = probeGenerator.addAndGet(PROBE_INCREMENT);
    int probe = (p == 0) ? 1 : p; // skip 0
    long seed = mix64(seeder.getAndAdd(SEEDER_INCREMENT));
    Thread t = Thread.currentThread();
    UNSAFE.putLong(t, SEED, seed);
    UNSAFE.putInt(t, PROBE, probe);
}

nextInt

public int nextInt(int bound) {
    if (bound <= 0)
        throw new IllegalArgumentException(BadBound);
    // 根據當前線程種的種子計算新種子
    int r = mix32(nextSeed());
    // 根據新種子和bound計算隨機數
    int m = bound - 1;
    if ((bound & m) == 0) // power of two
        r &= m;
    else { // reject over-represented candidates
        for (int u = r >>> 1; u + m - (r = u % bound) < 0;  u = mix32(nextSeed()) >>> 1)
            ;
    }
    return r;
}

nextSeed

final long nextSeed() {
    Thread t; long r; // read and update per-thread seed
    UNSAFE.putLong(t = Thread.currentThread(), SEED,
                   r = UNSAFE.getLong(t, SEED) + GAMMA);
    return r;
}

首先使用r = UNSAFE.getLong(t, SEED)獲取當前線程中threadLocalRandomSeed變量的值,然後在種子的基礎上加GAMMA值作爲新種子。

然後使用UNSAFE.putLong把新種子放入當前線程的threadLocalRandomSeed變量中。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章