[源碼解析] 深度學習分佈式訓練框架 horovod (4) --- 網絡基礎 & Driver

[源碼解析] 深度學習分佈式訓練框架 horovod (4) --- 網絡基礎 & Driver

0x00 摘要

Horovod 是Uber於2017年發佈的一個易於使用的高性能的分佈式訓練框架,在業界得到了廣泛應用。

本系列將通過源碼分析來帶領大家瞭解 Horovod。本文是系列第四篇,看看如何獲取 host 之間的路由等網絡信息。

前面幾篇鏈接如下:

[源碼解析] 深度學習分佈式訓練框架 Horovod (1) --- 基礎知識

[源碼解析] 深度學習分佈式訓練框架 horovod (2) --- 從使用者角度切入

[源碼解析] 深度學習分佈式訓練框架 horovod (3) --- Horovodrun背後做了什麼

0x01 引子

在 horovod/runner/launch.py 文件中,_run_static 函數中使用 driver_service.get_common_interfaces 來獲取路由信息等。

def _run_static(args):
    nics = driver_service.get_common_interfaces(settings, all_host_names,
                                                remote_host_names, fn_cache)

因爲這部分比較複雜( Driver 的概念很類似 Spark 之中 Driver 的概念),所以本文我們單獨來分析。

本文的分析問題點是:

  • 爲什麼要知道路由信息?
  • 當有多個host時候,horovod如何處理?
  • 如何找到路由信息?
  • 怎麼互相交互?
  • (後文會詳細分析)SparkDriverService,SparkTaskService,ElasticDriver, Worker 都有什麼區別和聯繫?

本文重點分析 HorovodRunDriverService 和 HorovodRunTaskService 相關。

先給出一個圖例,大家可以有些概念。

0x02 總體架構

從註釋可知,get_common_interfaces 完成了獲得路由信息(所有host之間的共有路由接口集合)的功能,主要是調用 _driver_fn 來完成相關工作。

def get_common_interfaces(settings, all_host_names, remote_host_names=None, fn_cache=None):
    '''
    Find the set of common and routed interfaces on all the hosts.
    '''

    # 得到遠端host地址
    if remote_host_names is None:
        remote_host_names = network.filter_local_addresses(all_host_names)

    if len(remote_host_names) > 0:
        if settings.nics: # 如果參數有設定網絡接口,就使用
            # If args.nics is provided, we will use those interfaces. All the workers
            # must have at least one of those interfaces available.
            nics = settings.nics
        else:
            # Find the set of common, routed interfaces on all the hosts (remote
            # and local) and specify it in the args to be used by NCCL. It is
            # expected that the following function will find at least one interface
            # otherwise, it will raise an exception.

            local_host_names = set(all_host_names) - set(remote_host_names)
            # 獲取其他host的網絡接口
            nics = _driver_fn(all_host_names, local_host_names, settings, fn_cache=fn_cache)

    else:
        nics = get_local_interfaces(settings) # 獲取本地的網絡接口
    return nics

2.1 get_local_interfaces

此函數比較簡單,目的是獲取本地的網絡接口。

def get_local_interfaces(settings):
    # If all the given hosts are local, find the interfaces with address
    # 127.0.0.1
    nics = set()
    for iface, addrs in net_if_addrs().items():
        if settings.nics and iface not in settings.nics:
            continue
        for addr in addrs:
            if addr.family == AF_INET and addr.address == '127.0.0.1':
                nics.add(iface)
                break

    return nics

2.2 _driver_fn

這是本文重點,獲取其他host 的網絡接口,_driver_fn 的作用是:

  • 啓動 service 服務;
  • 使用 driver.addresses() 獲取 Driver 服務的地址(使用self._addresses = self._get_local_addresses()完成);
  • 使用 _launch_task_servers(利用 Driver 服務的地址)在每個 worker 之中啓動 task 服務,然後 task 服務會在 service 服務中註冊;
  • 因爲是一個環形,每個 worker 會探測 worker index + 1 的所有網絡接口;
  • 最後 _run_probe 返回一個所有 workers 上的所有路由接口的交集;

代碼如下:

這裏需要注意的一點是:@cache.use_cache() 的使用:當第一次使用過之後,會把結果放入緩存。

