[源碼解析] 深度學習分佈式訓練框架 horovod (14) --- 彈性訓練發現節點 & State

[源碼解析] 深度學習分佈式訓練框架 horovod (14) --- 彈性訓練發現節點 & State

0x00 摘要

Horovod 是Uber於2017年發佈的一個易於使用的高性能的分佈式訓練框架,在業界得到了廣泛應用。

本系列將通過源碼分析來帶領大家瞭解 Horovod。本文是系列第十四篇,看看horovod 如何動態發現節點 和 狀態信息。

本系列其他文章鏈接如下:

[源碼解析] 深度學習分佈式訓練框架 Horovod (1) --- 基礎知識

[源碼解析] 深度學習分佈式訓練框架 horovod (2) --- 從使用者角度切入

[源碼解析] 深度學習分佈式訓練框架 horovod (3) --- Horovodrun背後做了什麼

[源碼解析] 深度學習分佈式訓練框架 horovod (4) --- 網絡基礎 & Driver

[源碼解析] 深度學習分佈式訓練框架 horovod (5) --- 融合框架

[源碼解析] 深度學習分佈式訓練框架 horovod (6) --- 後臺線程架構

[源碼解析] 深度學習分佈式訓練框架 horovod (7) --- DistributedOptimizer

[源碼解析] 深度學習分佈式訓練框架 horovod (8) --- on spark

[源碼解析] 深度學習分佈式訓練框架 horovod (9) --- 啓動 on spark

[源碼解析] 深度學習分佈式訓練框架 horovod (10) --- run on spark

[源碼解析] 深度學習分佈式訓練框架 horovod (11) --- on spark --- GLOO 方案

[源碼解析] 深度學習分佈式訓練框架 horovod (12) --- 彈性訓練總體架構

[源碼解析] 深度學習分佈式訓練框架 horovod (13) --- 彈性訓練之 Driver

0x01 設計點

本文對應架構圖中的 Host Discovery 部分,因爲是被 Driver Main 調用,所以把兩部分一起展示出。

img

發現節點機制的幾個關鍵設計點如下:

  • 有節點變化時候,如何即時發現?Horovod是通過定期調用完成。
  • 發現節點變化時候,如何通知各個worker? Horovod通過構建了一個通知機制完成。即,每個worker把自己註冊到WorkerNotificationManager 之上,當有節點變化時候,WorkerNotificationManager 會逐一通知這些worker。
  • worker得到通知之後,如何處理?Horovod 把worker的狀態在深度框架上進一步封裝成各種State,得到通知之後就會調用State的對應callback函數,或者同步狀態,或者進行其他處理。

0x02 發現機制

這部分代碼主要在:horovod/runner/elastic/discovery.py。

2.1 發現腳本

HostDiscoveryScript 的主要作用就是保存腳本(程序啓動時候設置進來),然後當執行 find_available_hosts_and_slots 的時候,調用這個發現腳本,得到 host 信息。

該腳本的輸出的格式 就是調用 horovodrun 時候 的 host 參數格式,比如:

$ sh ./discover_hosts.sh    # 運行腳本,輸出節點信息
10.68.32.2:4
10.68.32.3:4
10.68.32.4:4

定義如下:

class HostDiscoveryScript(HostDiscovery):
  
    def __init__(self, discovery_script, slots):
        self._discovery_script = discovery_script # 設定腳本
        self._default_slots = slots # 審定slots
        super(HostDiscoveryScript, self).__init__()

    def find_available_hosts_and_slots(self):
        stdout = io.StringIO()
        # 執行發現腳本
        exit_code = safe_shell_exec.execute(self._discovery_script, stdout=stdout)

        # 讀取腳本輸出,解析出來host信息
        host_slots = {}
        lines = set(stdout.getvalue().strip().split('\n'))
        for line in lines:
            host = line
            if ':' in line:
                host, slots = line.split(':')
                host_slots[host] = int(slots)
            else:
                host_slots[host] = self._default_slots
        return host_slots

2.2 HostManager

HostManager 是 host discovery 的核心,作用是維護當前 host 以及 狀態,其主要變量是:

  • self._current_hosts : 當前的 host 信息,包括 slot,assign order 等等;
  • self._hosts_state :當前的 host 狀態,包括黑名單,event 等;
  • self._discovery :可以認爲是對 發現腳本 的一個封裝,用來動態執行 發現腳本,獲取 host 信息;
