[源碼解析] 深度學習分佈式訓練框架 horovod (15) --- 廣播 & 通知

[源碼解析] 深度學習分佈式訓練框架 horovod (15) --- 廣播 & 通知

0x00 摘要

Horovod 是Uber於2017年發佈的一個易於使用的高性能的分佈式訓練框架,在業界得到了廣泛應用。

本系列將通過源碼分析來帶領大家瞭解 Horovod。本文是系列第十五篇,看看horovod 彈性訓練如何廣播和發送通知。

本系列其他文章鏈接如下:

[源碼解析] 深度學習分佈式訓練框架 Horovod (1) --- 基礎知識

[源碼解析] 深度學習分佈式訓練框架 horovod (2) --- 從使用者角度切入

[源碼解析] 深度學習分佈式訓練框架 horovod (3) --- Horovodrun背後做了什麼

[源碼解析] 深度學習分佈式訓練框架 horovod (4) --- 網絡基礎 & Driver

[源碼解析] 深度學習分佈式訓練框架 horovod (5) --- 融合框架

[源碼解析] 深度學習分佈式訓練框架 horovod (6) --- 後臺線程架構

[源碼解析] 深度學習分佈式訓練框架 horovod (7) --- DistributedOptimizer

[源碼解析] 深度學習分佈式訓練框架 horovod (8) --- on spark

[源碼解析] 深度學習分佈式訓練框架 horovod (9) --- 啓動 on spark

[源碼解析] 深度學習分佈式訓練框架 horovod (10) --- run on spark

[源碼解析] 深度學習分佈式訓練框架 horovod (11) --- on spark --- GLOO 方案

[源碼解析] 深度學習分佈式訓練框架 horovod (12) --- 彈性訓練總體架構

[源碼解析] 深度學習分佈式訓練框架 horovod (13) --- 彈性訓練之 Driver

[源碼解析] 深度學習分佈式訓練框架 horovod (14) --- 彈性訓練發現節點 & State

0x01 問題

首先,我們提出一個問題:爲什麼彈性訓練 需要有廣播?

答案就是:因爲捕獲兩種異常之後,需要廣播到各個worker。

1.1 HorovodInternalError

關於 HorovodInternalError 異常處理,我們看看具體容錯機制,就可以知道緣由:

  • hvd.elastic.run 裝飾器捕獲異常;
  • 如果是 HorovodInternalError,就恢復到最近一次提交的狀態,此時因爲是allreduce等異常,所以所有worker都處於停止狀態;
  • driver 會根據當前正在運行的節點重新執行一個 rendezvous,以便重新初始化 Horovod context;
  • 當新的通信域構造成功後,rank = 0 的 worker 會將自身的模型廣播給其他 worker
  • 所有worker接着上次停止的迭代步數繼續訓練;

因爲需要從 rank 0 廣播變量給其他進程,所以必須有一個廣播機制

1.2 HostsUpdateInterrupt

關於 HostsUpdateInterrupt 異常處理,我們看看具體原因。

  • 當驅動進程通過節點發現腳本發現一個節點被標記爲新增或者移除時,它將發送一個通知到所有workers,在下一次 state.commit() 或者更輕量的 state.check_host_updates() 被調用時,會拋出一個 HostsUpdateInterrupt 異常。這個異常類似於 HorovodInternalError 異常,但是參數狀態等不會從最近一次commit中恢復,而是從當前實時的參數中恢復。
  • check_host_updates 方法 會從 _host_messages 中讀取消息,積累更新,如其方法中註釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時拋出異常。
  • 具體同步使用 _bcast_object(然後內部調用到了 MPI)。

需要一個廣播機制在每個 worker 之間同步狀態(因爲這些worker目前都是在正常訓練,需要有一個東西統一打斷他們的訓練,從而重新組建一個通信環),目的是讓這些 worker 同時拋出 HostsUpdateInterrupt 異常。

其次,我們需要回顧下上文的流程圖,本文將對其部分內部流程進行細化。

0x02 廣播機制

我們具體剖析廣播機制如下,因爲廣播是和具體框架密切結合,所以我們以tensorflow爲例,具體代碼在horovod/tensorflow/elastic.py 之中。