@cache.use_cache()
def _driver_fn(all_host_names, local_host_names, settings):
    """
    launches the service service, launches the task service on each worker and
    have them register with the service service. Each worker probes all the
    interfaces of the worker index + 1 (in a ring manner) and only keeps the
    routed interfaces. Function returns the intersection of the set of all the
    routed interfaces on all the workers.
    :param all_host_names: list of addresses. for example,
        ['worker-0','worker-1']
        ['10.11.11.11', '10.11.11.12']
    :type all_host_names: list(string)
    :param local_host_names: host names that resolve into a local addresses.
    :type local_host_names: set
    :param settings: the object that contains the setting for running horovod
    :type settings: horovod.runner.common.util.settings.Settings
    :return: example: ['eth0', 'eth1']
    :rtype: list[string]
    """
    # Launch a TCP server called service service on the host running horovod
    # 啓動 service 服務
    num_hosts = len(all_host_names)
    driver = HorovodRunDriverService(num_hosts, settings.key, settings.nics)

    # Have all the workers register themselves with the service service.
    #(利用 Driver 服務的地址)在每個worker之中啓動 task 服務,然後task服務會在 service 服務中註冊
    _launch_task_servers(all_host_names, local_host_names,
                         driver.addresses(), settings)
    try:
        # 返回一個所有 workers 上的所有路由接口的交集
        return _run_probe(driver, settings, num_hosts)
    finally:
        driver.shutdown()

2.3 獲取路由接口

我們對 _run_probe 函數做進一步分析。

2.3.1 probe邏輯

_run_probe 函數就是當 所有 task 都啓動,註冊,probe 環中下一個worker 鄰居完成 之後,得到 接口集合。

  • 利用 wait_for_initial_registration 等待所有 task 完成註冊;
  • 對於所有 task,完成 task.notify_initial_registration_complete 通知;
  • 利用 driver.wait_for_task_to_task_address_updates 等待 每一個 worker probe 完成;
  • 利用 nics.intersection_update 得到接口集合;
def _run_probe(driver, settings, num_hosts):
       # wait for all the hosts to register with the service service.

    driver.wait_for_initial_registration(settings.start_timeout)
    tasks = [
        task_service.HorovodRunTaskClient(
            index,
            driver.task_addresses_for_driver(index),
            settings.key,
            settings.verbose) for index in range(
            num_hosts)]
    # Notify all the drivers that the initial registration is complete.
    for task in tasks:
        task.notify_initial_registration_complete()

    # Each worker should probe the interfaces of the next worker in a ring
    # manner and filter only the routed ones -- it should filter out
    # interfaces that are not really connected to any external networks
    # such as lo0 with address 127.0.0.1.
    driver.wait_for_task_to_task_address_updates(settings.start_timeout)

    # Determine a set of common interfaces for task-to-task communication.
    nics = set(driver.task_addresses_for_tasks(0).keys())
    for index in range(1, num_hosts):
        nics.intersection_update(
            driver.task_addresses_for_tasks(index).keys())

    return nics

2.3.2 等待函數

probe 利用 wait_for_initial_registration 等待所有 task 完成註冊,具體等待函數如下:

def wait_for_initial_registration(self, timeout):
    self._wait_cond.acquire()
    try:
        while len(self._all_task_addresses) < self._num_proc:
            self._wait_cond.wait(timeout.remaining())
            timeout.check_time_out_for('tasks to start')
    finally:
        self._wait_cond.release()

def wait_for_task_to_task_address_updates(self, timeout):
    self._wait_cond.acquire()
    try:
        while len(self._task_addresses_for_tasks) < self._num_proc:
            self._wait_cond.wait(timeout.remaining())
            timeout.check_time_out_for(
                'tasks to update task-to-task addresses')
    finally:
        self._wait_cond.release()

0x03 基礎網絡服務

前面提到,Horovod Driver 的概念很類似 Spark 之中 Driver 的概念。Spark應用程序運行時主要分爲 Driver 和 Executor,Driver負載總體調度及UI展示,Executor負責Task運行。用戶的Spark應用程序運行在Driver上(某種程度上說,用戶的程序就是Spark Driver程序),經過Spark調度封裝成一個個Task,再將這些Task信息發給Executor執行,Task信息包括代碼邏輯以及數據信息,Executor不直接運行用戶的代碼。