class HostManager(object):
    def __init__(self, discovery):
        self._current_hosts = DiscoveredHosts(host_slots={}, host_assignment_order=[])
        self._hosts_state = defaultdict(HostState)
        self._discovery = discovery

    def update_available_hosts(self):
        # TODO(travis): also check for hosts removed from the blacklist in the future
        # 檢查更新,給出是添加,還是刪除節點
        def check_update(cur_host_slots, prev_host_slots):
            res = HostUpdateResult.no_update

            for prev_h in prev_host_slots:
                if prev_h not in cur_host_slots:
                    # prev_h is a removed host
                    res |= HostUpdateResult.removed

            for h in cur_host_slots:
                if h not in prev_host_slots:
                    # h is an added host
                    res |= HostUpdateResult.added
                elif cur_host_slots[h] > prev_host_slots[h]:
                    # h has more slots added
                    res |= HostUpdateResult.added
                elif cur_host_slots[h] < prev_host_slots[h]:
                    # h has removed some slots
                    res |=  HostUpdateResult.removed
            return res

        prev_host_slots = self._current_hosts.host_slots
        prev_host_assignment_order = self._current_hosts.host_assignment_order
        host_slots = self._discovery.find_available_hosts_and_slots()
        
        if prev_host_slots != host_slots: # 有修改
            # 找到不在黑名單裏的host
            available_hosts = set([host for host in host_slots.keys() if not self._hosts_state[host].is_blacklisted()])
            # 找到host的order
            host_assignment_order = HostManager.order_available_hosts(available_hosts, prev_host_assignment_order)
            self._current_hosts = DiscoveredHosts(host_slots=host_slots,
                                                  host_assignment_order=host_assignment_order)
            # 檢查更新
            return check_update(self._current_hosts.host_slots, prev_host_slots)
        else: # 沒修改就不更新
            return HostUpdateResult.no_update

HostManager 核心邏輯是 update_available_hosts 方法,就是用來發現可用的 host。

2.2.1 order_available_hosts

order_available_hosts 的作用是:確保最老的host被賦予最低的rank,即rank 0,因爲最老的host最有可能擁有原來訓練的模型以及訓練狀態,這些信息需要在下一輪新迭代之前,發給所有worker。

    @staticmethod
    def order_available_hosts(available_hosts, prev_host_assignment_order):
        # We need to ensure this list preserves relative order to ensure the oldest hosts are assigned lower ranks.
        host_assignment_order = [host for host in prev_host_assignment_order if host in available_hosts]
        known_hosts = set(host_assignment_order)
        for host in available_hosts:
            if host not in known_hosts:
                host_assignment_order.append(host)
        return host_assignment_order

2.3 配置

我們看看是發現腳本如何配置進入HostManager之中。

首先,發現腳本是在_run_elastic之中配置。

def _run_elastic(args):
    # construct host discovery component
    if args.host_discovery_script:
        # 如果參數中有設置發現腳本,則賦值爲discover_hosts
        discover_hosts = discovery.HostDiscoveryScript(args.host_discovery_script, args.slots)
    elif args.hosts: # 如果參數設置好了hosts,則賦值爲discover_hosts
        _, available_host_slots = hosts.parse_hosts_and_slots(args.hosts)
        if len(available_host_slots) < 2:
            raise ValueError('Cannot run in fault tolerance mode with fewer than 2 hosts.')
        discover_hosts = discovery.FixedHosts(available_host_slots)
    else: # 拋出異常
        raise ValueError('One of --host-discovery-script, --hosts, or --hostnames must be provided')

    # 配置進入setting
    settings = elastic_settings.ElasticSettings(discovery=discover_hosts,
                                                .....)

    env = os.environ.copy()
    config_parser.set_env_from_args(env, args)
    gloo_run_elastic(settings, env, args.command)

其次,發現腳本被設置到ElasticSettings之中。

class ElasticSettings(BaseSettings):
    def __init__(self, discovery, min_np, max_np, elastic_timeout, reset_limit, **kwargs):
        self.discovery = discovery

當啓動時候,會設置到 ElasticDriver 之中。

def start(self):
    """Starts the Horovod driver and services."""
    self.rendezvous = RendezvousServer(self.settings.verbose)
    self.driver = ElasticDriver(
        rendezvous=self.rendezvous,
        discovery=self.settings.discovery, # 在這裏設置發現腳本
        min_np=self.settings.min_np,
        max_np=self.settings.max_np,
        timeout=self.settings.elastic_timeout,
        reset_limit=self.settings.reset_limit,
        verbose=self.settings.verbose)