2.1 廣播實現

在 horovod/tensorflow/elastic.py 之中,就是針對 TF 做的特定實現。其中會依據 TF 的版本做不同處理。

2.1.1 TensorFlowKerasState

以 TensorFlowKerasState 爲例,在初始化的時候,因爲有廣播對象的需要,比如在 TensorFlowKerasState 之中配置了_bcast_model用來廣播模型,bcast_object用來廣播對象,broadcast_variables用來廣播變量。

而且提供了sync函數負責廣播,可以看出來調用了_bcast_model

class TensorFlowKerasState(ObjectState):
    def __init__(self, model, optimizer=None, backend=None, **kwargs):

        if not backend or _executing_eagerly():
            # 這裏設定了廣播函數
            self._bcast_model = lambda: _broadcast_model(self.model, self.optimizer, backend=self.backend)
            bcast_object = broadcast_object
        else:
            # For TensorFlow v1, we need to reuse the broadcast op to prevent incrementing the uids
            # 這裏設定了廣播函數
            bcast_op = broadcast_variables(_global_variables(), root_rank=0)
            self._bcast_model = lambda: self.backend.get_session().run(bcast_op)
            bcast_object = broadcast_object_fn(session=self.backend.get_session())
        
    def sync(self):
        self._bcast_model() #廣播模型
        self._save_model()
        super(TensorFlowKerasState, self).sync()

2.1.2 廣播模型

_broadcast_model 函數會 廣播 模型變量,optimizer變量

def _broadcast_model(model, optimizer, backend):
    if _executing_eagerly():
        # TensorFlow 2.0 or TensorFlow eager
        broadcast_variables(model.variables, root_rank=0) # 廣播模型變量
        broadcast_variables(optimizer.variables(), root_rank=0) # 廣播優化器變量
    else:
        bcast_op = broadcast_variables(_global_variables(), root_rank=0)
        backend.get_session().run(bcast_op)

2.1.3 廣播變量

廣播變量的具體實現 在 horovod/tensorflow/functions.py 之中。broadcast_variables 的作用是從 root rank(即 rank 0)廣播變量到其他的進程。

具體也根據 TF 版本做了區別。

def _make_subgraph(f):
    return tf.function(f)

@_cache
def _make_broadcast_group_fn():
    if _executing_eagerly():
        # Eager mode will parallelize independent control flow
        def broadcast_group(variables, root_rank): # 在這裏定義
            for var in variables:
                var.assign(broadcast(var, root_rank)) # 調用MPI函數,這裏都指定了是root_rank

        return _make_subgraph(broadcast_group)
    else:
        # Graph mode requires an Op
        def broadcast_group(variables, root_rank): # 在這裏定義
            # tf.group()用於創造一個操作,可以將傳入參數的所有操作組合,當這個操作完成後,所有 input 中的所有 ops 都已完成。tf.group()操作沒有輸出。
            return tf.group(*[var.assign(broadcast(var, root_rank)) # 這裏調用MPI函數
                              for var in variables])

        return broadcast_group

