[源碼解析] 深度學習分佈式訓練框架 horovod (11) --- on spark --- GLOO 方案

[源碼解析] 深度學習分佈式訓練框架 horovod (11) --- on spark --- GLOO 方案

0x00 摘要

Horovod 是Uber於2017年發佈的一個易於使用的高性能的分佈式訓練框架,在業界得到了廣泛應用。

本系列將通過源碼分析來帶領大家瞭解 Horovod。本文是系列第十一篇,看看horovod 如何運行在 spark 之上(GLOO實現)。

Horovod on Spark 具體有兩種底層實現:MPI,GLOO。因爲篇幅所限,本文介紹 GLOO 實現。爲了單篇可以成文,所以本文和上文有部分重複,望諒解。

本系列其他文章如下:

[源碼解析] 深度學習分佈式訓練框架 Horovod (1) --- 基礎知識

[源碼解析] 深度學習分佈式訓練框架 horovod (2) --- 從使用者角度切入

[源碼解析] 深度學習分佈式訓練框架 horovod (3) --- Horovodrun背後做了什麼

[源碼解析] 深度學習分佈式訓練框架 horovod (4) --- 網絡基礎 & Driver

[源碼解析] 深度學習分佈式訓練框架 horovod (5) --- 融合框架

[源碼解析] 深度學習分佈式訓練框架 horovod (6) --- 後臺線程架構

[源碼解析] 深度學習分佈式訓練框架 horovod (7) --- DistributedOptimizer

[源碼解析] 深度學習分佈式訓練框架 horovod (8) --- on spark

[源碼解析] 深度學習分佈式訓練框架 horovod (9) --- 啓動 on spark

[源碼解析] 深度學習分佈式訓練框架 horovod (10) --- run on spark

0x01 回顧

1.1 總體序列圖

我們首先要回顧下 Horovod on Spark 的總體序列圖,需要注意的是:這個總體序列圖之中,從mpi_run開始,是 mpi 相關的實現,但本文是Gloo方案,所以會從 mpi_run 那裏開始不同

img

1.2 總體邏輯

總體來說,Horovod on Spark 的總體邏輯分爲以下階段:

  • 啓動 SparkDriverService 服務,利用 _make_spark_thread 啓動 Spark task,然後 horovod 會等待啓動結束;
  • 多線程在 spark executor 之中啓動 spark task,每個task之中運行一個 SparkTaskService,SparkTaskService 會向 hovorod 主進程中的 SparkDriverTask 進行註冊,並且等待下一步運行啓動的指令;
  • Horovod 收到所有 task 結束的信息之後,通知各個 task,進入下一階段;
  • Horovod 調用 mpi_run (又利用到 mpirun_rsh.py)在每一個 spark executor 上啓動 orted(這裏是通過 SparkTaskService 來啓動 orted),以啓動 MPI cluster;
  • orted 在每一個 spark executor 之上運行訓練代碼;

前文已經分析了前面三個階段,本文繼續後面兩個階段的分析。

0x02 第四階段 : 啓動 Job

下面我們看看第四階段,就是如何運行訓練 job。

2.1 GLOO VS MPI

本文的問題點就是:Gloo 與 MPI 實現有何不同

2.1.1 MPI 麻煩之處

MPI 麻煩之處是因爲:

  • 通常 MPI 會通過 SSH 來連接 hosts,但是這種方式無法在 Spark Executor 之中啓動 Python funtion。
  • Orted 需要運行在 Spark Executor 之中,但是 mpirun 在啓動時候,沒辦法知道 Spark Executor 的 IP : PORT 這個組合,所以沒法直接啓動。
  • 因此 MPI 使用RPC 來啓動用戶代碼:
    • 通過 SparkDriverService 和 SparkTaskService 等交互纔可以知道這個 IP : PORT 組合信息,即,在 Spark Executor 之中啓動 SparkTaskService ,然後把 SparkTaskService 的 IP : PORT 註冊到 Horovod 主進程的 SparkDriverService 之中。
    • 使用 horovod.spark.driver.mpirun_rsh 來連接每個 Executor,然後 “remote shell” 到這些 executors 之中。
    • 直接使用 SparkTaskService 來啓動 orted。