最後,建立HostManager時候,會設置發現腳本。

class ElasticDriver(object):
    def __init__(self, rendezvous, discovery, min_np, max_np, timeout=None, reset_limit=None, verbose=0):
        self._rendezvous = rendezvous
        self._host_manager = HostManager(discovery) # 設置腳本

0x03 如何調用

3.1 無限循環線程

HostManager 的調用邏輯是在 ElasticDriver 類中。

ElasticDriver 在初始化時候,生成一個後臺線程 _discovery_thread。

self._discovery_thread = threading.Thread(target=self._discover_hosts)

3.1.1 定時探尋

_discovery_thread 之中,會運行_discover_hosts。

ElasticDriver._discover_hosts 會:

  • 首先調用 self._host_manager.update_available_hosts(self._host_manager.current_hosts, update_res)得到最新的host狀態;
  • 其次,如果新 host 狀態已經發生的變化,於是就調用 _notify_workers_host_changes 和 _wait_hosts_cond.notify_all 來通知大家有 host 變化了;
def _discover_hosts(self):
    first_update = True
    while not self._shutdown.is_set():
        self._wait_hosts_cond.acquire()
        try:
            # 得到最新的host狀態
            update_res = self._host_manager.update_available_hosts()
            if update_res != HostUpdateResult.no_update:
                self._notify_workers_host_changes(self._host_manager.current_hosts, update_res)
                self._wait_hosts_cond.notify_all() # 通知大家有 host 變化
        except RuntimeError as e:
            if first_update:
                # Misconfiguration, fail the job immediately
                self._shutdown.set()
                self._wait_hosts_cond.notify_all() # 通知大家有 host 變化
                raise
            # Transient error, retry until timeout
            logging.warning(str(e))
        finally:
            self._wait_hosts_cond.release()
        first_update = False
        self._shutdown.wait(DISCOVER_HOSTS_FREQUENCY_SECS)

邏輯如下,是一個 thread loop 定時運行:

 <--------------------^
+                     |
|       thread loop   |
|                     |
|    +----------------+----------------------+
|    |  ElasticDriver._discovery_thread      |
|    |                                       |
|    |                                       |
|    |                                       |
|    |                                       |
|    |   HostManager.update_available_hosts  |
|    |                                       |
|    +----------------+----------------------+
|                     ^
|                     |
v                     |
+-------------------->+

3.1.2 通知變化

如果發現有host 變化,則調用 self._notify_workers_host_changes 來通知。

即,當Driver的定時進程通過節點發現腳本發現某一個節點被標記爲新增或者移除時,它將 調用 _notify_workers_host_changes 發送一個通知到所有workers

邏輯如下:

 <--------------------^
+                     |
|       thread loop   |
|                     |
|    +----------------+-----------------------------------------------+
|    |  ElasticDriver._discovery_thread                               |
|    |                                                                |
|    |                                                                |
|    |   HostManager.update_available_hosts                           |
|    |                +                                               |
|    |                |                                               |
|    |                |                                               |
|    |                v                                               |
|    |                                      YES                       |
|    |       update_res != no_update ???  +--------+                  |
|    |                +                            |                  |
|    |                |                            |                  |
|    |                |                            v                  |
|    |                | NO                                            |
|    |                |             _notify_workers_host_changes      |
|    |                v                                               |
|    +----------------------------------------------------------------+
|                     |
|                     |
|                     |
v                     |
+-------------------->+

具體如下:

def _notify_workers_host_changes(self, current_hosts, update_res):
    next_host_assignments = {}
    if current_hosts.count_available_slots() >= self._min_np:
        # Assignments are required to be stable via contract
        next_host_assignments, _ = self._get_host_assignments(current_hosts)

    if next_host_assignments == self.host_assignments:
        # Skip notifying workers when host changes would not result in changes of host assignments
        return

    coordinator_slot_info = self.get_coordinator_info()
    # 獲取 WorkerNotificationClient
    coordinator_client = self.get_worker_client(coordinator_slot_info)

    timestamp = _epoch_time_s()
    coordinator_client.notify_hosts_updated(timestamp, update_res) # 通知

get_worker_client 函數就是獲取 WorkerNotificationClient,然後調用 WorkerNotificationClient 來進行通知,所以下面我們接下來看 WorkerNotificationClient。