對於 Horovod 來說:

  • HorovodRunDriverService 就是 Driver 的實現類。
  • HorovodRunTaskService 提供了 Task 部分服務功能,這些 task 需要註冊到 HorovodRunDriverService 之中。
  • 這套 driver & task 機制的底層由 "基礎網絡服務" 支撐。

所以我們就仔細分析下基礎網絡服務。

3.1 繼承關係

首先給出繼承關係,我們下面講解的 Driver 服務由 HorovodRunDriverService 提供,Task 服務由HorovodRunTaskService 提供。

這兩個類最終都繼承了 network.BasicService。

                            network.BasicService

                                  ^    ^
                                  |    |
              +-------------------+    +-------------+
              |                                      |
              +                                      +
driver_service.BasicDriverService       task_service.BasicTaskService
              ^                                      ^
              |                                      |
              |                                      |
              |                                      |
              +                                      +
    HorovodRunDriverService                HorovodRunTaskService

3.2 network.BasicService

BasicService 提供了一個網絡服務器功能。即通過find_port函數構建了一個ThreadingTCPServer,對外提供服務。

class BasicService(object):
    def __init__(self, service_name, key, nics):
        self._service_name = service_name
        self._wire = Wire(key)
        self._nics = nics
        self._server, _ = find_port(
            lambda addr: socketserver.ThreadingTCPServer(
                addr, self._make_handler()))
        self._server._block_on_close = True
        self._port = self._server.socket.getsockname()[1]
        self._addresses = self._get_local_addresses()
        self._thread = in_thread(target=self._server.serve_forever)

3.2.1 創建Server

創建服務器代碼如下,這裏是搜索一個隨機端口,然後設置:

def find_port(server_factory):
    min_port = 1024
    max_port = 65536
    num_ports = max_port - min_port
    start_port = random.randrange(0, num_ports)
    
    for port_offset in range(num_ports):
        try:
            port = min_port + (start_port + port_offset) % num_ports
            addr = ('', port)
            server = server_factory(addr)
            return server, port
        except Exception as e:
            pass

    raise Exception('Unable to find a port to bind to.')

3.2.2 Server功能

服務器就是基本的功能,比如獲取本server地址,處理 ping,網絡交互等。

def _make_handler(self):
    server = self

    class _Handler(socketserver.StreamRequestHandler):
        def handle(self):
            try:
                req = server._wire.read(self.rfile)
                resp = server._handle(req, self.client_address)

                # A tuple is the usual response object followed by a utf8 text stream
                if type(resp) == tuple:
                    (resp, stream) = resp
                    server._wire.write(resp, self.wfile)
                    server._wire.stream(stream, self.wfile)
                else:
                    server._wire.write(resp, self.wfile)
            except (EOFError, BrokenPipeError):
                # Happens when client is abruptly terminated, don't want to pollute the logs.
                pass

    return _Handler

def _handle(self, req, client_address):
    if isinstance(req, PingRequest):
        return PingResponse(self._service_name, client_address[0])

    raise NotImplementedError(req)

def _get_local_addresses(self):
    result = {}
    for intf, intf_addresses in psutil.net_if_addrs().items():
        if self._nics and intf not in self._nics:
            continue
        for addr in intf_addresses:
            if addr.family == socket.AF_INET:
                if intf not in result:
                    result[intf] = []
                result[intf].append((addr.address, self._port))
    return result

def addresses(self):
    return self._addresses.copy()

def shutdown(self):
    self._server.shutdown()
    self._server.server_close()
    self._thread.join()

def get_port(self):
    return self._port

3.3 network.BasicClient

HorovodRunDriverClient 和 HorovodRunTaskClient 這兩個類都繼承了network.BasicClient。

network.BasicClient 的作用就是連接 network.BasicService,與其交互。即 network.BasicClient 是一個操作接口

                             network.BasicClient

                                ^            ^
                                |            |
             +------------------+            +---------------+
             |                                               |
             +                                               |
                                                             +
driver_service.BasicDriverClient               task_service.BasicTaskClient

             ^                                               ^
             |                                               |
             |                                               |
             +                                               +
   HorovodRunDriverClient                           HorovodRunTaskClient

兩個主要 API 如下:

3.3.1 _probe

_probe 獲取 server 的網絡接口。