2.1.2 Gloo關鍵點

我們看看Gloo的關鍵點,在普通模式下,Gloo方案會:

  • 會創建一個帶有 KVStore 的 RendezvousServer,driver 會將參與通信的 worker 的 ip 等信息存入 KVstore 中。
  • 然後 worker 就可以調用 gloo 來訪問 RendezvousServer 構造通信環了。

Horovod on Spark 之中,關鍵點就是:

  • 如何構造RendezvousServer,RendezvousServer如何知道Executor(或者類似實體)的 ip:port?
  • Executor上的 SparkTaskService 如何與 RendezvousServer 溝通,從而知道自己和鄰居的網絡信息?

讓我們從代碼中尋求下答案。

2.2 回顧啓動過程

我們首先要回顧下之前的啓動過程。

Horovod.spark.run 的邏輯是:

  • 處理各種配置,比如timeout,nice...;
  • 獲取 spark 信息,比如從 pyspark 之中獲取SparkContext;
  • 構建驅動 SparkDriverService(Spark driver service);
  • 利用 _make_spark_thread 來啓動 spark executor(以及在每一個 spark executor 之中啓動一個SparkTaskService),這樣就構建了 cluster;
  • 每個 SparkTaskService 會通過 driver_service.SparkDriverClient.register_task 來向 horovod 中的 Driver 註冊這就是關鍵之處,通過這裏 RendezvousServer 就可以知道 SparkTaskService 的 IP :PORT
  • 利用 _notify_and_register_task_addresses 等待所有 spark task 都結束;
  • 利用 _launch_job 啓動訓練
  • 利用 spark_thread.join 來收集訓練結果;

以上關鍵點是:SparkTaskService 本身內部有一個 http server,會把自己的IP:PORT 信息註冊到Driver之中。

2.3 _launch_job

我們從_launch_job 開始分析。

_launch_job 很簡單:

  • 首先 driver.get_common_interfaces 獲取網絡路由信息,這個網絡路由信息就將被 RendezvousServer 記錄下來,最終將被 Executor上的 SparkTaskService 利用;
  • 其次 調用 run_contoller 來啓動 job;
def _launch_job(use_mpi, use_gloo, settings, driver, env, stdout=None, stderr=None):
    nics = driver.get_common_interfaces()
    # 在 gloo_run 調用時候傳輸網絡路由信息。
    run_controller(use_gloo, lambda: gloo_run(settings, nics, driver, env, stdout, stderr),
                   use_mpi, lambda: mpi_run(settings, nics, driver, env, stdout, stderr),
                   False, lambda: None,
                   settings.verbose)

2.3 獲取路由信息

Driver 的 get_common_interfaces 與普通模式下的 get_common_interfaces 不同。因爲此時,Spark Executor 之中的 SparkTaskService 的信息已經保存在 Driver 之中,直接獲取即可

def get_common_interfaces(self):
    if self._nics is not None:
        return self._nics

    nics = None
    if len(self._task_addresses_for_tasks) > 0:
        # in Elastic Horovod on Spark with auto-scaling
        # keys in task_addresses are in range(max_np or proc_num)
        # but not all keys may exist, so we don't do for index in range(proc_num)
        indices = list(self._task_addresses_for_tasks.keys())
        nics = set(self._task_addresses_for_tasks[indices[0]].keys())
        for index in indices[1:]:
            nics.intersection_update(self._task_addresses_for_tasks[index].keys())

    return nics

2.4 run_controller

就是依據配置和編譯情況來進行處理,選擇 gloo,js,還是 mpi。

