深入理解Python的TLS機制和Threading.local()

1.背景介紹
我之前寫過一個關於Python的TLS機制的淺淺析,大家可以參考這個文章,首先,我們再來熟悉熟悉什麼是TLS機制。

1.1 Thread Local Storage(線程局部存儲)
這個概念最早是相對於全局變量來說的,就是我們在編程的時候,會涉及到希望所有線程都能夠共享訪問同一個變量,在 Python/Go/C 中,我們就可以定義一個全局變量,這樣Global Variable 對多個線程就是可見的,因爲同一個進程所有線程共享地址空間,大家都可以操作。例如,一個全局的配置變量或單實例對象,所有線程就可以很方便訪問了,但是僅僅這樣有一個前提,就是這個變量的併發操作必須是冪等的,讀寫不影響我們程序的正確性。但是往往多線程共同操作一個全局變量,就會影響程序的正確性,因此我們必須枷鎖,比如經典的併發加操作。

import threading
count = 0
lock = threading.RLock()
def inc():

global count, lock
with lock:
    count += 1

上面那個例子很多博客用來做ThreadLocal變量的講解,實際上我覺得是有誤導的,不恰當的。因爲這種共享變量,你必須枷鎖,因爲他的目的就是爲了大家一起去更新一個共享變量,多線程環境下必須枷鎖。就算你使用ThreadLocal替換也沒用,ThreadLocal能替換這個Count變量讓所有線程單獨存儲一份麼,不滿足需求。你單獨存一份,更改之後還得把結果再次寫回到全局變量去更新,那寫會的過程還是得枷鎖。除非使用Golang中的單Channel更新機制,才能避免枷鎖。

所以ThreadLocal變量使用強調的側重點不在這裏,更多的是在編程範式上面。其實就是有些時候,我們某個變量類型很多函數或者類都需要用,但是我又不想寫死在代碼裏,每次傳遞參數都要傳遞這個類或者變量,因爲一旦這個類發生類型上的變化,可能對於靜態類型的語言,很多地方就得修改參數,而且這種變量一直在程序代碼的參數傳遞中層層出現,你如果寫過代碼就會有感覺,有時候你設計的函數API好像一層層的得把一個參數傳遞進去,即使某些層好像用不到這個參數。

def getMysqlConn(passwd, db, host="localhost", port=3306, user="root", charset='utf8'):

conn = MySQLdb.connect(host=host, port=port, user=user, passwd=passwd, db=db, charset=charset)
return conn

def func1(zzz, passwd, db, host="localhost", port=3306, user="root", charset='utf8'):

conn = getMysqlConn(passwd, db, host, port, user, charset)
...

def func2(xxx,yyy,zzz, passwd, db, host="localhost", port=3306, user="root", charset='utf8'):

...
func1(zzz,passwd,db,host,port,user,charset)

上面的代碼你可能會瘋掉。那麼你可能就考慮想把這個參數提出來,當成全局變量算了,哪一層用到了直接用就好了,不能讓他無緣無故的不停的被當成局部變量傳參。文章Alternatives to global variables and passing the same value over a long chain of calls描述了這個問題,但是這個時候出現的問題就是,可能其他代碼線程會不可控的更改這個變量,導致你的程序發生未知錯誤。你把這種參數變成全局的暴露出來,那麼基於的假設就是該參數不會被隨意修改!一旦這個假設崩塌,你的程序可能會發生災難後果。這不符合軟件設計的開閉原則。所以我們使用TLS技術化解這種矛盾。

那麼我們就設計了一種方案,就是有這樣一種變量,他是全局的,但是每個線程在訪問的時候都會存儲一份成爲自己的局部變量,修改就不會相互影響了。比如 Linux/Unix的 C 程序庫 libc的全局變量errno, 這個其實就是TLS的例子。當系統調用從內核空間返回用戶空間時,如果系統調用出錯,那麼便設置errno的值爲一個負值,這樣就不需要每次在函數內部定義局部變量。但是當多線程的概念和技術被提出後,這套機制就不再適用了,可以使用局部變量,但是不太可能去更改已有的代碼了,比較好的解決方案是讓每個線程都有自己的errno。實際上,現在的C庫函數不是把出錯代碼寫入全局量errno,而是通過一個函數__errno_location()獲取一個地址,再把出錯代碼寫入該地址,其意圖就是讓不同的線程使用不同的出錯代碼存儲地點,而errno,現在一般已經變成了一個宏定義。每一個線程都會維護自己的一份,修改不影響其他線程。

