ThreadPoolExecutor源碼分析

ThreadPoolExecutor



前言

Python中,ThreadPoolExecutor對Thread做了進一步封裝。在Thread基礎之上,使得多線程開發更簡單了。另一方面,由於還存在ProcessPoolExecutor類,多線程與多進程的開發接口得到了統一。

在整個過程中,需要理清ThreadPoolExecutor的成員方法Future的成員方法

ThreadPoolExecutor

一個簡單的多線程

ThreadPoolExecutor在concurrent.futures模塊下,一個簡單的多線程代碼如下:

import time
from concurrent.futures import ThreadPoolExecutor

def print_hello():
    for i in range(10):
        time.sleep(1)
        print("h{} hello".format(i))

def print_world():
    for i in range(10):
        time.sleep(1)
        print("w{} world".format(i))

if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=2)
    task1th = executor.submit(print_hello)
    task2ed = executor.submit(print_world)

# 輸出:
h0 hello
w0 world
h1 hello
w1 world
h2 hello
...

可以見到,打印結果中的“hello”和“world”是交叉出現,這符合多線程行爲。

submit

在上面的demo中,ThreadPoolExecutor(max_workers=2)表示創建一個線程池,而它的管理人就是這裏的實例對象executor,executor有一個submit()方法,用來提交子線程需要執行的任務——在這裏分別是函數print_hello()和函數print_world(),每個任務對應一個線程。跟**threading.Thread()**不同,你不需要用什麼命令讓它“動”起來(threading.Thread()中需要start()),當你submit之後,子線程就去執行了。

下面是submit()方法的源碼:

# submit源碼
class ThreadPoolExecutor(_base.Executor):
	...
    def submit(self, fn, *args, **kwargs):
	    with self._shutdown_lock:
		    ...
	        f = _base.Future()
	        ...
	        self._adjust_thread_count()  # 在submit中執行了_adjust_thread_count()
	        return f  # 返回Future的對象
	                
    def _adjust_thread_count(self):
	    num_threads = len(self._threads)
        if num_threads < self._max_workers:
            thread_name = '%s_%d' % (self._thread_name_prefix or self,
                                     num_threads)
            # 創建一個線程對象t
            t = threading.Thread(name=thread_name, target=_worker,
                                 args=(weakref.ref(self, weakref_cb),
                                       self._work_queue,
                                       self._initializer,
                                       self._initargs))
            t.daemon = True
            t.start()  # 啓動線程
            ...

shutdown

注意,在_adjust_thread_count()中並沒有執行join()方法,也就是說,子線程執行的同時,主線程也會向下執行。ThreadPoolExecutor的成員函數中有一個shutdown(),當參數wait=True時可以用來阻塞主線程,其本質是調用了每個子線程的join()

# shutdown源碼
class ThreadPoolExecutor(_base.Executor):
	...
    def shutdown(self, wait=True):
	    ...
        if wait:
            for t in self._threads:  #  self._threads是用來存放子線程的集合
                t.join()  # 調用每個子線程的join方法

在前面代碼的基礎上做些小修改,shutdown()的作用立竿見影:

...
if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=2)
    task1th = executor.submit(print_hello)
    task2ed = executor.submit(print_world)
	# 使用shutdown方法
    executor.shutdown()
    print("zty")

# 輸出:
w0 world
h0 hello
...
w9 world
h9 hello
zty  # zty最後輸出
...
if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=2)
    task1th = executor.submit(print_hello)
    task2ed = executor.submit(print_world)
    # 不使用shutdown方法
    print("zty")

# 輸出:
zty  # zty最先輸出
w0 world
h0 hello
...
w9 world
h9 hello

map

提交線程任務除了submit()外,還提供了map()方法。此map與Python內置的map在使用上相似,它可以批量啓動相同函數的不同線程。

def print_num(num):
    print(num)

if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=3)
    executor.map(print_num, [1, 2, 3])

# 輸出:
1
2
3

上述代碼在運行過程中啓動了三個子線程,三個子線程又各自只打印了一個數字。事實上,在Python內部,map()的實現也是基於submit()。

# map源碼
class Executor(object):
    def map(self, fn, *iterables, timeout=None, chunksize=1):
	    ...
        fs = [self.submit(fn, *args) for args in zip(*iterables)]
        ...

map()函數接收參數的設計我認爲比較巧妙,在zip()的幫忙下,你可以傳遞多個可迭代對象。