def run_controller(use_gloo, gloo_run, use_mpi, mpi_run, use_jsrun, js_run, verbosity):
    if use_gloo:
        gloo_run()
    elif use_mpi:
        mpi_run()
    elif use_jsrun:
        js_run()
    else:
        if mpi_built(verbose=verbose):
            if lsf.LSFUtils.using_lsf() and is_jsrun_installed():
                js_run()
            else:
                mpi_run()
        elif gloo_built(verbose=verbose):
            gloo_run()

所以我們開始啓動 job,下面就 GLOO進行分析。

0x03 Gloo 實現

相比 MPI,Gloo 這部分就比較清晰了。

3.1 gloo_run

回到 2.3 run_controller

就是依據配置和編譯情況來進行處理,選擇 gloo,js,還是 mpi。

def run_controller(use_gloo, gloo_run, use_mpi, mpi_run, use_jsrun, js_run, verbosity):
    if use_gloo:
        gloo_run() # 本文調用到這裏
    elif use_mpi:
        mpi_run() # mpi會調用到這裏
    elif use_jsrun:
        js_run()
    else:
        if mpi_built(verbose=verbose):
            if lsf.LSFUtils.using_lsf() and is_jsrun_installed():
                js_run()
            else:
                mpi_run() # mpi會調用到這裏
        elif gloo_built(verbose=verbose):
            gloo_run() # 本文調用到這裏

如果是配置了gloo,則我們用到了 gloo_run:

def gloo_run(settings, nics, driver, env, stdout=None, stderr=None):
    """
    Run distributed gloo jobs.

    :param settings: Settings for running the distributed jobs.
                     Note: settings.num_proc and settings.hosts must not be None.
    :param nics: Interfaces to use by gloo.
    :param driver: The Spark driver service that tasks are connected to.
    :param env: Environment dictionary to use for running gloo jobs.  Can be None.
    :param stdout: Horovod stdout is redirected to this stream.
    :param stderr: Horovod stderr is redirected to this stream.
    """
    if env is None:
        env = {}

    # we don't want the key to be serialized along with settings from here on
    key = settings.key
    settings.key = None

    # Each thread will use SparkTaskClient to launch the job on each remote host. If an
    # error occurs in one thread, entire process will be terminated. Otherwise,
    # threads will keep running and ssh session.
    iface = list(nics)[0]
    server_ip = driver.addresses()[iface][0][0]
    # 這裏構建了需要執行的命令
    command = (sys.executable,
               '-m', 'horovod.spark.task.gloo_exec_fn', # 這個就是在task裏面運行的代碼
               codec.dumps_base64(driver.addresses()),
               codec.dumps_base64(settings))

    # 可以認爲_exec_command_fn這裏是一種執行命令的能力
    exec_command = _exec_command_fn(driver, key, settings, env,
                                    stdout, stderr, settings.prefix_output_with_timestamp)
    # 這裏傳入了路由信息
    launch_gloo(command, exec_command, settings, nics, {}, server_ip)

需要注意的是,這裏的 _exec_command_fn 如下,可以認爲_exec_command_fn這裏是一種執行命令的能力:

def _exec_command_fn(driver, key, settings, env, stdout, stderr, prefix_output_with_timestamp):
    def _exec_command(command, slot_info, events):
        host = slot_info.hostname #host名字
        local_rank = slot_info.local_rank # 本地rank
        verbose = settings.verbose
        # 用rsh封裝的運行能力
        result = rsh(driver.addresses(), key, host, command, env, local_rank, verbose,
                     stdout, stderr, prefix_output_with_timestamp, False, events)
        return result, time.time()
    return _exec_command

即調用了 from horovod.spark.driver.rsh import rsh。這裏是關鍵

3.2 launch_gloo

這裏主要是:

  • 首先,要注意,參數中,
    • command 大致爲:'python','-m','horovod.spark.task.gloo_exec_fn';
    • exec_command 大致爲:rsh xxxx。因爲exec_command可以認爲是一種利用rsh執行command的能力,所以這裏的xxx對應本文就是 “python -m horovod.spark.task.gloo_exec_fn”;
  • 建立了 RendezvousServer;
  • 構建了 slot_info_to_command,這裏指定了在哪一個slot上面運行;
  • 調用 execute_function_multithreaded 來使用多線程來運行命令;