def broadcast_variables(variables, root_rank):
    """Broadcasts variables from root rank to all other processes.
    """
    broadcast_group = _make_broadcast_group_fn()
    return broadcast_group(variables, root_rank # 在上面定義

2.1.4 廣播對象

廣播對象 的作用是從 root rank(即 rank 0)廣播對象到其他的進程。 廣播對象和廣播變量的區別是:對象需要序列化和反序列化。

def broadcast_object(obj, root_rank=0, session=None, name=None):
    """
    Serializes and broadcasts an object from root rank to all other processes.

    Arguments:
        obj: An object capable of being serialized without losing any context.
        root_rank: The rank of the process from which parameters will be
                   broadcasted to all other processes.
        session: Session for TensorFlow v1 compatibility.
        name: Optional name to use during broadcast, will default to the class
              type.
    Returns:
        The object that was broadcast from the `root_rank`.
    """
    if name is None:
        name = type(obj).__name__

    def to_numpy(v): # 依據tf版本不同做不同處理
        if not _executing_eagerly():
            sess = session or ops.get_default_session()
            return sess.run(v)
        else:
            return v.numpy()

    if rank() == root_rank:
        b = io.BytesIO() # BytesIO實現了在內存中讀寫bytes
        cloudpickle.dump(obj, b) # 序列化,編碼成一個二進制文件
        t = tf.convert_to_tensor(bytearray(b.getvalue()), dtype=tf.uint8)
        sz = tf.convert_to_tensor([t.shape[0]], dtype=tf.int32) # 張量對應維度的數值
        to_numpy(broadcast(sz, root_rank, name + '.sz')) # 廣播維度
    else:
        sz = tf.convert_to_tensor([0], dtype=tf.int32)
        sz = to_numpy(broadcast(sz, root_rank, name + '.sz')) # 接受維度
        t = tf.zeros(sz.tolist()[0], dtype=tf.uint8)

    t = to_numpy(broadcast(t, root_rank, name + '.t')) # 廣播對象內容

    if rank() != root_rank:
        buf = io.BytesIO(t.tobytes())
        obj = cloudpickle.load(buf) # 反序列化,解碼成原本的對象

    return obj

2.1.5 HVD C++

底層會調用到 MPI 函數完成廣播功能

def broadcast(tensor, root_rank, name=None, ignore_name_scope=False):
    """An op which broadcasts the input tensor on root rank to the same input tensor
    on all other Horovod processes.

    The broadcast operation is keyed by the name of the op. The tensor type and
    shape must be the same on all Horovod processes for a given name. The broadcast
    will not start until all processes are ready to send and receive the tensor.

    Returns:
      A tensor of the same shape and type as `tensor`, with the value broadcasted
      from root rank.
    """
    if name is None and not _executing_eagerly():
        name = 'HorovodBroadcast_%s' % _normalize_name(tensor.name)
    return MPI_LIB.horovod_broadcast(tensor, name=name, root_rank=root_rank,
                                     ignore_name_scope=ignore_name_scope)

2.1.6 MPI

MPI_BCAST的作用是:從一個序列號爲root的進程將一條消息廣播發送到組內的所有進程, 包括它本身在內。

因爲之前指定了root_rank,所以即使所有worker雖然都調用了同樣代碼,也只是會把 root_rank 通信消息緩衝區中的消息拷貝到其他所有進程中去

void MPIController::Bcast(void* buffer, size_t size, int root_rank,
                          Communicator communicator) {
  MPI_Comm comm = mpi_ctx_.GetMPICommunicator(communicator);
  int ret_code = MPI_Bcast(buffer, size, MPI_BYTE, root_rank, comm);
  if (ret_code != MPI_SUCCESS) {
    throw std::runtime_error(
        "MPI_Broadcast failed, see MPI output for details.");
  }
}

2.1.7 小結

我們總結一下各個函數:

  • _bcast_model用來廣播模型;
  • bcast_object用來廣播對象;
  • broadcast_variables用來廣播變量;
  • 廣播對象和廣播變量的區別是:對象需要序列化和反序列化。
  • _broadcast_model 就是調用了broadcast_variables完成對模型參數的廣播;
  • broadcast_variables 中調用了broadcast_groupbroadcast_group主要就是利用tf.group()把廣播操作組合起來;

2.2 使用

2.2.1 HorovodInternalError

當捕獲HorovodInternalError時候,會進行廣播同步,目的是當新的通信域構造成功後,rank = 0 的 worker 會將自身的模型廣播給其他 worker。

def run_fn(func, reset):
    @functools.wraps(func)
    def wrapper(state, *args, **kwargs):
        notification_manager.init()
        notification_manager.register_listener(state)
        skip_sync = False

        try:
            while True:
                if not skip_sync:
                    state.sync() # 這裏會進行廣播同步,就是TensorFlowKerasState.sync

                try:
                    return func(state, *args, **kwargs)
                except HorovodInternalError:
                    state.restore() # 捕獲一場,然後繼續while循環
                    skip_sync = False
                except HostsUpdatedInterrupt as e:
                    skip_sync = e.skip_sync

                reset()
                state.on_reset()
        finally:
            notification_manager.remove_listener(state)
    return wrapper

具體如下:

  Worker rank 0                               Worker rank n
        +                                         +
        |                                         |
        |                                         |
        |                                         |
        v                                         |
 Catch HorovodInternalError                       |
        +                                         |
        |                                         |
        |                                         |
        |                                         |
       sync                                       |
        |                                         |
        |                                         |    
        v                                         |
_broadcast_model(model)                           |
        +                                         |
        |                                         |
        |                                         |
        |                                         |
        v                                         |
 broadcast_variables(model.variables)             |
                                                  |
 broadcast_variables(optimizer.variables)         |
                                                  |
        +                                         |
        |                                         |
        |                                         |
        |                                         |
        v                                         |
  broadcast_group                                 |
        +                                         |
        |                                         |
        |                                         |
        |                                         |
        v                                         |
 MPI_LIB.horovod_broadcast  +-------------------> |
        +                                         |
        |                                         |
        |                                         |
        v                                         v

2.2.2 HostsUpdateInterrupt

廣播對象作用是 在每個 worker 之間同步狀態,目的是讓這些 worker 同時拋出 HostsUpdateInterrupt 異常。

具體如何使用?

WorkerNotificationService . _handle 方法之中,調用了 self._manager.handle_hosts_updated(req.timestamp, req.res) 進行通知更新。

WorkerNotificationManager. handle_hosts_updated 方法之中,會調用註冊的 state,逐一通知更新。

def handle_hosts_updated(self, timestamp, update_res):
    for listener in self._listeners:
        listener.on_hosts_updated(timestamp, update_res)

是在 State 的幾個方法中可以看到。

  • on_hosts_updated :當有 host 變化時候調用,即 向 _host_messages 這個 queue 放入一個消息;
  • commit :用戶會定期調用此函數,會存儲狀態,檢查 host 更改;
  • check_host_updates : 會從 _host_messages 中讀取消息,積累更新,如方法中註釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時拋出異常。具體同步使用 _bcast_object

check_host_updates代碼如下:

def check_host_updates(self):
    """Checks that a notification has been sent indicating that hosts can be added or will be removed.

    Raises a `HostsUpdatedInterrupt` if such a notification has been received.
    """
    # Iterate through the update messages sent from the server. If the update timestamp
    # is greater than the last update timestamp, then trigger a HostsUpdatedException.
    last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
    all_update = HostUpdateResult.no_update
    while not self._host_messages.empty():
        timestamp, update = self._host_messages.get()
        if timestamp > last_updated_timestamp:
            last_updated_timestamp = timestamp
            all_update |= update

    # In order to ensure all workers raise the exception at the same time, we need to sync
    # the updated state across all the workers.
    # TODO(travis): this should be a max allreduce to account for changes in rank 0
    # 這裏會廣播
    prev_timestamp, self._last_updated_timestamp, all_update = \
        self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))

    # At this point, updated state is globally consistent across all ranks.
    if self._last_updated_timestamp > prev_timestamp:
        raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)