def get_worker_client(self, slot_info):
    return self._worker_clients.get((slot_info.hostname, slot_info.local_rank))

具體如下:

 <--------------------^
+                     |
|       thread loop   |
|                     |
|    +----------------+------------------------------------+
|    | ElasticDriver._discovery_thread                     |
|    |                +                                    |
|    |                |                                    |
|    |                v                                    |
|    |   HostManager.update_available_hosts                |
|    |                +                                    |
|    |                |                                    |
|    |                |                                    |
|    |                v                     YES            |                       +---------------------------+
|    |       update_res != no_update ???  +-----+          |                       |                           |
|    |                +                         |          |                       |                           |
|    |                |                         |          |                       | WorkerNotificationClient  |
|    |                |                         v          | notify_hosts_updated  |                           |
|    |                | NO                                 |                       |                           |
|    |                |     _notify_workers_host_changes+------------------------> |                           |
|    |                v                                    |                       |                           |
|    +-----------------------------------------------------+                       +---------------------------+
|                     |
|                     |
|                     |
v                     |
+-------------------->+

手機如下:

3.2 如何通知

就是利用 WorkerNotificationClient 發送 HostsUpdatedRequest

3.2.1 WorkerNotificationClient

可以看到,WorkerNotificationService 繼承了 network.BasicService,所以 WorkerNotificationClient 就是作爲 WorkerNotificationService 的操作接口,從而給 WorkerNotificationService 發送 HostsUpdatedRequest。

class WorkerNotificationClient(network.BasicClient):
    def __init__(self, addresses, key, verbose, match_intf=False):
        super(WorkerNotificationClient, self).__init__(WorkerNotificationService.NAME,
                                                       addresses,
                                                       key,
                                                       verbose,
                                                       match_intf=match_intf)

    def notify_hosts_updated(self, timestamp, update_res):
        self._send(HostsUpdatedRequest(timestamp, update_res))

3.2.2 WorkerNotificationService

WorkerNotificationService 會響應 HostsUpdatedRequest。

class WorkerNotificationService(network.BasicService):
    NAME = 'worker notification service'

    def __init__(self, key, nic, manager):
        super(WorkerNotificationService, self).__init__(WorkerNotificationService.NAME,
                                                        key,
                                                        nic)
        self._manager = manager

    def _handle(self, req, client_address):
        if isinstance(req, HostsUpdatedRequest):
            self._manager.handle_hosts_updated(req.timestamp, req.res)
            return network.AckResponse()

        return super(WorkerNotificationService, self)._handle(req, client_address)

3.2.3 WorkerNotificationManager

handle_hosts_updated 會逐一通知註冊在WorkerNotificationManager 上的 listener(就是用戶代碼中的 State)

WorkerNotificationManager 是在 horovod/common/elastic.py 構建,每一個host上運行一個。

notification_manager = WorkerNotificationManager()

具體定義如下:

class WorkerNotificationManager(object):
    def __init__(self):
        self._lock = threading.Lock()
        self._service = None
        self._listeners = set()

    def init(self, rendezvous_addr=None, rendezvous_port=None,
             nic=None, hostname=None, local_rank=None):
        with self._lock:
            if self._service:
                return

            rendezvous_addr = rendezvous_addr or os.environ.get(HOROVOD_GLOO_RENDEZVOUS_ADDR)
            if not rendezvous_addr:
                return

            rendezvous_port = rendezvous_port if rendezvous_port is not None else \
                int(os.environ.get(HOROVOD_GLOO_RENDEZVOUS_PORT))
            nic = nic or os.environ.get(HOROVOD_GLOO_IFACE)
            hostname = hostname or os.environ.get(HOROVOD_HOSTNAME)
            local_rank = local_rank if local_rank is not None else \
                int(os.environ.get(HOROVOD_LOCAL_RANK))

            secret_key = secret.make_secret_key()
            self._service = WorkerNotificationService(secret_key, nic, self)

            value = (self._service.addresses(), secret_key)
            put_data_into_kvstore(rendezvous_addr,
                                  rendezvous_port,
                                  PUT_WORKER_ADDRESSES,
                                  self._create_id(hostname, local_rank),
                                  value)

    def register_listener(self, listener):
        self._listeners.add(listener)

    def remove_listener(self, listener):
        self._listeners.remove(listener)

    def handle_hosts_updated(self, timestamp, update_res):
        for listener in self._listeners:
            listener.on_hosts_updated(timestamp, update_res)

3.2.4 通知 State