def launch_gloo(command, exec_command, settings, nics, env, server_ip):
    """
    Launches the given command multiple times using gloo.
    Each command is launched via exec_command.

    :param command: command to launch
    :param exec_command: means to execute a single command
    :param settings: settings for the distribution
    :param nics: common interfaces
    :param env: environment to use
    :param server_ip: ip to use for rendezvous server
    """
		......
    
    # start global rendezvous server and get port that it is listening on
    # 建立 RendezvousServer,這個會被底層 Gloo C++ 環境使用到
    rendezvous = RendezvousServer(settings.verbose)

    # allocate processes into slots
    # 來根據host進行分配slot,就是horovod的哪個rank應該在哪個host上的哪個slot之上運行
    hosts = parse_hosts(settings.hosts)
    host_alloc_plan = get_host_assignments(hosts, settings.num_proc)

    # start global rendezvous server and get port that it is listening on
    global_rendezv_port = rendezvous.start()
    rendezvous.init(host_alloc_plan)
    # 獲取到可執行命令
    run_command = get_run_command(command, server_ip, nics, global_rendezv_port)

    # 得到在slot之上可執行的 slot command
    slot_info_to_command = _slot_info_to_command_fn(run_command, env)
    event = register_shutdown_event()
    # 依據 slot_info_to_command_fn 構建 args_list,這個 list 之中,每一個arg就是一個 slot command
    args_list = [[slot_info_to_command(slot_info), slot_info, [event]]
                 for slot_info in host_alloc_plan]

    # If an error occurs in one thread, entire process will be terminated.
    # Otherwise, threads will keep running.
    # 多線程執行,在每一個 exec_command 之上執行每一個 arg(slot command),args_list 包括 HOROVOD_GLOO_RENDEZVOUS_ADDR 等信息
    res = threads.execute_function_multithreaded(exec_command,
                                                 args_list,
                                                 block_until_all_done=True)

    ......

具體如下圖所示:

               launch_gloo( command ='python','+m','horovod.spark.task.gloo_exec_fn'
                    +       exec_command = rsh xxxx)
                    |
                    |
                    |
                    |
                    |
                    v
               RendezvousServer
                    +
                    |
                    |   get_run_command
                    |
                    |
                    v
 run_command = HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222
               HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo ......
               python +m horovod.spark.task.gloo_exec_fn
 exec_command = rsh xxxx

                    +
                    |
                    |   _slot_info_to_command_fn
                    |
                    v

slot_info_to_command = rank=0,local_rank=0,socket+ifname=eth0,cpu_operations=gloo......
                       HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222
                       HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo ......
                       python -m horovod.spark.task.gloo_exec_fn
        exec_command = rsh xxxx
                    +
                    |
                    |
                    |
                    v
               threads.execute_function_multithreaded
                    +
                    |
                    |
                    v

手機如下:

3.2.1 get_run_command

launch_gloo 代碼之中所用到的get_run_command十分關鍵,它會調用 create_run_env_vars 得到gloo需要信息,並據此構建 run_command,其格式如下:

HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222 HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo HOROVOD_CONTROLLER=gloo python train.py

代碼如下:

def get_run_command(command, server_ip, nics, port, elastic=False):
    env_vars = create_run_env_vars(server_ip, nics, port, elastic)
    env_string = " ".join(
        [f"{k}={str(v)}" for k, v in env_vars.items()])
    run_command = (
        '{env_string} '
        '{command}'  # expect a lot of environment variables
        .format(env_string=env_string,
                command=' '.join(quote(par) for par in command)))
    return run_command

3.2.2 create_run_env_vars

create_run_env_vars 函數會把 gloo 運行的相關信息構建出來,這些信息最後會傳給 Spark Executor。