具體如下:

+---------------------------+      +--------------+            +-------------+
|Catch HostsUpdatedInterrupt|      | Worker rank 1|            |Worker rank n|
+---------+-----------------+      +-------+------+            +----+--------+
          |                                |                        |
          |                                |                        |
          |                                |                        |
          v                                |                        |
                                           |                        |
 WorkerNotificationService                 |                        |
          +                                |                        |
          |                                |                        |
          |                                |                        |
          |                                |                        |
          v                                |                        |
                                           |                        |
manager.handle_hosts_updated+------------> |                        |
                                           |                        |
                                           |                        |
                                           v                        |
                                                                    |
                                   on_hosts_updated                 |
                                           +                        |
                                           |                        |
                                           |                        |
                                           |                        |
                                   check_host_updates               |
                                           |                        |
                                           |                        |
                                           |                        |
                                           |                        |
                                           v                        |
                                                                    |
                                   broadcast_object                 |
                                           +                        |
                                           |                        |
                                           |                        |
                                           |                        |
                                           |                        |
                                           v                        |
                                                                    |
                                   MPI_LIB.horovod_broadcast +----> |
                                           +                        |
                                           |                        |
                                           |                        |
                                           v                        v

0x03 通知機制

上圖中用到 manager.handle_hosts_updated,manager就是 WorkerNotificationManager。