def _probe(self, addresses):
    result_queue = queue.Queue()
    threads = []
    for intf, intf_addresses in addresses.items():
        for addr in intf_addresses:
            thread = in_thread(target=self._probe_one, args=(intf, addr, result_queue))
            threads.append(thread)
    for t in threads:
        t.join()

    result = {}
    while not result_queue.empty():
        intf, addr = result_queue.get()
        if intf not in result:
            result[intf] = []
        result[intf].append(addr)
    return result

3.3.2 發送消息

_send 的作用是給server發送消息。

def _send(self, req, stream=None):
    """
    Sends the request and returns the response object.
    Streaming data response is transferred to the optional stream parameter.
    """
    # Since all the addresses were vetted, use the first one.
    addr = list(self._addresses.values())[0][0]
    return self._send_one(addr, req, stream)

3.4 總結

我們可以看到,network.BasicService 會提供了一個server,這個 Service 都是通過 network.BasicClient 來訪問。基於此,Horovod 的HorovodRunDriverService 和 HorovodRunTaskService 這兩個類就可以互相交互,進行溝通。

0x04 Driver 服務

Driver 服務由 HorovodRunDriverService 提供,其功能主要是維護維護各種 task 地址以及相應關係。具體各種 task 地址 就是 Task 服務 來註冊的

需要注意的是:HorovodRunDriverService 和 HorovodRunTaskService 都最終繼承了 network.BasicService,他們之間可以是異地運行交互

4.1 HorovodRunDriverService

HorovodRunDriverService 是對 BasicDriverService 的封裝。

HorovodRunDriverClient 是 其 訪問接口。

class HorovodRunDriverService(driver_service.BasicDriverService):
    NAME = 'horovod driver service'

    def __init__(self, num_hosts, key, nics):
        super(HorovodRunDriverService, self).__init__(num_hosts,
                                                      HorovodRunDriverService.NAME,
                                                      key, nics)

class HorovodRunDriverClient(driver_service.BasicDriverClient):
    def __init__(self, driver_addresses, key, verbose, match_intf=False):
        super(HorovodRunDriverClient, self).__init__(
            HorovodRunDriverService.NAME,
            driver_addresses,
            key,
            verbose,
            match_intf=match_intf)

4.2 BasicDriverService

BasicDriverService基類 主要就是 維護各種 task 地址以及相應關係

class BasicDriverService(network.BasicService):
    def __init__(self, num_proc, name, key, nics):
        super(BasicDriverService, self).__init__(name, key, nics)
        self._num_proc = num_proc
        self._all_task_addresses = {}
        self._task_addresses_for_driver = {}
        self._task_addresses_for_tasks = {}
        self._task_index_host_hash = {}
        self._task_host_hash_indices = {}
        self._wait_cond = threading.Condition()

這裏的各種 task 地址就是 Task 服務 註冊到 Driver 的數值。

可以看到裏面有各種關於地址的變量,爲了讓大家理解這些變量的作用,對於每一個變量我們舉例如下(這裏有些變量是專門爲 spark 設計,都放到基類裏面有點奇怪):

4.2.1 _all_task_addresses

本變量是記錄了所有 task 的地址,變量舉例如下:

self._all_task_addresses = { 
  1: { 
    'lo' : [('1.1.1.1', 12345)],
		'eth0' : [('10.10.10.01', 12345)]
	},
  0: { 
    'lo' : [('2.2.2.2', 54321)],
		'eth0' : [('10.10.10.02', 54321)]
	}  
}

本變量由 task 調用 RegisterTaskRequest 來註冊。

if isinstance(req, RegisterTaskRequest):
    self._wait_cond.acquire()
    try:
        assert 0 <= req.index < self._num_proc
        self._all_task_addresses[req.index] = req.task_addresses

使用時候,有幾個方式,舉例如下:

比如 all_task_addresses 獲取 self._all_task_addresses[index].copy() 來決定 ping /check 的下一個跳轉。

4.2.2 _task_addresses_for_driver

本變量是記錄了所有 task 的地址,但是網卡接口有多種,這裏選擇與 本 driver 地址匹配的地址

變量舉例如下:

self._task_addresses_for_driver = { 
  1: { 
		'eth0' : [('10.10.10.01', 12345)]
	},
  0: { 
		'eth0' : [('10.10.10.02', 54321)]
	} 
}

本變量由 task 調用 RegisterTaskRequest 來註冊。

# Just use source address for service for fast probing.
self._task_addresses_for_driver[req.index] = \
    self._filter_by_ip(req.task_addresses, client_address[0])