def create_run_env_vars(server_ip, nics, port, elastic=False):
    run_envs = {
        'HOROVOD_GLOO_RENDEZVOUS_ADDR': server_ip,
        'HOROVOD_GLOO_RENDEZVOUS_PORT': port,
        'HOROVOD_CONTROLLER': "gloo",
        'HOROVOD_CPU_OPERATIONS': "gloo",
        'HOROVOD_GLOO_IFACE': list(nics)[0],   # TODO: add multiple ifaces in future
        'NCCL_SOCKET_IFNAME': ','.join(nics), # 這裏就是構建環需要的網絡路由信息
    }
    if elastic:
        run_envs["HOROVOD_ELASTIC"] = "1"
    return run_envs

3.3 rsh

在 execute_function_multithreaded 之中,調用了 rsh,並最終與 Spark Executor 交互。

具體會:

  • 獲取到 driver handle;
  • 利用driver handle調用 SparkDriverClient 獲取 task 相關信息;
  • 獲取 task handle;
  • 調用 SparkTaskClient 的 run_command 方法 來進行發送命令給 Spark Executor,這裏的參數 command 內容大致爲 “'python -m horovod.spark.task.gloo_exec_fn”;
  • 等待運行結果;

在調用 rsh 時候,command 會包括 類似 HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222 HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo 等信息,這樣 SparkDriverService 就知道如何構建 Ring 路由了

def rsh(driver_addresses, key, host_hash, command, env, local_rank, verbose,
        stdout=None, stderr=None, prefix_output_with_timestamp=False,
        background=True, events=None):
    """
    Method to run a command remotely given a host hash, local rank and driver addresses.

    This method connects to the SparkDriverService running on the Spark driver,
    retrieves all information required to connect to the task with given local rank
    of that host hash and invoke the command there.

    The method returns immediately after launching the command if background is True (default).
    When background is set to False, this method waits for command termination and returns
    command's result. If there is an exception while waiting for the result (i.e. connection reset)
    it returns -1.

    :param driver_addresses: driver's addresses
    :param key: used for encryption of parameters passed across the hosts
    :param host_hash: host hash to connect to
    :param command: command and arguments to invoke
    :param env: environment to use
    :param local_rank: local rank on the host of task to run the command in
    :param verbose: verbosity level
    :param stdout: Task stdout is redirected to this stream.
    :param stderr: Task stderr is redirected to this stream.
    :param prefix_output_with_timestamp: shows timestamp in stdout/stderr forwarding on the driver if True
    :param background: run command in background if True, returns command result otherwise
    :param events: events to abort the command, only if background is True
    :return exit code if background is False
    """
    if ':' in host_hash:
        raise Exception('Illegal host hash provided. Are you using Open MPI 4.0.0+?')

    # 獲取到 driver handle    
    driver_client = driver_service.SparkDriverClient(driver_addresses, key, verbose=verbose)
    # 利用配置確定是哪一個task來運行
    task_indices = driver_client.task_host_hash_indices(host_hash)
    task_index = task_indices[local_rank]
    task_addresses = driver_client.all_task_addresses(task_index)
    # 獲取task handle
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key, verbose=verbose)
    task_client.stream_command_output(stdout, stderr)
    # 要求task運行命令,command就是 python -m horovod.spark.task.gloo_exec_fn
    task_client.run_command(command, env,
                            capture_stdout=stdout is not None,
                            capture_stderr=stderr is not None,
                            prefix_output_with_timestamp=prefix_output_with_timestamp)

    if not background:
        events = events or []
        stop = threading.Event()
        for event in events:
            on_event(event, task_client.abort_command, stop=stop)
        try:
            exit_code = task_client.wait_for_command_exit_code()
            return exit_code
        except:
            traceback.print_exc()
            return -1
        finally:
            stop.set()

所以,此時邏輯如下,最終在spark executor 運行python -m horovod.spark.task.gloo_exec_fn

                                                                                                          Horovod Job    +    Spark Host
                                                                                                                         |