這是不是意味着ThreadLocal對象不用枷鎖了? 其實這個ThreadLocal和同步沒有關係,他僅僅是提供了一種方便每個線程快速訪問變量的方式,但是如果這個對象本身有些共享狀態需要大家一起維護(比如Count++),你就必須枷鎖,儘管每個線程操作的是ThreadLocal副本。維基百科上有以下原話:

A second use case would be multiple threads accumulating information into a global variable. To avoid a race condition, every access to this global variable would have to be protected by a mutex. Alternatively, each thread might accumulate into a thread-local variable (that, by definition, cannot be read from or written to from other threads, implying that there can be no race conditions). Threads then only have to synchronise a final accumulation from their own thread-local variable into a single, truly global variable.

比如我們寫了一個共享的Manager類,這個類可能是用來做數據庫連接,網絡連接或者其他的做底層管理功能。我們有很多線程需要使用這個Manager的某些功能,並且這種類不是用來表示一種狀態,供所有線程併發修改其狀態並將最終修改的結果表現在該類上面(上面count的例子)。Manager只是可以提供給線程使用某些功能,然後每個線程可以把這個Manager複製一份成爲自己的局部變量,自己可以隨意修改,但是不會影響到其他線程,因爲是複製的一份。但是如果你需要讓管理器記錄所有的連接操作次數,那麼多線程對立面的某些變量訪問比如Count就需要枷鎖了。

2.TLS 在Python中的運用和實現
2.1 簡單使用
ThreadLocal不僅僅可以解決全局變量訪問衝突,其實還有其他好處,在PEP266中有提到,ThreadLocal變量是可以減少指令加速運算的,因爲全局變量往往需要更多的指令(需要for loop)來做查詢訪問,而ThreadLocal 之後,有了索引表,直接可以一條指令找到這個對象。

import threading

userName = threading.local()

def SessionThread(userName_in):

userName.val = userName_in
print(userName.val)   

Session1 = threading.Thread(target=SessionThread("User1"))
Session2 = threading.Thread(target=SessionThread("User2"))

start the session threads

Session1.start()
Session2.start()

wait till the session threads are complete

Session1.join()
Session2.join()
上述Threadlocal的實現原理類似有一個全局的詞典,詞典的key是線程id,value就是共享的全局變量的副本。每次訪問全局變量的時候,你訪問到的其實是副本,只是Python使用黑魔法幫我們屏蔽了這個userName.val 的訪問細節,其實他訪問的是詞典中的對應線程所擁有的對象副本。

2.2 實現源碼分析
all = ["local"]
class _localbase(object):

__slots__ = '_local__key', '_local__args', '_local__lock'

def __new__(cls, *args, **kw):
    # 新建一個類對象
    self = object.__new__(cls)
    # 在主線程中初始化這個這個全局對象的某些屬性,比如 `_local__key`, 這個key以後會用作其他線程使用全局變量副本的查詢依據,以後每個線程都會根據這個key來查找自己的局部副本數據
    key = '_local__key', 'thread.local.' + str(id(self))
    object.__setattr__(self, '_local__key', key)
    object.__setattr__(self, '_local__args', (args, kw))
    # 多線程會併發設置全局變量的屬性,這時候會併發訪問設置屬性,因此需要一把全局鎖,進行互斥操作
    object.__setattr__(self, '_local__lock', RLock())

    if (args or kw) and (cls.__init__ is object.__init__):
        raise TypeError("Initialization arguments are not supported")

    # We need to create the thread dict in anticipation of
    # __init__ being called, to make sure we don't call it
    # again ourselves.
    dict = object.__getattribute__(self, '__dict__')
    current_thread().__dict__[key] = dict

    return self

def _patch(self):

# 拿到全局的key
key = object.__getattribute__(self, '_local__key')
# 在當前線程中根據key找到線程的私有數據副本,並替換掉 ThreadLocal自己的__dict__屬性。如果沒有,就創建一個,並添加
d = current_thread().__dict__.get(key)
if d is None:
    d = {}
    # 線程還沒得私有數據副本,創建一個並加入線程自己的屬性中
    current_thread().__dict__[key] = d
    # 替換ThreadLocal的__dict__爲當前線程的私有數據詞典d
    object.__setattr__(self, '__dict__', d)

    # we have a new instance dict, so call out __init__ if we have
    # one
    # 這段的意思其實是,如果原來的全局變量ThreadLocal 本身有一些其他的屬性和數據,那麼直接替換掉一個新dict之後,以前的數據就丟失了,這裏我們必須初始化以前的數據到新dict中
    cls = type(self)
    if cls.__init__ is not object.__init__:
        args, kw = object.__getattribute__(self, '_local__args')
        cls.__init__(self, *args, **kw)