我們再梳理以下流程:

  • 當Driver的定時進程通過節點發現腳本發現某一個節點被標記爲新增或者移除時,它將發送一個通知到所有workers。
  • 每一個 worker 有自己對應的 State,都被存儲於 WorkerNotificationManager . _listeners
  • _host_messages 會在state 之中註冊host的變化,就是往其 _host_messages 之中放入"host 有變化" 的消息。
  • 因爲這個消息不是一定要立即處理的,所以這裏只是先放入 State 的隊列之中

邏輯如下:

 <--------------------^
+                     |
|       thread loop   |
|                     |
|    +----------------+------------------------------------+
|    | ElasticDriver._discovery_thread                     |
|    |                +                                    |
|    |                |                                    |
|    |                v                                    |
|    |   HostManager.update_available_hosts                |
|    |                +                                    |
|    |                |                                    |
|    |                |                                    |
|    |                v                     YES            |
|    |       update_res != no_update ???  +-----+          |                       +--------------------------+                       +----------------------------+
|    |                +                         |          |                       |                          |                       |                            |
|    |                |                         |          |                       | WorkerNotificationClient |                       | WorkerNotificationService  |
|    |                |                         v          | notify_hosts_updated  |                          |  HostsUpdatedRequest  |                            |
|    |                | NO                                 |                       |                          |                       |                            |
|    |                |     _notify_workers_host_changes+------------------------> |                          | +-------------------> |                            |
|    |                v                                    |                       |                          |                       |                            |
|    +-----------------------------------------------------+                       +--------------------------+                       +----------------+-----------+
|                     |                                                                                                                                |
|                     |                                                                                                                                |
|                     |                                                                                                           handle_hosts_updated |
v                     |                                                                                                                                |
+-------------------->+                                                                                                                                v
                                                                                                                                    +------------------+-----------+
                                                                                                                                    |                              |
                                                                                                                                    | WorkerNotificationManager    |
                                                                    +-----------+  +----------+        +----------+                 |                              |
                                                                    |           |  |          |        |          |                 |                              |
                                                                    |  State 1  |  | State 2  | ...... | State n  |   <---------------------+  _listeners          |
                                                                    |           |  |          |        |          |                 |                              |
                                                                    +-----------+  +----------+        +----------+                 |                              |
                                                                                                                                    |                              |
                                                                          ^              ^                   ^                      |                              |
                                                                          |              |                   |                      |                              |
                                                         on_hosts_updated |              | on_hosts_updated  |  on_hosts_updated    |                              |
                                                                          |              |                   |                      |                              |
                                                                          +--------------+-------------------+-------------------------+  handle_hosts_updated     |
                                                                                                                                    |                              |
                                                                                                                                    +------------------------------+

手機如下:

3.2.5 何時處理

何時處理這個通知?在下一次 state.commit() 或者更輕量的 state.check_host_updates() 被調用時,state.check_host_updates 會從 _host_messages 中讀取消息,積累更新。

如 check_host_updates 方法中註釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時拋出 HostsUpdateInterrupt 異常,具體同步使用 _bcast_object(然後內部調用到了 MPI)。

我們接下來就會在 State 的介紹之中,講解check_host_updates 。

0x04 狀態抽象

Horovod 實現了一個 State 對象,這是把機器訓練的模型又做了一步抽象。

每一個Worker擁有一個 State 對象。

  • Horovod 把所有需要在workers之間同步的變量都放進 hvd.elastic.State (比如model parameters,optimizer state,當前epoch和batch進度等等)對象之中。

  • State 對象的作用是定期存儲訓練狀態,在需要時候從 State 對象中恢復機器學習的狀態。這樣在某些worker發生意外錯誤時,可以避免因爲狀態被損壞而無法恢復現場。

  • 比如,假設一個worker剛好在參數更新過程中突然掛掉,而此時部分梯度更新可能只更新到一半,這個狀態是不可逆而又無法繼續,導致參數是被損壞狀態無法用於恢復訓練。

4.1 State

State 的作用是:在不同的 worker 之中跟蹤內存狀態