SparkDriverService                           horovod.spark.run                                                           |                    SparkTaskService
         +                                        +                                                                      |                           +
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                                                                                               |                           |
         |                                   launch_gloo( command ='python','+m','horovod.spark.task.gloo_exec_fn'       |                           |
         |                                        +       exec_command = rsh xxxx)                                       |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                   RendezvousServer                                                            |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |   get_run_command                                                    |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                    run_command = HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222       |                           |
         |                                  HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo ......                     |                           |
         |                                  python +m horovod.spark.task.gloo_exec_fn                                    |                           |
         |                    exec_command = rsh xxxx                                                                    |                           |
         |                                                                                                               |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |   _slot_info_to_command_fn                                           |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                                                                                               |                           |
         |                    slot_info_to_command = rank=0,local_rank=0,socket+ifname=eth0,cpu_operations=gloo......    |                           |
         |                                        HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222 |                           |
         |                                        HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo ......               |                           |
         |                                        python +m horovod.spark.task.gloo_exec_fn                              |                           |
         |                            exec_command = rsh xxxx                                                            |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                   threads.execute_function_multithreaded                                      |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                       rsh                                                                     |                           |
         |                                        +                                                                      |                           |
         |  <----------------------------------+  |                                                                      |                           |
         |      task_host_hash_indices            |                                                                      |                           |
         |                                        |                                                                      |                           |
         |  <----------------------------------+  |                     run_command(command, env)                        |    RunCommandRequest      |
         |      all_task_addresses                |                                                                      |                           |
         |                                        | +--------------------------------------------------------------------------------------------->  |
         |                                        |                                                                      |                           +
         |                                        |                                                                      |                      run command
         |                                        |                                                                      |                           +
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         v                                        v                                                                      |                           |
                                                                                                                         +                           v

手機如下:

3.4 gloo_exec_fn

注意,此時已經在 Spark Host 上的 Executor 中運行了。

gloo_exec_fn 就對應了前面 mpi版本的 mpirun_exec_fn

spark 在 Executor 上運行 horovod.spark.task.gloo_exec_fn

horovod.spark.task.gloo_exec_fn 內容如下:

from horovod.spark.task import task_exec
from horovod.runner.common.util import codec

def main(driver_addresses, settings):
    task_exec(driver_addresses, settings, 'HOROVOD_RANK', 'HOROVOD_LOCAL_RANK')

if __name__ == '__main__':
    if len(sys.argv) != 3:
        print('Usage: %s <driver addresses> <settings>' % sys.argv[0])
        sys.exit(1)
    main(codec.loads_base64(sys.argv[1]), codec.loads_base64(sys.argv[2]))

0x04 第五階段 : 運行用戶代碼

task_exec 函數就是運行用戶代碼進行訓練。

task_exec 位於:horovod/spark/task/__init__.py

具體會:

  • 調用 SparkDriverClient 獲取 task 相關信息;
  • 調用 SparkTaskClient 來進行獲取 用戶代碼;
  • 執行用戶代碼等等。
def task_exec(driver_addresses, settings, rank_env, local_rank_env):
    # Die if parent process terminates
    in_thread(target=_parent_process_monitor, args=(os.getppid(),))

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    rank = int(os.environ[rank_env])
    local_rank = int(os.environ[local_rank_env])
    driver_client = driver_service.SparkDriverClient(driver_addresses, key,
                                                     verbose=settings.verbose)

    # tell driver about local rank and rank
    # in elastic mode the driver already knows this mapping
    # for simplicity we keep code paths the same for elastic and static mode
    host_hash = os.environ['HOROVOD_HOSTNAME']
    task_index = driver_client.set_local_rank_to_rank(host_hash, local_rank, rank)

    # gather available resources from task service
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key,
                                               verbose=settings.verbose)
    task_info.set_resources(task_client.resources())

    fn, args, kwargs = driver_client.code()
    result = fn(*args, **kwargs)
    task_client.register_code_result(result)