def print_num(num, alpha):
    print(num, alpha)

if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=3)
    executor.map(print_num, [1, 2, 3], ["a", "b", "c"])

# 輸出:
1 a  # 注意成對輸出
2 b
3 c

最後,map會返回一個生成器,裏面放着線程的運行結果。

# map源碼
class Executor(object):
    def map(self, fn, *iterables, timeout=None, chunksize=1):
		...
        def result_iterator():
            try:
                # reverse to keep finishing order
                fs.reverse()
                while fs:
                    # Careful not to keep a reference to the popped future
                    if timeout is None:
                        yield fs.pop().result()  # 調用future的result()方法拿到結果
                    else:
                        yield fs.pop().result(end_time - time.time())
                        # 調用future的result()方法拿到結果
            ...
        return result_iterator()  # 返回這個生成器

使用示例:

def print_num(n1, n2):
    return n1+n2

if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=3)
    results = executor.map(print_num, [1, 2, 3], [11, 12, 13])
    for result in results:
        print(result)
# 輸出:
12
14
16

構造函數接收參數

迴歸到ThreadPoolExecutor類,查看它的__init__()方法:

# __init__源碼
class ThreadPoolExecutor(_base.Executor):
    def __init__(self, max_workers=None, thread_name_prefix='',
                 initializer=None, initargs=()):
	    ....

其接收四個參數,意義分別爲:

  • max_workers 表示允許的最大線程數量
  • thread_name_prefix 表示線程名字的前綴(thread_0、thread_1、thread_2中的thread)
  • initializer 表示子線程執行前需要執行的函數
  • initargs 表示initializer函數接收的參數
def print_num(num):
    print(num)

def print_two():
    print("two")

if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=1, 
							      initializer=print_num, 
							      initargs=("one",))
    task1th = executor.submit(print_two)

# 輸出:
one  # 在打印two之前,先打印了one
two

有以下幾點需要注意:

  • initializer代表的函數在調用submit()後被執行,而不是executor初始化時;
  • initializer代表的函數出現異常,後面的線程將不再執行。

整個邏輯被寫在了_work()函數中:

# concurrent/futures/thread.py下_work()源碼
def _worker(executor_reference, work_queue, initializer, initargs):
    if initializer is not None:
        try:
            initializer(*initargs)  # 執行初始化函數
        except BaseException:
            ...
            return
    try:
        while True:
            work_item = work_queue.get(block=True)  # 從隊列中取出線程
            if work_item is not None:
                work_item.run()  # 執行線程
                # Delete references to object. See issue16284
                del work_item
                continue
            ...

Future

submit()會返回Future對象。前面已經貼過submit的源碼,這裏不做重複。Future對象包含以下幾個重要方法:

  • add_done_callback() 接收一個函數名,當線程執行完後調用傳入函數;
  • result() 用於獲取線程執行的結果;
  • exception() 用於獲取線程執行過程中存在的異常。

add_done_callback()result()的使用示例:

def print_and_return_num():
    print(512)
    return 512

def print_three(*args):
    print("結束")

if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=1)
    task = executor.submit(print_and_return_num)
    
    print("程序執行結果:", task.result())  # 打印線程的執行結果
    task.add_done_callback(print_three)   # 線程結束後執行print_three

# 輸出:
512
程序執行結果: 512
結束

我們再來看看add_done_callback()的源碼:

# add_done_callback源碼
class Future(object):
	...
    def add_done_callback(self, fn):
        ...
        with self._condition:
            if self._state not in [CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED]:
                self._done_callbacks.append(fn)
                return
        fn(self)  # fn被調用

也就是說,add_done_callback接收的函數fn,必須接收一個參數,Python會在調用它時,默認把對象(self)傳進去。


同時,Future還提供了獲取線程狀態的三個方法:

  • cancelled() 當線程狀態是CANCELLED或者CANCELLED_AND_NOTIFIED,返回True;
  • running() 當線程狀態是RUNNING,返回True;
  • done() 當線程狀態是CANCELLEDCANCELLED_AND_NOTIFIED或者FINISHED,返回True。

還有一個我暫時不知道有什麼用的方法:cancel(),官方文檔對其釋義:

Attempt to cancel the call. If the call is currently being executed and cannot be cancelled then the method will return False, otherwise the call will be cancelled and the method will return True.
當程序正在執行時(RUNNING)或者處於不可以被取消的狀態時(FINISHED),返回False。否則取消調用,並且返回True。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章