所以我們順着討論下 WorkerNotificationManager,這是 Hovorod 的通知機制

3.1 WorkerNotificationManager 生成

每個host 只有一個 WorkerNotificationManager,也只有一個 WorkerNotificationService

注意:是 ElasticDriver 會作爲client,給這些 WorkerNotificationService 發消息,從而引起 WorkerNotificationManager 的對應操作。

horovod/common/elastic.py 有如下代碼完成了實例生成。

notification_manager = WorkerNotificationManager()

WorkerNotificationManager 定義如下:

class WorkerNotificationManager(object):
    def __init__(self):
        self._lock = threading.Lock()
        self._service = WorkerNotificationService(secret_key, nic, self)
        self._listeners = set()

3.2 初始化

在用戶代碼啓動之前,會先初始化 WorkerNotificationManager。

 def run_fn(func, reset):
    @functools.wraps(func)
    def wrapper(state, *args, **kwargs):
        # 初始化 WorkerNotificationManager
        notification_manager.init()
        # 把自己對應的 state 註冊到 notification_manager
        notification_manager.register_listener(state)        

WorkerNotificationManager初始化代碼如下,其邏輯是:

  • 如果 _service 已經生成,則直接返回,這就保證了每個host之中只有一個 WorkerNotificationService
  • 從系統變量中得到 rendezvous 的各種信息,比如地址,端口,key 等等;
  • 生成 WorkerNotificationService,賦值給 _service;
  • 使用 put_data_into_kvstore 把 本 worker 的地址 和 給其在邏輯通信環分配的序號 rank 發送給 rendezvous(這個爲了後續生成 WorkerNotificationClient 使用)。
  • 備註:這個 rendezvous 會存儲每個 worker 的地址和給其在邏輯通信環分配的序號 rank。worker 進程可以通過這個 rendezvous 來構造新的通信域。
def init(self, rendezvous_addr=None, rendezvous_port=None,
         nic=None, hostname=None, local_rank=None):
    with self._lock:
        if self._service:
            return

        # 從系統變量中得到 rendezvous 的各種信息,比如地址,端口,key 等等  
        rendezvous_addr = rendezvous_addr or os.environ.get(HOROVOD_GLOO_RENDEZVOUS_ADDR)
        rendezvous_port = rendezvous_port if rendezvous_port is not None else \
            int(os.environ.get(HOROVOD_GLOO_RENDEZVOUS_PORT))
        nic = nic or os.environ.get(HOROVOD_GLOO_IFACE)
        hostname = hostname or os.environ.get(HOROVOD_HOSTNAME)
        local_rank = local_rank if local_rank is not None else \
            int(os.environ.get(HOROVOD_LOCAL_RANK))

        secret_key = secret.make_secret_key()
        self._service = WorkerNotificationService(secret_key, nic, self)

        value = (self._service.addresses(), secret_key)
        # 把本worker的地址 和 給其在邏輯通信環分配的序號 rank 發送給 rendezvous
        put_data_into_kvstore(rendezvous_addr,
                              rendezvous_port,
                              PUT_WORKER_ADDRESSES,
                              self._create_id(hostname, local_rank),
                              value)

具體 put_data_into_kvstore 如下。

def put_data_into_kvstore(addr, port, scope, key, value):
    try:
        url = "http://{addr}:{port}/{scope}/{key}".format(
            addr=addr, port=str(port), scope=scope, key=key
        )
        req = Request(url, data=codec.dumps_base64(value, to_ascii=False))
        req.get_method = lambda: "PUT"  # for urllib2 compatibility
        urlopen(req)
    except (HTTPError, URLError) as e:
        raise RuntimeError("Put data input KVStore server failed.", e)

3.3 註冊State

用戶代碼啓動之前,還會把自己對應的 state 註冊到 notification_manager。