else:
    object.__setattr__(self, '__dict__', d)

class local(_localbase):

def __getattribute__(self, name):
    lock = object.__getattribute__(self, '_local__lock')
    lock.acquire()
    try:
        _patch(self)
        return object.__getattribute__(self, name)
    finally:
        lock.release()

def __setattr__(self, name, value):
    if name == '__dict__':
        raise AttributeError(
            "%r object attribute '__dict__' is read-only"
            % self.__class__.__name__)
    # 拿到早已經在主線程設置的共享的一把鎖
    lock = object.__getattribute__(self, '_local__lock')
    lock.acquire()
    try:
        _patch(self)# 關鍵代碼,這個patch會導致 Threadlocal 這個數據的__dict__直接被換成了所在線程自己的私有數據, Python 裏面有很多這種patch的替換手段,就是直接把基礎庫的某些功能和函數直接替換成了第三方庫的比如monkey patch
        # 再次設置屬性的時候,設置的__dict__ 其實不是 Threadlocal 自己的屬性了,是而是當前所在線程的__dict__的某一個key-value 副本數據的value,這個value是一個dict
        # object 的setattr默認行爲其實就是在自己的__dict__對象中添加一對key-pair,但是現在他的__dict__已經更換成所在線程的一個數據副本詞典了,黑魔法替換就在這裏
        return object.__setattr__(self, name, value)
    finally:
        lock.release()

def __delattr__(self, name):
    if name == '__dict__':
        raise AttributeError(
            "%r object attribute '__dict__' is read-only"
            % self.__class__.__name__)
    lock = object.__getattribute__(self, '_local__lock')
    lock.acquire()
    try:
        _patch(self)
        return object.__delattr__(self, name)
    finally:
        lock.release()

def __del__(self):
    import threading

    key = object.__getattribute__(self, '_local__key')

    try:
        # We use the non-locking API since we might already hold the lock
        # (__del__ can be called at any point by the cyclic GC).
        threads = threading._enumerate()
    except:
        # If enumerating the current threads fails, as it seems to do
        # during shutdown, we'll skip cleanup under the assumption
        # that there is nothing to clean up.
        return

    for thread in threads:
        try:
            __dict__ = thread.__dict__
        except AttributeError:
            # Thread is dying, rest in peace.
            continue

        if key in __dict__:
            try:
                del __dict__[key]
            except KeyError:
                pass # didn't have anything in this thread

from threading import current_thread, RLock

data = local()
print (data.__dict__)
def t(x):

global data
data.x = x
data.y = 1
print (current_thread().__dict__)
print (data.__dict__)

t1 = threading.Thread(target=t, args = (777,))
t2 = threading.Thread(target=t, args = (888,))
print current_thread().__dict__
t1.start()
t2.start()
t1.join()
t2.join()
print(data.__dict__)
關鍵技術就在patch上面,Python 裏面有很多這種patch的替換手段,就是直接把基礎庫的某些功能和函數直接替換成了第三方庫的比如monkey patch. 再次設置屬性的時候,設置的 dict 其實不是ThreadLocal自己的,是而是當前所在線程的__dict__ 的某一個key-value 副本數據,key 就是線程訪問的某個TLS變量生成的(一個線程可以有很多TLS變量,每個有不同的key),value是一個dict. object的 setattr默認行爲其實就是在自己的__dict__對象中添加一對key-pair,但是現在他的__dict__已經更換成所在線程的一個數據副本詞典dict了,黑魔法替換就在這裏.

下面的例子展示了Python黑魔法的一個替換詞典的方式,可以運行看看

class A:

def substitute(self, d):
    object.__setattr__(self, '__dict__', d)

a = A()
a.y = 3
old_dict = a.__dict__
print(old_dict)
d = {'x':1}
a.substitute(d)
print(a.__dict__)
a.y = 777
print(a.__dict__)
print(d)

OUTPUT

{'y': 3}
{'x': 1}
{'x': 1, 'y': 777}
{'x': 1, 'y': 777}
如果A本身已經含有一些數據,那就不能簡單的直接複製了,還需要初始化以前的數據填充新的詞典,這也是在源碼中看到的。