最終代碼如下:

                                                                                                          Horovod Job    +    Spark Host
                                                                                                                         |
SparkDriverService                           horovod.spark.run                                                           |                    SparkTaskService
         +                                        +                                                                      |                           +
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                                                                                               |                           |
         |                                   launch_gloo( command ='python','+m','horovod.spark.task.gloo_exec_fn'       |                           |
         |                                        +       exec_command = rsh xxxx)                                       |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                   RendezvousServer                                                            |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |   get_run_command                                                    |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                     run_command = rendevous_addr, rendevous_port python -m horovod.spark.task.gloo_exec_fn    |                           |
         |                    exec_command = rsh xxxx                                                                    |                           |
         |                                                                                                               |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |   _slot_info_to_command_fn                                           |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                                                                                               |                           |
         |                    slot_info_to_command = rank=0,local_rank=0,socket+ifname=eth0,cpu_operations=gloo......    |                           |
         |                                     rendevous_addr, rendevous_port python -m horovod.spark.task.gloo_exec_fn  |                           |
         |                            exec_command = rsh xxxx                                                            |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                   threads.execute_function_multithreaded                                      |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                       rsh                                                                     |                           |
         |                                        |                                                                      |                           |
         |  <----------------------------------+  |                                                                      |                           |
         |      task_host_hash_indices            |                                                                      |                           |
         |                                        |                                                                      |                           |
         |  <----------------------------------+  |                     run_command(command, env)                        |    RunCommandRequest      |
         |      all_task_addresses                |                                                                      |                           |
         |                                        | +--------------------------------------------------------------------------------------------->  |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                      run command
         |                                        |                                                                      |                           +
         |                                        |                                                                      |      code()               |
         |  <-------------------------------------------------------------------------------------------------------------------------------------+  |
         |                                        |                                                                      |                           |
         |  +------------------------------------------------------------------------------------------------------------------------------------->  |
         |                                        |                                                                      |  code  of gloo_exec_fn    |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                     gloo_exec_fn
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                      task_exec
         v                                        |                                                                      |                           |
                                                  v                                                                      |                           |
                                                                                                                         +                           v

手機如下:

0x05 總結

在普通模式下,Gloo方案會:

  • 創建一個帶有 KVStore 的 RendezvousServer,driver 會將參與通信的 worker 的 ip 等信息存入 KVstore 中。
  • 然後 worker 就可以調用 gloo 來訪問 RendezvousServer 構造通信環了。

Horovod on Spark via GLOO 之中,關鍵點就是:

  • 如何構造RendezvousServer,RendezvousServer如何知道Executor的 ip:port?
    • 答案爲:
      • 在 Horovod 的 driver 之中,會創建RendezvousServer。
      • 在之前的初始化過程中,每個 SparkTaskService 會通過 driver_service.SparkDriverClient.register_task 來向 horovod 中的 Driver 註冊這就是關鍵之處,通過這裏 RendezvousServer 就可以知道 SparkTaskService 的 IP :PORT
  • Executor上的 SparkTaskService 如何與 RendezvousServer 溝通,從而知道自己和鄰居的網絡信息?
    • 答案爲:
      • 在 execute_function_multithreaded 之中,調用了 rsh,並最終與 Spark Executor 交互。
      • 在調用 rsh 時候,會把類似 HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222 HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo 信息傳遞過去,此信息中包括了 RendezvousServer 的地址,這樣 Spark Executor 中的 SparkTaskService 就知道了如何找到RendezvousServer,進而就會知道如何構建 ring。

至此,Horovod on spark解析完畢,從下一篇開始解析彈性訓練。

0xEE 個人信息

★★★★★★關於生活和技術的思考★★★★★★

微信公衆賬號:羅西的思考

如果您想及時得到個人撰寫文章的消息推送,或者想看看個人推薦的技術資料,敬請關注。

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章