def run_fn(func, reset):
    @functools.wraps(func)
    def wrapper(state, *args, **kwargs):
        # 初始化 WorkerNotificationManager
        notification_manager.init()
        # 把自己對應的 state 註冊到 notification_manager
        notification_manager.register_listener(state)

具體代碼如下:

def register_listener(self, listener):
    self._listeners.add(listener)

def remove_listener(self, listener):
    self._listeners.remove(listener)

3.4 WorkerNotificationService

WorkerNotificationService 在每個host之中也只有一個,用來接受其 client 發來的 HostsUpdatedRequest 消息,進行處理。可以看到,其繼承了 network.BasicService,這意味着 WorkerNotificationService 本身是一個http server,可以和其client交互,大家可以想想之前介紹的各種 driver / client,就可以理解其機制了。

class WorkerNotificationService(network.BasicService):
    NAME = 'worker notification service'

    def __init__(self, key, nic, manager):
        super(WorkerNotificationService, self).__init__(WorkerNotificationService.NAME,
                                                        key,
                                                        nic)
        self._manager = manager

    def _handle(self, req, client_address):
        if isinstance(req, HostsUpdatedRequest):
            self._manager.handle_hosts_updated(req.timestamp, req.res)
            return network.AckResponse()

        return super(WorkerNotificationService, self)._handle(req, client_address)

邏輯如下:

 +-------------------------------+                          +---------------------------+
 | WorkerNotificationManager     |                          | rendezvous                |
 |                               +------------------------> |                           |
 |                               |  put_data_into_kvstore   |                           |
 |                               |                          |                           |
 |                               |                          +---------------------------+
 | _listeners                    |
 |      +                        |                          +---------------------------+
 |      |         _service  +-----------------------------> | WorkerNotificationService |
 |      |                        |                          |                           |
 +-----------------------+-------+                          |                           |
        |                ^                                  |                           |
        |                |                                  |                           |
        |                |                                  |                           |
        |                +----------------------------------------+ _manager            |
        |                                                   |                           |
        v                                                   |                           |
                                                            +---------------------------+
[State 1, State 2, ......, State n]

3.5 WorkerNotificationClient

WorkerNotificationClient 就是用來給 WorkerNotificationService 發送消息的接口。

ElasticDriver 中,會針對每個 worker 生成一個對應的 WorkerNotificationClient,用來進行通知

class WorkerNotificationClient(network.BasicClient):
    def __init__(self, addresses, key, verbose, match_intf=False):
        super(WorkerNotificationClient, self).__init__(WorkerNotificationService.NAME,
                                                       addresses,
                                                       key,
                                                       verbose,
                                                       match_intf=match_intf)

    def notify_hosts_updated(self, timestamp, update_res):
        self._send(HostsUpdatedRequest(timestamp, update_res))

3.6 生成 Client

3.6.1 註冊時機

回顧一下,在 WorkerNotificationManager 的初始化函數 init 中,會給 rendezvous 發送put 請求,進行註冊。

註冊信息就是爲了 生成client。

put_data_into_kvstore(rendezvous_addr,
                      rendezvous_port,
                      PUT_WORKER_ADDRESSES,
                      self._create_id(hostname, local_rank),
                      value)

3.6.2 註冊 worker

在 ElasticRendezvousHandler 中有 _put_value,用來處理 PUT_WORKER_ADDRESSES。調用 driver 處理。

# 注意,這裏在 Rendezvous Server 之內
def _put_value(self, scope, key, value):
    if scope == PUT_WORKER_ADDRESSES:
        host, local_rank = key.split(':')
        addresses, secret_key = codec.loads_base64(value)
        self._put_worker_addresses(host, int(local_rank), addresses, secret_key)

    super(RendezvousHandler, self)._put_value(scope, key, value)

def _put_worker_addresses(self, host, local_rank, addresses, secret_key):
    # 這裏調用driver進行處理
    driver.register_worker_server(host, local_rank, addresses, secret_key)

3.6.3 生成 WorkerNotificationClient

ElasticDriver 中,會針對每個 worker 生成一個對應的 WorkerNotificationClient,用來進行通知。