from threading import current_thread
class A:

def __new__(cls, *args, **kw):
    self = object.__new__(cls)
    setattr(cls, '_local__args', (args, kw))
    return self

def __init__(self, *args, **kw):
    self.shared_x = kw["shared_x"]
    self.shared_y = kw["shared_y"]
def substitute(self, d):
    object.__setattr__(self, '__dict__', d)
    cls = type(self)
    if cls.__init__ is not object.__init__:
        print("7---------------")
        args, kw = getattr(self, '_local__args')
        cls.__init__(self, *args, **kw)

a = A(shared_x=111, shared_y=222)
a.y = 3
old_dict = a.__dict__
print(old_dict)
d = {'x':1}
a.substitute(d)
print(a.__dict__)
a.y = 777
print(a.__dict__)
print(d)
print(old_dict)
下圖就是訪問每個線程訪問過程,實際上操作的是線程自己的私有數據副本。同時需要注意的還是那句話,使用 ThreadLocal對象不意味着你的程序不需要再枷鎖,比如這個 ThreadLocal 對象可能又引用了其他共享狀態的對象,那麼就要對這個共享狀態對象的操作進行枷鎖實現同步和互斥。

ThreadLocal 實現過程
3 TLS 在Java 中的運用和實現
3.1 簡單使用
public class ThreadLocalExample {

public static class MyRunnable implements Runnable {

    private ThreadLocal threadLocal = new ThreadLocal();

    @Override
    public void run() {
        threadLocal.set((int) (Math.random() * 100D));
        try {
        Thread.sleep(2000);
        } catch (InterruptedException e) {

        }
        System.out.println(threadLocal.get());
    }
}

public static void main(String[] args) {
     MyRunnable sharedRunnableInstance = new MyRunnable();
     Thread thread1 = new Thread(sharedRunnableInstance);
     Thread thread2 = new Thread(sharedRunnableInstance);
     thread1.start();
     thread2.start();
}

}
3.2 源碼實現
有了Python版本的分析,Java版本就不再多做解釋,感興趣的可以看看源碼,實現原理肯定都是大同小異,只是語言上的差異,導致 Java 不可能像Python這種動態類型語言一樣靈活。

需要每個線程都維護一個 key-value 集合數據結構,記錄每個線程訪問到的 TLS 變量副本,這樣每個線程可以根據 key 來找到相應的 TLS副本數據,對副本數據進行真實的操作,而不是TLS全局變量或者靜態類(Java中)。在Python中直接很簡單的使用了動態數據綁定的詞典數據結構,在Java中稍顯麻煩,需要實現一個類似Map的結構,ThreadLocal.get() 方法其實本質上也是和Python中一樣,先獲取當前線程自己的ThreadLocalMap對象(就是每個線程維護的TLS key-value集合啦)。再從ThreadLocalMap對象中找出當前的ThreadLocal變量副本,和HashMap一樣的採用了鏈地址法的hash結構。可以參考文章Java 多線程(7): ThreadLocal 的應用及原理。Java 裏一般是採用泛型規定你共享的變量類型,然後每個線程維護該變量的副本。

  1. 小結
    TLS技術的使用和屬性:

解決多線程編程中的對同一變量的訪問衝突的一種技術,TLS會爲每一個線程維護一個和該線程綁定的變量的副本。而不是無止盡的傳遞局部參數的方式編程。
每一個線程都擁有自己的變量副本,並不意味着就一定不會對TLS變量中某些操作枷鎖了。
Java平臺的java.lang.ThreadLocal和Python 中的threading.local()都是TLS技術的一種實現,。
TLS使用的缺陷是,如果你的線程都不退出,那麼副本數據可能一直不被GC回收,會消耗很多資源,比如線程池中,線程都不退出,使用TLS需要非常小心。
TLS技術的實現原理:
需要每個線程都維護一個 key-value集合數據結構,記錄每個線程訪問到的 TLS變量副本,這樣每個線程可以根據 key來找到相應的 TLS副本數據,對副本數據進行真實的操作,而不是TLS全局變量或者靜態類(Java中).

TLS變量自己會根據當前調用他的Thread對象,根據Thread對象得到該線程維護的 TLS 副本集合,然後進一步根據當前TLS的key,查到到key對一個的TLS副本數據。這樣就給每個線程造成一種假象,以爲大家可以同時更新一個全局共享變量或者靜態類對象。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章