Python子線程異常捕獲 & Python自定義優化線程池

前言

雖然CPython因爲GIL的存在導致多線程的併發計算能力大打折扣,但是在i/o密集型的場景時,使用多線程還是能帶來效率上的飛躍。近期在使用多線程時遇到了主線程無法捕獲子線程拋出的異常問題,這裏記錄一下解決的辦法。

需求
將某一指定目錄下所有的文件(包含子目錄下的文件)中所有被$[]$字符包裹起來的變量替換成指定的值,這是一個典型的io密集的場景,因此考慮使用多線程提升效率

原demo

def main():
      conf_map = {'DB_HOST': "X.X.X.X",'DB_USER': "root",'DB_PASSWD': 'abcd1234'}
      conf_files=['/etc/abc/', '/var/abc']
      thpool = ThreadPoolExecutor(5)
      for file in conf_files:
          thpool.submit(replace_config, file, conf_map)
      thpool.shutdown(wait=True)

def replace_config(file, conf_map, tmp_conf_path):
    with open(file, 'r') as f:
        content = f.read()

    # 需替換的變量的樣式爲: $[PASSWORD]$
    wrapper_pattern = re.compile('\$\[([\w-]+)\]\$')
    var_list = wrapper_pattern.findall(content)

    for var in var_list:
        try:
            value = conf_map[var]
            wrapper = "$[%s]$" % var
            content = content.replace(wrapper, value)
        except KeyError:
            print('key error')
            os.system("mv {} /tmp".format(tmp_conf_path))
            raise Exception('[%s]文件中存在未知的key: %s' % (file, var))

    with open(file, 'w') as f:
        f.write(content)    

demo內直接使用python3.2版本以後引入的ThreadPoolExecutor庫使用多線程,在子線程無異常時是正常運行的,但是在子線程出現異常時(比如子線程的內部邏輯裏發現了不存在key時, except KeyError會捕獲到異常),但是你會發現子線程異常終止了,主線程但是卻沒有異常拋出。

經過一番搜索,在python官方手冊中瞭解到了原因:
在這裏插入圖片描述

參考官方文檔鏈接:
https://docs.python.org/zh-cn/3/library/_thread.html

原因即爲:

使用start()方法啓動子線程時,解釋器會爲子線程開闢獨立的棧空間,主線程自然就無法獲取子線程棧的信息。當線程異常中止時,會自行退出而不會將此異常raise到主線程。那麼得知了原因,就可以找到解決的辦法了。思路是繼承標準庫的Thread類,進行一些小的改寫封裝。

修改後的demo

class ReplaceThread(Thread):
    def __init__(self, file, conf_map):
        super(ReplaceThread, self).__init__()
        self.file = file
        self.conf_map = conf_map

        self.exitcode = 0
        self.exception = None

    def run(self):
        try:
            self._run()
        except Exception as e:
            # 如果捕獲到異常,返回值改爲不等於0
            self.exitcode = 1
            self.exception = e

    def _run(self):
        with open(self.file, 'r') as f:
            content = f.read()

        # 需替換的變量的樣式爲: $[PASSWORD]$
        wrapper_pattern = re.compile('\$\[([\w-]+)\]\$')
        var_list = wrapper_pattern.findall(content)

        for var in var_list:
            try:
                value = self.conf_map[var]
                wrapper = "$[%s]$" % var
                content = content.replace(wrapper, value)
            except KeyError:
                raise Exception('[%s]文件中存在未知的key: %s‘ % (self.file, var))

        with open(self.file, 'w') as f:
            f.write(content)
            
# 主線程無法直接捕獲子線程內的異常,因此自定義了Thread類,在子線程內定義其出現異常時的返回值,在主線程內根據返回值
# 來判斷是否出現異常,並進行下一步操作
 def main():
    conf_files=['/etc/abc/', '/var/abc']
    conf_map = {'DB_HOST': "X.X.X.X",'DB_USER': "root",'DB_PASSWD': 'abcd1234'}
    t_objs = []
    for file in conf_files:
        t = ReplaceThread(file=file, conf_map=conf_map)
        t.start()
        t_objs.append(t)
	
    for t in t_objs:
        t.join()
        if t.exitcode != 0:
            os.system("mv {} /tmp".format(tmp_conf_path))
            raise Exception(t.exception)  

如此這般,在主線程裏通過自定義的子線程返回值來判斷子線程是否有異常,如果子線程有異常則主線程接替拋出子線程裏的異常。這裏另外還要注意,子線程的join()操作要放到start()操作全部完成了之後再進行,避免主線程被子線程阻塞,這樣就變成了串行執行多線程就失去了意義了。

問題
這裏是以迭代對象來循環啓動多線程的,假設迭代對象數量很長,那就會啓動成百上千個線程,這是不願意看到的,爲了避免這種情況,可自定義線程池,僅需對上方demo中的main()方法再做一點小改動即可。

實現

 def main():
    conf_files=['/etc/abc/', '/var/abc']
    conf_map = {'DB_HOST': "X.X.X.X",'DB_USER': "root",'DB_PASSWD': 'abcd1234'}
    
    # 實現的方法是對迭代的對象進行截取,每次只截取前10個對象,執行完這10個對象的操作後再截取隨後後的10個對象,直到迭代對象爲空。
    n = 0
    pool_size = 10
    files = conf_files
    file_num = len(files)
    while files:
        t_objs = []
        start_index = n * pool_size
        end_index = (n + 1) * pool_size

        for i in range(pool_size):
            current_index = start_index + i
            if current_index < file_num:
                file = self.conf_files[current_index]
                t = ReplaceThread(file=file, conf_map=conf_map)
                t.start()
                t_objs.append(t)

        for t in t_objs:
            t.join()
            if t.exitcode != 0:
                os.system("rm -rf {}".format(base_conf_path))
                raise Exception(t.exception)

        n += 1
        files = conf_files[end_index:]

複雜度優化:
上方的方法是對迭代的對象列表進行切片,每次只截取前10個對象,執行完這10個對象的操作後再截取隨後後的10個對象,直到迭代對象爲空。這種方式使用list的切片,時間複雜度爲O(k),k爲截取長度。

有沒有更好的方式?這裏列舉一種複雜度更低的方式:

 def main():
    conf_files=['/etc/abc/', '/var/abc']
    conf_map = {'DB_HOST': "X.X.X.X",'DB_USER': "root",'DB_PASSWD': 'abcd1234'}
    
    pool_size = 10
    files = self.conf_files
    file_num = len(files)
    while files:
        if file_num < pool_size:
            pool_size = file_num

        t_objs = []
        for i in range(pool_size):
            file = files.pop()
            t = ReplaceThread(file=file, conf_map=conf_map)
            t.start()
            t_objs.append(t)

        for t in t_objs:
            t.join()
            if t.exitcode != 0:
                os.system("rm -rf {}".format(tmp_conf_path))
                raise Exception(t.exception)

        file_num -= pool_size

說明:

利用了順序表(list)的 尾部操作/獲取長度操作 時間複雜度均爲O(1)的特性,每次操作list的尾部元素,這個方式的複雜度更低

總結

子線程異常處理問題由此就得以解決,通過一些小改動也可以實現自定義的低複雜度線程池。重要的事情只說一遍:遇到問題查官方文檔

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章