主要變量&方法是:

  • on_reset : 當需要重啓狀態時候調用;
  • on_hosts_updated :當有 host 變化時候調用,即 向 _host_messages 這個 queue 放入一個消息;
  • commit :用戶會定期調用此函數,會存儲狀態(state)到內存,檢查 host 更改
    • 當有異常發生時,會拋出一個 HorovodInternalError 異常,當 hvd.elastic.run 捕獲到這個異常後,會利用最新一次commit中恢復所有狀態。
    • 因爲commit狀態代價高昂(比如如參數量太大會導致耗時過長),所以需要在"每個batch的處理時間"與"如果出錯,訓練需要從多久前的狀態恢復"之間選取一個平衡點。比如,如果你每訓練10個batches就commit一次,你就把複製時間降低了10倍。但是當發生錯誤時,你需要回滾到10個batches前的狀態。
  • check_host_updates : 會從 _host_messages 中讀取消息,積累更新,如方法中註釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時拋出異常。具體同步使用 _bcast_object(然後內部調用到了 MPI);
    • 如果節點發現腳本可以預見到某個節點是需要被移除或新增,Elastic Horvod可以避免回滾操作。當Driver的定時進程通過節點發現腳本發現某一個節點被標記爲新增或者移除時,它將發送一個通知到所有workers,於是在下一次 state.commit() 或者更輕量的 state.check_host_updates() 被調用時,會拋出一個 HostsUpdateInterrupt 異常。這個異常類似於 HorovodInternalError 異常,但是參數狀態等不會從最近一次commit中恢復,而是從當前實時的參數中恢復。
    • 一般來說,如果你的硬件設施是可靠與穩定的,並且你的編排系統會在任務節點移除時提供足夠的告警,你就可低頻次調用 state.commit() 函數,同時只在每個batch結束時調用相對不耗時的 state.check_host_updates() 來檢查節點變更情況。
  • _reset_callbacks :用戶可以註冊一些回調函數到 hvd.elastic.State 對象中,用於響應worker成員發生變化的情況。
    • 比如回調函數可以處理如下情況:
      1. 當worker數量發生改變時,學習率需要根據新的world size進行相應改變。
      2. 對數據集進行重新分區。
    • 這些回調函數會在"Horovod被重啓之後"和"狀態在節點間同步之前"這兩個階段中間被調用。

具體定義如下:

class State(object):
    """State representation used for tracking in memory state across workers.

    Args:
        bcast_object: Function used to broadcast a variable from rank 0 to the other workers.
        get_rank: Function that returns the current rank of this worker.
    """
    def __init__(self, bcast_object, get_rank):
        self._bcast_object = bcast_object
        self._rank = get_rank
        self._host_messages = queue.Queue()
        self._last_updated_timestamp = 0
        self._reset_callbacks = []

    def on_reset(self):
        self._host_messages = queue.Queue()
        self.reset()
        for callback in self._reset_callbacks:
            callback()

    def on_hosts_updated(self, timestamp, update_res):
        self._host_messages.put((timestamp, update_res))

    def commit(self):
        self.save()
        self.check_host_updates()

    def check_host_updates(self):
        """Checks that a notification has been sent indicating that hosts can be added or will be removed.

        Raises a `HostsUpdatedInterrupt` if such a notification has been received.
        """
        # Iterate through the update messages sent from the server. If the update timestamp
        # is greater than the last update timestamp, then trigger a HostsUpdatedException.
        # 遍歷更新消息,如果更新時間戳大於上次更新時間戳,就觸發一個HostUpdateResult
        last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
        all_update = HostUpdateResult.no_update
        while not self._host_messages.empty():
            timestamp, update = self._host_messages.get()
            if timestamp > last_updated_timestamp:
                last_updated_timestamp = timestamp
                all_update |= update

        # In order to ensure all workers raise the exception at the same time, we need to sync
        # the updated state across all the workers.
        # TODO(travis): this should be a max allreduce to account for changes in rank 0
        # 會從 `_host_messages` 中讀取消息,積累更新,如方法中註釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時拋出異常。具體同步使用 `_bcast_object`(然後內部調用到了 MPI)
        prev_timestamp, self._last_updated_timestamp, all_update = \
            self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))

        # At this point, updated state is globally consistent across all ranks.
        if self._last_updated_timestamp > prev_timestamp:
            raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)

因此,我們加入 Commit 之後,邏輯如圖:

 <--------------------^