具體使用舉例如下:

def task_addresses_for_driver(self, index):
    self._wait_cond.acquire()
    try:
        return self._task_addresses_for_driver[index].copy()
    finally:
        self._wait_cond.release()

driver用這個地址來生成 其內部 task 變量

tasks = [
    task_service.HorovodRunTaskClient(
        index,
        driver.task_addresses_for_driver(index),
        settings.key,
        settings.verbose) for index in range(
        num_hosts)]

4.2.3 _task_addresses_for_tasks

該變量舉例如下:

self._task_addresses_for_tasks = { 
  1: { 
		'eth0' : [('10.10.10.01', 12345)]
	},
  0: { 
		'eth0' : [('10.10.10.02', 54321)]
	} 
}

本變量由RegisterTaskToTaskAddressesRequest註冊。

if isinstance(req, RegisterTaskToTaskAddressesRequest):
    self.register_task_to_task_addresses(req.index, req.task_addresses)
    return network.AckResponse()
  
def register_task_to_task_addresses(self, index, task_addresses):
    self._wait_cond.acquire()
    try:
        assert 0 <= index < self._num_proc
        self._task_addresses_for_tasks[index] = task_addresses # 這裏賦值
    finally:
        self._wait_cond.notify_all()
        self._wait_cond.release()  

該變量被 task 用來獲取 某個 task 的一套網絡接口,比如:

# Determine a set of common interfaces for task-to-task communication.
nics = set(driver.task_addresses_for_tasks(0).keys())

4.2.4 _task_index_host_hash

每一個 task 有一個對應的 host hash,該數值被 MPI 作爲 host name 來操作。

self._task_index_host_hash = { 
  1: { 
		'ip-10-10-10-01-dfdsfdsfdsfdsf2'
	},
  0: { 
		'ip-10-10-10-02-treterwrtqwer'
	} 
}

具體使用如下。這個函數是 spark 相關會使用,具體是逐一通知 spark task 進入下一階段

def task_indices(self):
    self._wait_cond.acquire()
    try:
        return list(self._task_index_host_hash.keys())
    finally:
        self._wait_cond.release()

或者使用如下,是爲了獲取某一個 host 對應的 host hash name

def task_index_host_hash(self, index):
    self._wait_cond.acquire()
    try:
        assert 0 <= index < self._num_proc
        return self._task_index_host_hash[index]
    finally:
        self._wait_cond.release()

4.2.5 _task_host_hash_indices

該變量舉例如下:

self._task_host_hash_indices = { 
  { 
		'ip-10-10-10-01-dfdsfdsfdsfdsf2' : [1]
	},
  { 
		'ip-10-10-10-02-treterwrtqwer' : [0]
	} 
}

具體是在註冊 RegisterTaskRequest 時候生成。

self._task_host_hash_indices[req.host_hash].append(req.index)

使用具體代碼是:

def task_host_hash_indices(self):
    self._wait_cond.acquire()
    try:
        return self._task_host_hash_indices.copy()
    finally:
        self._wait_cond.release()

具體是被 rsh 使用。rsh 就是在某一個 host 上,讓某一個 horovod rank 啓動。具體邏輯是:

  • 獲取某一個 host 上所有的 task indices ;
  • 利用 task_host_hash_indices 取出本進程 local rank 對應的 task index;
  • 取出在 driver 中 task index 對應保持的 task address;
  • 最後依據這個 task addresses 生成一個 SparkTaskClient,進行後續操作。
driver_client = driver_service.SparkDriverClient(driver_addresses, key, verbose=verbose)
task_indices = driver_client.task_host_hash_indices(host_hash)
task_index = task_indices[local_rank]
task_addresses = driver_client.all_task_addresses(task_index)
task_client = task_service.SparkTaskClient(task_index, task_addresses, key, verbose=verbose)
task_client.stream_command_output(stdout, stderr)
task_client.run_command(command, env,
                        capture_stdout=stdout is not None,
                        capture_stderr=stderr is not None,
                        prefix_output_with_timestamp=prefix_output_with_timestamp)

4.3 總體邏輯

總體邏輯如下:

                               network.BasicService

                                     ^    ^
                                     |    |
                 +-------------------+    +-------------+
                 |                                      |
                 +                                      +
   driver_service.BasicDriverService       task_service.BasicTaskService
                 ^                                      ^
                 |                                      |
                 |                                      |
                 |                                      |
                 |                                      +