這裏需要注意ElasticDriver 就是 WorkerNotificationClient 的使用者,需要通知各個worker時候,就調用這些WorkerNotificationClient ,給對應host上的 WorkerNotificationService發消息,從而引起WorkerNotificationManager做相應處理

# 這裏是 ElasticDriver 之中
def register_worker_server(self, host, slot, addresses, secret_key):
    self._worker_clients[(host, slot)] = WorkerNotificationClient(
        addresses, secret_key, self._verbose)

邏輯如下:

 +-------------------------------+
 | WorkerNotificationManager     |                          +---------------------------+      +----------------------------+
 |                               |                          | rendezvous                |      | ElasticRendezvousHandler   |
 |                 init  +--------------------------------> |                        +-------> |                            |
 |                               |  1 put_data_into_kvstore |                           |      |                            |
 |                               |                          |                           |      |                            |
 |                               |                          +---------------------------+      +------------------+---------+
 | _listeners                    |                                                                                |
 |      +                        |                          +---------------------------+                         |
 |      |         _service  +-----------------------------> | WorkerNotificationService |                         |
 |      |                        |                          |                           |                         |
 +-----------------------+-------+                          |                           |                         |
        |                ^                                  |                           |                         |
        |                |                                  |                           |                         |
        |                |                                  |                           |                         |
        |                +----------------------------------------+ _manager            |                         |
        |                                                   |                           |                         |
        v                                                   |                           |                         |
                                                            +---------------------------+                         |
[State 1, State 2, ......, State n]                                                                               |
                                                                                                                  |
                      +-------------------------------------------------------------------------------------------+
                      |                             2 register_worker_server
                      |
                      |
                      v
                                                        3 new instance
 +-------------------------------+
 |ElasticDriver                  |             +----------------------------+     +---------------------------+
 |                               |             | WorkerNotificationClient 1 |     |WorkerNotificationClient n |
 |                               |             |                            |     |                           |
 |                               |             |                            |     |                           |
 |         _worker_clients  +--------------->  |     (host 1, slot 1)       | ... |     (host n, slot n)      |
 |                               |             |      For worker 1          |     |        For worker n       |
 |                               |             |                            |     |                           |
 +-------------------------------+             +----------------------------+     +---------------------------+

手機如圖:

img

3.7 使用

3.7.1 發現更新

ElasticDriver._discovery_thread 之中 如果發現有 host 變化,則調用 self._notify_workers_host_changes 來通知。

def _notify_workers_host_changes(self, current_hosts, update_res):
    next_host_assignments = {}
    if current_hosts.count_available_slots() >= self._min_np:
        # Assignments are required to be stable via contract
        next_host_assignments, _ = self._get_host_assignments(current_hosts)

    if next_host_assignments == self.host_assignments:
        # Skip notifying workers when host changes would not result in changes of host assignments
        return

    coordinator_slot_info = self.get_coordinator_info()
    coordinator_client = self.get_worker_client(coordinator_slot_info)

    timestamp = _epoch_time_s()
    coordinator_client.notify_hosts_updated(timestamp, update_res)

3.7.2 獲取 client

get_worker_client 函數就是獲取 WorkerNotificationClient。就是依據 host,slot 信息來找到某一個 worker 對應的 client。

def get_worker_client(self, slot_info):
    return self._worker_clients.get((slot_info.hostname, slot_info.local_rank))

3.7.3 發送HostsUpdatedRequest

notify_hosts_updated 的作用是發送HostsUpdatedRequest

class WorkerNotificationClient(network.BasicClient):
    def __init__(self, addresses, key, verbose, match_intf=False):
        super(WorkerNotificationClient, self).__init__(WorkerNotificationService.NAME,
                                                       addresses,
                                                       key,
                                                       verbose,
                                                       match_intf=match_intf)

    def notify_hosts_updated(self, timestamp, update_res):
        self._send(HostsUpdatedRequest(timestamp, update_res))

3.7.4 處理 HostsUpdatedRequest

WorkerNotificationService 之中會處理HostsUpdatedRequest,調用 WorkerNotificationManager處理。