+                     |
|       thread loop   |
|                     |
|    +----------------+------------------------------------+
|    | ElasticDriver._discovery_thread                     |
|    |                +                                    |
|    |                |                                    |
|    |                v                                    |
|    |   HostManager.update_available_hosts                |
|    |                +                                    |
|    |                |                                    |
|    |                |                                    |
|    |                v                     YES            |
|    |       update_res != no_update ???  +-----+          |                       +--------------------------+                       +----------------------------+
|    |                +                         |          |                       |                          |                       |                            |
|    |                |                         |          |                       | WorkerNotificationClient |                       | WorkerNotificationService  |
|    |                |                         v          | notify_hosts_updated  |                          |  HostsUpdatedRequest  |                            |
|    |                | NO                                 |                       |                          |                       |                            |
|    |                |     _notify_workers_host_changes+------------------------> |                          | +-------------------> |                            |
|    |                v                                    |                       |                          |                       |                            |
|    +-----------------------------------------------------+                       +--------------------------+                       +----------------+-----------+
|                     |                                                                                                                                |
|                     |                                                                                                                                |
|                     |                                                             _bcast_object                                 handle_hosts_updated |
v                     |                                                                                                                                |
+-------------------->+                                                  +-------------+----------------------+                                        v
                                                                         |             |                      |                     +------------------+-----------+
                                                                         |             |                      |                     |                              |
                                                                         v             v                      v                     | WorkerNotificationManager    |
       +--------------------+                                       +----+------+  +---+------+        +------+---+                 |                              |
       |                    |                                       |           |  |          |        |          |                 |                              |
       |   Python xxx.py    +-------------------------------------> |  State 1  |  | State 2  | ...... | State n  |   <---------------------+  _listeners          |
       |                    |        commit / check_host_updates    |           |  |          |        |          |                 |                              |
       +--------------------+                                       +-----------+  +----------+        +----------+                 |                              |
                                                                                                                                    |                              |
                                                                          ^              ^                   ^                      |                              |
                                                                          |              |                   |                      |                              |
                                                         on_hosts_updated |              | on_hosts_updated  |  on_hosts_updated    |                              |
                                                                          |              |                   |                      |                              |
                                                                          +--------------+-------------------+-------------------------+  handle_hosts_updated     |
                                                                                                                                    |                              |
                                                                                                                                    +------------------------------+

具體如下:

我們接下來介紹各級派生類。

4.2 ObjectState

ObjectState 的目的是組裝成 simple Python objects。

class ObjectState(State):
    """State for simple Python objects.

    Every object is specified as a keyword argument, and will be assigned as an attribute.

    Args:
        bcast_object: Horovod broadcast object function used to sync state dictionary.
        get_rank: Horovod rank function used to identify is this process is the coordinator.
        kwargs: Properties to sync, will be exposed as attributes of the object.
    """
    def __init__(self, bcast_object, get_rank, **kwargs):
        self._bcast_object = bcast_object
        self._saved_state = kwargs
        self._set_attrs()
        super(ObjectState, self).__init__(bcast_object=bcast_object, get_rank=get_rank)

    def save(self):
        new_state = {}
        for attr in self._saved_state.keys():
            new_state[attr] = getattr(self, attr)
        self._saved_state = new_state

    def restore(self):
        self._set_attrs()

    def sync(self):
        if self._saved_state:
            self._saved_state = self._bcast_object(self._saved_state)
            self._set_attrs()

    def _set_attrs(self):
        for attr, value in self._saved_state.items():
            setattr(self, attr, value)

4.3 TensorFlowKerasState

Horovod 默認已提供標準的TensorFlow,Keras和PyTorch的狀態保持和恢復實現,如果需要在某些場景下自定義,可以重載 hvd.elastic.State 這個對象。

TensorFlowKerasState 是 TensorFlow Keras model and optimizer 的狀態抽象。

初始化函數中,會設置各種相關變量,比如廣播函數。

class TensorFlowKerasState(ObjectState):

    def __init__(self, model, optimizer=None, backend=None, **kwargs):
        self.model = model
        if not _model_built(model):
            raise ValueError('Model must be built first. Run `model.build(input_shape)`.')

        self.optimizer = optimizer or model.optimizer
        self.backend = backend
        self._save_model()

        if not backend or _executing_eagerly():
            self._bcast_model = lambda: _broadcast_model(self.model, self.optimizer, backend=self.backend)
            bcast_object = broadcast_object
        else:
            # For TensorFlow v1, we need to reuse the broadcast op to prevent incrementing the uids
            bcast_op = broadcast_variables(_global_variables(), root_rank=0)
            self._bcast_model = lambda: self.backend.get_session().run(bcast_op)
            bcast_object = broadcast_object_fn(session=self.backend.get_session())

        super(TensorFlowKerasState, self).__init__(bcast_object=bcast_object,
                                                   get_rank=rank,
                                                   **kwargs)