+----------------+------------------+         HorovodRunTaskService
| HorovodRunDriverService           |
|                                   |
|                                   |
|        _all_task_addresses        |
|                                   |
|    _task_addresses_for_driver     |
|                                   |
|       _task_addresses_for_tasks   |
|                                   |
|       _task_index_host_hash       |
|                                   |
|     _task_host_hash_indices       |
|                                   |
+-----------------------------------+

0x05 Task 服務

HorovodRunTaskService 提供了 Task 部分服務功能。整體邏輯是由幾個函數共同完成。

5.1 啓動具體服務

_launch_task_servers 用來啓動具體服務,其主要作用是:多線程運行,在每一個線程中,遠程運行 horovod.runner.task_fn

其中:

  • 傳入參數中,all_host_names 就是程序啓動時候配置的所有host,比如 ["1.1.1.1", "2.2.2.2"];
  • 使用了我們之前提到的 safe_shell_exec.execute 完成了安全運行保證;
  • 使用我們前文提到的 get_remote_command 完成了遠程命令的獲取,即在命令之前加上了 ssh -o PasswordAuthentication=no -o StrictHostKeyChecking=no等等配置;
  • 最終每個啓動的命令舉例如下: ssh -o PasswordAuthentication=no -o StrictHostKeyChecking=no 1.1.1.1 python -m horovod.runner.task_fn xxxxxxx
  • 使用 execute_function_multithreaded 在每一個 host 上運行,啓動 task 服務;

具體代碼如下:

def _launch_task_servers(all_host_names, local_host_names, driver_addresses,
                         settings):
    """
    Executes the task server and service client task for registration on the
    hosts.
    :param all_host_names: list of addresses. for example,
        ['worker-0','worker-1']
        ['10.11.11.11', '10.11.11.12']
    :type all_host_names: list(string)
    :param local_host_names: names that are resolved to one of the addresses
    of local hosts interfaces. For example,
        set(['localhost', '127.0.0.1'])
    :type local_host_names: set
    :param driver_addresses: map of interfaces and their address and port for
    the service. For example:
        {
            'lo': [('127.0.0.1', 34588)],
            'docker0': [('172.122.10.1', 34588)],
            'eth0': [('11.111.33.73', 34588)]
        }
    :type driver_addresses: map
    :param settings: the object that contains the setting for running horovod
    :type settings: horovod.runner.common.util.settings.Settings
    :return:
    :rtype:
    """

    def _exec_command(command):
        host_output = io.StringIO()
        try:
            # 完成了安全運行保證
            exit_code = safe_shell_exec.execute(command,
                                                stdout=host_output,
                                                stderr=host_output)
        finally:
            host_output.close()
        return exit_code

    args_list = []
    num_hosts = len(all_host_names)
    for index in range(num_hosts):
        host_name = all_host_names[index] # all_host_names 就是程序啓動時候配置的所有host,比如 ["1.1.1.1", "2.2.2.2"]
        command = \
            '{python} -m horovod.runner.task_fn {index} {num_hosts} ' \
            '{driver_addresses} {settings}' \
            .format(python=sys.executable,
                    index=codec.dumps_base64(index),
                    num_hosts=codec.dumps_base64(num_hosts),
                    driver_addresses=codec.dumps_base64(driver_addresses),
                    settings=codec.dumps_base64(settings))
        if host_name not in local_host_names:
            # 完成了遠程命令的獲取,即在命令之前加上了 `ssh -o PasswordAuthentication=no -o StrictHostKeyChecking=no`等等配置
            command = get_remote_command(command,
                                         host=host_name,
                                         port=settings.ssh_port,
                                         identity_file=settings.ssh_identity_file)

        args_list.append([command])
        
    # Each thread will use ssh command to launch the server on one task. If an
    # error occurs in one thread, entire process will be terminated. Otherwise,
    # threads will keep running and ssh session -- and the the task server --
    # will be bound to the thread. In case, the horovod process dies, all
    # the ssh sessions and all the task servers will die as well.
    
    # 使用 execute_function_multithreaded 在每一個 host 上運行,啓動 task 服務
    threads.execute_function_multithreaded(_exec_command,
                                           args_list,
                                           block_until_all_done=False)