class WorkerNotificationService(network.BasicService):
    NAME = 'worker notification service'

    def __init__(self, key, nic, manager):
        super(WorkerNotificationService, self).__init__(WorkerNotificationService.NAME,
                                                        key,
                                                        nic)
        self._manager = manager

    def _handle(self, req, client_address):
        if isinstance(req, HostsUpdatedRequest):
            self._manager.handle_hosts_updated(req.timestamp, req.res)
            return network.AckResponse()

        return super(WorkerNotificationService, self)._handle(req, client_address)

3.7.5 WorkerNotificationManager

所以,當有host 更新時候,WorkerNotificationManager 中的 handle_hosts_updated 如下,最終調用到 state 的 on_hosts_updated。

def handle_hosts_updated(self, timestamp, update_res):
    for listener in self._listeners: # 遍歷state
        listener.on_hosts_updated(timestamp, update_res)

State 的實現如下:

def on_hosts_updated(self, timestamp, update_res):
    self._host_messages.put((timestamp, update_res))

邏輯如下圖:

                                                         +-----------------------------v
                                                         ^        thread loop          |
                                                         |                             |
                                        +----------------+----------------------+      |
                                        |  ElasticDriver._discovery_thread      |      |
       1 _notify_workers_host_changes   |                                       |      |
                                        |                                       |      |
                     +------------------+                                       |      |
                     |                  |                                       |      |
                     |                  |   HostManager.update_available_hosts  |      |
                     |                  |                                       |      |
                     |                  +-----------------+---------------------+      |
                     |                                    ^                            |
                     |                                    |                            |
                     |                                    |                            |
                     |                                    +----------<---------------+ v
                     v

+---------------------------+ 2 HostsUpdatedRequest  +----------------------------+ handle_hosts_updated +----------------------------+
|                           |                        |                            |                      |                            |
| WorkerNotificationClient  +----------------------> |  WorkerNotificationService | +------------------> |  WorkerNotificationManager |
|                           |                        |                            |                      |                            |
+---------------------------+                        +----------------------------+                      +------+---------------------+
                                                                                                                |
                                                                                                                |
                                                                                                                | on_hosts_updated
                                                                                                                |
                                                                                                                v
                                                                                                  +-----------------------+
                                                                                                  |  State      |         |
                                                                                                  |             | put     |
                                                                                                  |             v         |
                                                                                                  |     _host_messages    |
                                                                                                  +-----------------------+

手機如下:

img

3.7.6 處理更新

在用戶調用 commit 的時候,纔會調用 check_host_updates 檢查更新。

def commit(self):
    self.save()
    self.check_host_updates()

檢查更新就是看看 _host_messages 有沒有新的消息,如果發現 host 有變化,就會產生一個 HostsUpdatedInterrupt 異常。

def check_host_updates(self):
    # Iterate through the update messages sent from the server. If the update timestamp
    # is greater than the last update timestamp, then trigger a HostsUpdatedException.
    last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
    all_update = HostUpdateResult.no_update
    while not self._host_messages.empty():
        timestamp, update = self._host_messages.get()
        if timestamp > last_updated_timestamp:
            last_updated_timestamp = timestamp
            all_update |= update

    # In order to ensure all workers raise the exception at the same time, we need to sync
    # the updated state across all the workers.
    # TODO(travis): this should be a max allreduce to account for changes in rank 0
    prev_timestamp, self._last_updated_timestamp, all_update = \
        self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))

    # At this point, updated state is globally consistent across all ranks.
    if self._last_updated_timestamp > prev_timestamp:
        raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)

在worker進程出現 HorvodInternalError 錯誤或者 HostsUpdatedInterrupt 節點增刪時,會捕獲這兩個錯誤,調用 reset 來進行容錯處理。於是就把流程前後串聯了起來。

具體如下:

img

至此,廣播通知機制我們整理完畢,下一篇介紹 worker 如何運作。

0xEE 個人信息

★★★★★★關於生活和技術的思考★★★★★★

微信公衆賬號:羅西的思考

如果您想及時得到個人撰寫文章的消息推送,或者想看看個人推薦的技術資料,敬請關注。

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章