ThreadLocal用法與實現原理

ThreadLocal用法與實現原理

1. 對ThreadLocal的理解

Java API中對ThreadLocal的描述是:該類提供了線程局部 (thread-local) 變量。這些變量不同於它們的普通對應物,因爲訪問某個變量(通過其 get 或 set 方法)的每個線程都有自己的局部變量,它獨立於變量的初始化副本。ThreadLocal 實例通常是類中的 private static 字段,它們希望將狀態與某一個線程(例如,用戶 ID 或事務 ID)相關聯。

我們先來看一個例子:

package com.threadlocal.demo;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;

/**
 * 數據庫操作輔助類
 * 
 * @author 小明
 *
 */
public class DbSession {

    private static String DRIVER = "com.mysql.jdbc.Driver"; // 驅動字符串
    private static String URL = "jdbc:mysql:///test"; // 連接字符串
    private static String USER = "root"; // 用戶名
    private static String PASSWORD = "123456"; // 密碼
    private static Connection connection; // 連接對象

    /**
     * 加載驅動
     */
    static {
        try {
            Class.forName(DRIVER);
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
    }

    /**
     * 打開數據庫連接
     * 
     * @return 連接對象
     */
    public static Connection openConnection() {
        if (connection == null) {
            try {
                connection = DriverManager.getConnection(URL, USER, PASSWORD);
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
        return connection;
    }

    /**
     * 關閉連接資源
     */
    public static void closeConnection() {
        if (connection != null) {
            try {
                connection.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }
}

這是一個數據庫連接管理類,在單線程中使用是沒有任何問題的,但是在多線程中使用就會出現線程安全問題:這裏打開和關閉連接的2個方法都沒有進行同步,很可能在openConnection()方法中會多次創建connection對象。由於connection是共享變量,那麼就有必要在調用connection的地方使用同步來保障線程安全,因爲很可能一個線程在使用connection進行數據庫操作(如讀數據),而另外一個線程調用closeConnection()關閉鏈接。這時,我們可以使得線程同步來解決這個問題。

那麼我們要實現線程同步,使用同步方法或是同步塊的作法是否可行呢?我們想像一下,在線程同步時,一個線程在使用connection進行數據庫操作的時候,其他線程只有等待,這就將大大影響程序執行效率。

我們實現線程同步,其主要目的是保障共享資源的線程安全,那麼這兒的connection連接資源是否真的需要共享呢?事實上,這是不需要的。假如每個線程中都有一個connection變量,各個線程之間對connection變量的訪問實際上是沒有依賴關係的,是相互獨立的,即一個線程不需要關心其他線程是否對這個connection進行了修改。

既然是這樣,那我們就不使用靜態成員吧,將所有變量與方法的static都去掉。這樣做又會不會有什麼問題呢?

這時,如果我們要使用連接對象來操作數據,就得先創建DbSession對象,然後通過對象來調用相應的建立連接、關閉連接的方法。這就又出現了一個問題:服務器壓力增大,並且嚴重影響程序執行性能。由於在方法中需要頻繁的打開和關閉數據庫連接,這樣不僅嚴重影響程序執行效率,還可能導致服務器壓力巨大,這就好比我們要過一條河,每次過河前先收集原材料在河上建一座橋,過河之後又將橋拆掉,那麼在建橋和拆橋的時候,所消耗的資源是非常多的。

那我們到底使用什麼方法能夠使得這個問題比較完美的解決呢,這就是我們的ThreadLocal類。

2. 深入理解ThreadLocal

從Java API來看,ThreadLocal主要表達了下面幾種觀點:

  1. ThreadLocal不是線程,是線程的一個局部變量,可以先簡單理解爲線程類的屬性。
  2. 每個線程有自己的一個ThreadLocal,它是變量的一個副本(也稱拷貝),所以修改它不影響其他線程。
  3. ThreadLocal在類中通常定義爲靜態類變量。

ThreadLocal在每個線程中對該變量會創建一個副本,即每個線程內部都會有一個該變量,且在線程內部任何地方都可以使用,線程之間互不影響,這樣一來就不存在線程安全問題,也不會嚴重影響程序執行性能。

需要注意的是,雖然ThreadLocal能夠解決上面第1節所說的問題,但是由於在每個線程中都創建了副本,所以要考慮它對資源的消耗,比如內存的佔用會比不使用ThreadLocal要大,這是典型的“以空間換時間”的設計方式,而我們以前用到的synchronized同步是“以時間換空間”的設計方式。

2.1 常用方法介紹

ThreadLocal方法介紹:

返回類型 方法 說明
T get() 返回此線程局部變量的當前線程副本中的值。
protected T initialValue() 返回此線程局部變量的當前線程的“初始值”,一般是用來在使用時進行重寫的,它是一個延遲加載方法。
void remove() 移除此線程局部變量當前線程的值。
void set(T value) 將此線程局部變量的當前線程副本中的值設置爲指定值。

改進後的數據庫操作輔助類:

package com.threadlocal.demo;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;

/**
 * 數據庫操作輔助類
 * 
 * @author 小明
 *
 */
public class DbSession {

    private static String DRIVER = "com.mysql.jdbc.Driver"; // 驅動字符串
    private static String URL = "jdbc:mysql:///test"; // 連接字符串
    private static String USER = "root"; // 用戶名
    private static String PASSWORD = "123456"; // 密碼
    private static ThreadLocal<Connection> threadLocal = new ThreadLocal<Connection>(); // ThreadLocal對象

    /**
     * 加載驅動
     */
    static {
        try {
            Class.forName(DRIVER);
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
    }

    /**
     * 打開數據庫連接
     * 
     * @return 連接對象
     */
    public static Connection openConnection() {
        Connection connection = threadLocal.get(); // 獲取ThreadLocal中保存的連接對象
        if (connection == null) {
            try {
                connection = DriverManager.getConnection(URL, USER, PASSWORD); // 創建連接對象
                threadLocal.set(connection); // 將連接對象保存到ThreadLocal對象中
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
        return connection;
    }

    /**
     * 關閉連接資源
     */
    public static void closeConnection() {
        Connection connection = threadLocal.get(); // 獲取ThreadLocal中保存的連接對象
        if (connection != null) { // 不爲空則釋放資源
            try {
                threadLocal.set(null); // 將ThreadLocal中的連接對象置空
                connection.close(); // 關閉連接對象
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }
}

2.2 深入源碼(JDK1.6)

我們先看一下get()方法:

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null)
            return (T)e.value;
    }
    return setInitialValue();
}

先取得當前線程,然後通過getMap(t)方法獲取到一個map,map的類型爲ThreadLocalMap。map不爲空,則獲取到key-value鍵值對,如果獲取成功,則返回value值。如果map爲空,則調用setInitialValue()方法返回value。

再進一層,我們看看getMap()作了什麼:

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

在getMap中,返回當前線程t中的一個成員變量threadLocals,threadLocals又是什麼?

ThreadLocal.ThreadLocalMap threadLocals = null;

它實際上是一個ThreadLocalMap,是ThreadLocal類的一個靜態內部類,我們繼續取部分ThreadLocalMap的實現:

static class ThreadLocalMap {

    private Entry[] table;

    static class Entry extends WeakReference<ThreadLocal> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal k, Object v) {
            super(k);
            value = v;
        }
    }

    private Entry getEntry(ThreadLocal key) {
        int i = key.threadLocalHashCode & (table.length - 1);
        Entry e = table[i];
        if (e != null && e.get() == key)
            return e;
        else
            return getEntryAfterMiss(key, i, e);
    }

    private void set(ThreadLocal key, Object value) {

        // We don't use a fast path as with get() because it is at
        // least as common to use set() to create new entries as
        // it is to replace existing ones, in which case, a fast
        // path would fail more often than not.

        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);

        for (Entry e = tab[i];
        e != null;
        e = tab[i = nextIndex(i, len)]) {
            ThreadLocal k = e.get();

            if (k == key) {
                e.value = value;
                return;
            }

            if (k == null) {
                replaceStaleEntry(key, value, i);
                return;
            }
        }

        tab[i] = new Entry(key, value);
        int sz = ++size;
        if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
    }
}

我們看到ThreadLocalMap的Entry繼承了WeakReference,使用ThreadLocal作爲鍵。通過getEntry()方法獲取到key-value鍵值對。set()方法將key-value鍵值對映射保存到table數組中,key存在,則替換value,key不存在,則保存新的映射。

再繼續看setInitialValue()方法:

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

initialValue()方法中返回null值,接下來仍然會去獲取ThreadLocalMap對象,不爲空,則設置鍵值對,若爲空,再創建ThreadLocalMap對象:

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

接着再來看一下ThreadLocal的set()方法:

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

set()方法相對比較簡單:獲取當前線程的引用,獲取該線程對應的map,如果map存在更新緩存值,否則創建並存儲。

至此,可能部分讀者已經明白了ThreadLocal是如何爲每個線程創建變量的副本的:

在每個線程Thread內部有一個ThreadLocal.ThreadLocalMap類型的成員變量threadLocals,這個threadLocals就是用來存儲實際的變量副本的,鍵(key)爲當前ThreadLocal變量,值(value)爲變量副本(即T類型的變量)。

初始時,在Thread裏面,threadLocals爲空,當通過ThreadLocal調用get()方法或者set()方法,就會對Thread類中的threadLocals進行初始化,並且以當前ThreadLocal變量爲key,以ThreadLocal要保存的副本變量爲value,存到threadLocals中。

然後在當前線程裏面,如果要使用副本變量,就可以通過get()方法在threadLocals裏面查找。

最常見的ThreadLocal使用場景是用來解決數據庫連接、Session管理等問題。

2.3 示例

下面我們來看一個示例,重現一個關於日期解析的問題,在重現這個問題之前還是要先來看一下SimpleDateFormat類中的parse()方法:

public Date parse(String text, ParsePosition pos) {
    // …… // 處理

    calendar.clear(); // 清空所有時間字段值

    // ……  // 處理

    parsedDate = calendar.getTime(); // 獲取時間

    // …… // 處理
}

這兒我只抽取出來這兩條語句,其中calendar是一個Calendar對象引用,它用來儲存和這個SimpleDateFormat相關的日期信息。如果SimpleDateFormat是static的,那麼多個Thread之間就會共享這個SimpleDateFormat,同時也就共享這個Calendar的引用。

問題重現:

package com.threadlocal.demo;

import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Date;

/**
 * 日期解析示例
 * 
 * @author 小明
 *
 */
public class DateParseDemo {
    public static void main(String[] args) {
        // 啓動線程1,解析"2015-10-1 00:00:00",休眠2秒鐘
        new DateParseThread("線程1", "2015-10-1 00:00:00", true, 2000).start();
        // 啓動線程2,解析"2012-3-8 15:37:22",不休眠
        new DateParseThread("線程2", "2012-3-8 15:37:22", false, 0).start();
    }
}

/**
 * 日期解析線程類
 * 
 * @author 小明
 *
 */
class DateParseThread extends Thread {
    private String name; // 線程名稱
    private String dateString; // 要解析的日期字符串
    private boolean isSleep; // 是否休眠
    private long sleepTime; // 休眠時長

    public DateParseThread(String name, String dateString, boolean isSleep,
            long sleepTime) {
        super();
        this.name = name;
        this.dateString = dateString;
        this.isSleep = isSleep;
        this.sleepTime = sleepTime;
    }

    @Override
    public void run() {
        if (isSleep) {
            try {
                Thread.sleep(sleepTime);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }

        Date date = DateParseUtil.parse(dateString);

        System.out.println("線程:" + this.name + ",日期時間爲:" + date);
    }
}

/**
 * 日期解析輔助類
 * 
 * @author 小明
 *
 */
class DateParseUtil {
    private static final String PATTERN = "yyyy-MM-dd HH:mm:ss"; // 格式
    private static SimpleDateFormat sdf; // SimpleDateFormat對象
    static {
        sdf = new SimpleDateFormat(PATTERN); // 創建基於給定模式的SimpleDateFormat對象
    }

    /**
     * 日期解析
     * 
     * @param dateString
     *            待解析字符串
     * @return 日期
     */
    public static Date parse(String dateString) {
        try {
            return sdf.parse(dateString);
        } catch (ParseException e) {
            e.printStackTrace();
        }

        return null;
    }
}

我們在執行前先在 calendar.clear()calendar.getTime() 之間打個斷點,然後使用Debug模式來執行這段代碼。

線程1啓動之後進入sleep(休眠)狀態,線程2啓動起來後卡在斷點處(暫停執行),這時calendar的日期時間是:”2012-3-8 15:37:22”。當線程1從休眠中醒來後接着向下執行,當執行到斷點處時,因爲calendar是被共享的資源,所以它的日期時間又變爲:”2015-10-1 00:00:00”。最後讓兩個線程斷點繼續執行,結果可想而知了:

線程:線程1,日期時間爲:Thu Oct 01 00:00:00 CST 2015
線程:線程2,日期時間爲:Thu Oct 01 00:00:00 CST 2015

在實際業務中,我們不會是在Debug模式下運行,但如果線程1調用了sdf.parse(),並且進行了calendar.clear()後還未執行calendar.getTime()的時候,線程2又調用了sdf.parse(),這時候線程2也執行了sdf.clear()方法,這樣就導致線程1的calendar數據被清空了(實際上線程1,2同時被清空了);又或者當線程1執行了calendar.clear()後被掛起,這時候線程2開始調用sdf.parse()並順利結束,這樣線程1的calendar內存儲的日期時間就變成了後來線程2設置的calendar的日期時間值。

那麼我們怎麼解決這種問題呢,最簡單地就是將靜態的SimpleDateFormat改爲實例SimpleDateFormat,這樣每個線程都會有一個自己的SimpleDateFormat實例。但使用這種方法,在高併發的情況下會大量的創建SimpleDateFormat對象以及銷燬SimpleDateFormat對象,這樣是非常耗費資源的。

我們就可以使用ThreadLocal來優化,將DateParseUtil代碼修改如下:

/**
 * 日期解析輔助類
 * 
 * @author 小明
 *
 */
class DateParseUtil {
    private static final String PATTERN = "yyyy-MM-dd HH:mm:ss"; // 格式

    private static ThreadLocal<SimpleDateFormat> threadLocal = new ThreadLocal<SimpleDateFormat>();

    /**
     * 獲取SimpleDateFormat對象
     * 
     * @return SimpleDateFormat對象
     */
    public static SimpleDateFormat getSimpleDateFormat() {
        SimpleDateFormat simpleDateFormat = threadLocal.get();
        if (simpleDateFormat == null) {
            simpleDateFormat = new SimpleDateFormat(PATTERN);
            threadLocal.set(simpleDateFormat);
        }
        return simpleDateFormat;
    }

    /**
     * 日期解析
     * 
     * @param dateString
     *            待解析字符串
     * @return 日期
     */
    public static Date parse(String dateString) {
        try {
            return getSimpleDateFormat().parse(dateString);
        } catch (ParseException e) {
            e.printStackTrace();
        }

        return null;
    }
}

重新在Debug模式下運行,結果爲:

線程:線程2,日期時間爲:Thu Mar 08 15:37:22 CST 2012
線程:線程1,日期時間爲:Thu Oct 01 00:00:00 CST 2015

這樣,在兩個線程中解析的時間就正確了。

當然,DateParseUtil類也可以修改爲:

/**
 * 日期解析輔助類
 * 
 * @author 小明
 *
 */
class DateParseUtil {
    private static final String PATTERN = "yyyy-MM-dd HH:mm:ss"; // 格式

    private static ThreadLocal<SimpleDateFormat> threadLocal = new ThreadLocal<SimpleDateFormat>() {
        protected SimpleDateFormat initialValue() {
            return new SimpleDateFormat(PATTERN);
        };
    };

    /**
     * 獲取SimpleDateFormat對象
     * 
     * @return SimpleDateFormat對象
     */
    public static SimpleDateFormat getSimpleDateFormat() {
        return threadLocal.get();
    }

    /**
     * 日期解析
     * 
     * @param dateString
     *            待解析字符串
     * @return 日期
     */
    public static Date parse(String dateString) {
        try {
            return getSimpleDateFormat().parse(dateString);
        } catch (ParseException e) {
            e.printStackTrace();
        }

        return null;
    }
}

這裏創建一個ThreadLocal類變量,創建時用了一個匿名類,覆蓋了initialValue()方法,主要作用是創建時初始化實例。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章