5.2 具體服務邏輯

上段有:{python} -m horovod.runner.task_fn {index} {num_hosts} {driver_addresses} {settings}執行具體服務邏輯,所以我們介紹下 horovod.runner.task_fn

_task_fn 函數完成了

  • 生成了 HorovodRunTaskService 實例,賦值給 task;
  • 使用 HorovodRunDriverClient . register_task 來向 Driver 服務註冊task(自己)的地址;
  • 使用 HorovodRunDriverClient . register_task_to_task_addresses 來向 Driver 服務註冊自己在Ring上 下一個鄰居的地址;
  • 每一個 task 都做這個操作,最後就得到了在這個 ring cluster 之中的一個路由接口;

具體代碼如下:

def _task_fn(index, num_hosts, driver_addresses, settings):
  
    task = task_service.HorovodRunTaskService(index, settings.key, settings.nics)
    try:
        driver = driver_service.HorovodRunDriverClient(
            driver_addresses, settings.key, settings.verbose)
        # 向 Driver 服務註冊task(自己)的地址
        driver.register_task(index,
                             task.addresses(),
                             host_hash.host_hash())
        task.wait_for_initial_registration(settings.start_timeout)
        # Tasks ping each other in a circular fashion to determine interfaces
        # reachable within the cluster.
        next_task_index = (index + 1) % num_hosts
        next_task_addresses = driver.all_task_addresses(next_task_index)
        # We request interface matching to weed out all the NAT'ed interfaces.
        next_task = task_service.HorovodRunTaskClient(
            next_task_index,
            next_task_addresses,
            settings.key,
            settings.verbose,
            match_intf=True,
            attempts=10)
        # 向 Driver 服務註冊自己在Ring上 下一個鄰居的地址
        driver.register_task_to_task_addresses(next_task_index,
                                               next_task.addresses())
        # Notify the next task that the address checks are completed.
        next_task.task_to_task_address_check_completed()
        # Wait to get a notification from previous task that its address checks
        # are completed as well.
        task.wait_for_task_to_task_address_check_finish_signal(settings.start_timeout)

    finally:
        task.shutdown()


if __name__ == '__main__':
    index = codec.loads_base64(sys.argv[1])
    num_hosts = codec.loads_base64(sys.argv[2])
    driver_addresses = codec.loads_base64(sys.argv[3])
    settings = codec.loads_base64(sys.argv[4])

    _task_fn(index, num_hosts, driver_addresses, settings)

5.3 HorovodRunTaskService

HorovodRunTaskService 主要的作用是提供了兩個等待函數。因爲具體路由操作是需要彼此通知,所以需要互相等待

class HorovodRunTaskService(task_service.BasicTaskService):
    NAME_FORMAT = 'horovod task service #%d'

    def __init__(self, index, key, nics):
        super(HorovodRunTaskService, self).__init__(
            HorovodRunTaskService.NAME_FORMAT % index,
            index, key, nics)
        self.index = index
        self._task_to_task_address_check_completed = False

    def _handle(self, req, client_address):

        if isinstance(req, TaskToTaskAddressCheckFinishedSignal):
            self._wait_cond.acquire()
            try:
                self._task_to_task_address_check_completed = True
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()

            return TaskToTaskAddressCheckFinishedSignalResponse(self.index)

        return super(HorovodRunTaskService, self)._handle(req, client_address)

    def wait_for_task_to_task_address_check_finish_signal(self, timeout):
        self._wait_cond.acquire()
        try:
            while not self._task_to_task_address_check_completed:
                self._wait_cond.wait(timeout.remaining())
                timeout.check_time_out_for('Task to task address check')
        finally:
            self._wait_cond.release()


class HorovodRunTaskClient(task_service.BasicTaskClient):

    def __init__(self, index, task_addresses, key, verbose, match_intf=False, attempts=3):
        super(HorovodRunTaskClient, self).__init__(
            HorovodRunTaskService.NAME_FORMAT % index,
            task_addresses, key, verbose,
            match_intf=match_intf,
            attempts=attempts)
        self.index = index

    def task_to_task_address_check_completed(self):
        resp = self._send(TaskToTaskAddressCheckFinishedSignal(self.index))
        return resp.index