具體實現了幾個方法,基本就是 存儲,恢復 state,同步。

def save(self):
    self._save_model()
    super(TensorFlowKerasState, self).save()

def restore(self):
    self._load_model()
    super(TensorFlowKerasState, self).restore()

def sync(self):
    self._bcast_model()
    self._save_model()
    super(TensorFlowKerasState, self).sync()

def _save_model(self):
    if _executing_eagerly():
        self._saved_model_state = [tf.identity(var) for var in self.model.variables]
        self._saved_optimizer_state = [tf.identity(var) for var in self.optimizer.variables()]
    else:
        self._saved_model_state = self.model.get_weights()
        self._saved_optimizer_state = self.optimizer.get_weights()

def _load_model(self):
    if _executing_eagerly():
        for var, saved_var in zip(self.model.variables, self._saved_model_state):
            var.assign(saved_var)
        for var, saved_var in zip(self.optimizer.variables(), self._saved_optimizer_state):
            var.assign(saved_var)
    else:
        self.model.set_weights(self._saved_model_state)
        self.optimizer.set_weights(self._saved_optimizer_state)

4.4 Restore

我們看到了,restore 會從內存中恢復模型。

def restore(self):
    self._load_model()
    super(TensorFlowKerasState, self).restore()

於是,我們有一個問題:何時調用restore?

發現是如果 horovod 捕獲了 HorovodInternalError 之後,會用 restore 來恢復。

def run_fn(func, reset):
    @functools.wraps(func)
    def wrapper(state, *args, **kwargs):
        notification_manager.init()
        notification_manager.register_listener(state)
        skip_sync = False

        try:
            while True:
                if not skip_sync:
                    state.sync()

                try:
                    return func(state, *args, **kwargs)
                except HorovodInternalError:
                    state.restore() # 在這裏調用
                    skip_sync = False
                except HostsUpdatedInterrupt as e:
                    skip_sync = e.skip_sync

                reset()
                state.on_reset()
        finally:
            notification_manager.remove_listener(state)
    return wrapper

0x05 總結

我們再次重複一下,發現節點機制的幾個關鍵設計點:

  • 有節點變化時候,如何即時發現?Horovod是通過定期調用完成。
  • 發現節點變化時候,如何通知各個worker? Horovod通過構建了一個通知機制完成。即,每個worker把自己註冊到WorkerNotificationManager 之上,當有節點變化時候,WorkerNotificationManager 會逐一通知這些worker。
  • worker得到通知之後,如何處理?Horovod 把worker的狀態在深度框架上進一步封裝成各種State,得到通知之後就會調用State的對應callback函數,或者同步狀態,或者進行其他處理。

簡化版總體邏輯如下:

                                                         +-----------------------------v
                                                         ^        thread loop          |
                                                         |                             |
                                        +----------------+----------------------+      |
                                        |  ElasticDriver._discovery_thread      |      |
         _notify_workers_host_changes   |                                       |      |
                                        |                                       |      |
                     +------------------+                                       |      |
                     |                  |                                       |      |
                     |                  |   HostManager.update_available_hosts  |      |
                     |                  |                                       |      |
                     |                  +-----------------+---------------------+      |
                     |                                    ^                            |
                     |                                    |                            |
                     |                                    |                            |
                     |                                    +----------<---------------+ v
                     v

+---------------------------+   HostsUpdatedReques   +----------------------------+ handle_hosts_updated +----------------------------+
|                           |                        |                            |                      |                            |
| WorkerNotificationClient  +----------------------> |  WorkerNotificationService | +------------------> |  WorkerNotificationManager |
|                           |                        |                            |                      |                            |
+---------------------------+                        +----------------------------+                      +--------+-------------------+
                                                                                                                  |
                                                                                                                  |
                                                                                                                  |   on_hosts_updated
                                                                                                                  |
                                                                                                                  v
                                                                                                             +----+---+
                                                                                                             | State  |
                                                                                                             +--------+

手機如下:

至此,發現節點部分介紹完畢,因爲本文只是使用了 WorkerNotificationService 完成通知,但是沒有深入介紹,所以下一篇介紹內部廣播和通知機制。

0xEE 個人信息

★★★★★★關於生活和技術的思考★★★★★★

微信公衆賬號:羅西的思考

如果您想及時得到個人撰寫文章的消息推送,或者想看看個人推薦的技術資料,敬請關注。

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章