邏輯如下:

                                                         _driver_fn
                                                            +
                                                            |
                                                            |
                    +---------------------------------------+-------------------------------------v
                    |                                                                             |
                    |                                                                             v
                    |                                                                   _launch_task_servers
                    v                                                                             +
     driver = HorovodRunDriverService                                                             |
                    +                                                              +--------------+-------------------+
                    |                                                              |                                  |
                    |                                                              |                                  |
                    v                                                              v                                  v
+-------------------+---------------+                                    horovod.runner.task_fn    ......     horovod.runner.task_fn
| HorovodRunDriverService           |                                              +                                  +
|                                   |                                              |                                  |
|                                   |                                              |                                  |
|        _all_task_addresses        |                                              |                                  |
|                                   |                                              v                                  v
|    _task_addresses_for_driver     |          register_task           +-----------+---------------+          +-------+--------------------+
|                                   |                                  | HorovodRunTaskService     |          |  HorovodRunTaskService     |
|       _task_addresses_for_tasks   | <--------------------------------+                           |          |                            |
|                                   |                                  |                           |   wait   |                            |
|       _task_index_host_hash       |                                  |                           | <------> |                            |
|                                   | <--------------------------------+                           |          |                            |
|     _task_host_hash_indices       |  register_task_to_task_addresses |                           |          |                            |
|                                   |                                  +---------------------------+          +----------------------------+
+-----------------------------------+                                                  `

手機如下:

0x06 總結

本文總結如下:

  • 因爲 Horovod 分佈式訓練 涉及到多個 hosts,所以如果要彼此訪問,需要知道路由信息;
  • 當所有 task 都啓動,註冊,probe 環中下一個worker 鄰居完成 之後,DriverService 會得到路由信息(所有host之間的共有路由接口集合),返回給 Horovod 主體部分使用;
  • network.BasicService 提供了網絡服務功能;
  • XXXService 都是通過 XXXClient作爲接口才能訪問;
  • HorovodRunDriverService 和 HorovodRunTaskService 都最終繼承了 network.BasicService,他們之間可以是異地運行交互
  • HorovodRunTaskService 提供了 Task 部分服務功能,這些 task 需要註冊到 Driver 之中(和Spark思路類似)。
  • HorovodRunDriverService 是對 BasicDriverService 的封裝。BasicDriverService 就是 維護各種 task 地址以及相應關係,比如:
    • _all_task_addresses :記錄了所有 task 的地址;
    • _task_addresses_for_driver :記錄了所有 task 的地址,但是因爲網卡接口有多種,這裏選擇與 本driver 地址匹配的地址;
    • _task_addresses_for_tasks :用來給某一個 task 分配一個地址,同時獲取本 task 的一套網絡接口;
    • _task_index_host_hash :每一個 task 有一個對應的 host hash。這個函數是 spark 相關會使用,具體是逐一通知 spark task 進入下一階段。或者是爲了獲取某一個 host 對應的 host hash name
    • _task_host_hash_indices :具體是被 rsh 使用,由 rank 得到 在 driver 中 task index 對應保持的 task address;
  • SparkDriverService,SparkTaskService,ElasticDriver, Worker 都有什麼區別和聯繫?
    • HorovodRunDriverService 這裏只是用來得到路由信息,記錄各種 Task 地址;
    • SparkDriverService 除了記錄路由和地址之外,還提交執行任務(Command),因爲具體在哪一個Spark Executor啓動之後,SparkDriverService 就需要知道 對應 SparkTaskService 的地址,這樣才能知道提交到哪裏;
    • SparkTaskService 負責執行命令(拋棄了Spark Executor的邏輯,自己搞了一套),就是從 SparkDriverService 那裏獲得訓練函數,然後啓動 python 進程來執行;
    • ElasticDriver 做得更多,因爲還有彈性,需要容錯;

0xEE 個人信息

★★★★★★關於生活和技術的思考★★★★★★

微信公衆賬號:羅西的思考

如果您想及時得到個人撰寫文章的消息推送,或者想看看個人推薦的技術資料,敬請關注。

在這裏插入圖片描述

0xFF 參考

[源碼解析] 深度學習分佈式訓練框架 Horovod (1) --- 基礎知識

[源碼解析] 深度學習分佈式訓練框架 horovod (2) --- 從使用者角度切入

[源碼解析] 深度學習分佈式訓練框架 horovod (3) --- Horovodrun背後做了什麼

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章