[源碼解析] 深度學習分佈式訓練框架 horovod (10) --- run on spark

[源碼解析] 深度學習分佈式訓練框架 horovod (10) --- run on spark

0x00 摘要

Horovod 是Uber於2017年發佈的一個易於使用的高性能的分佈式訓練框架,在業界得到了廣泛應用。

本系列將通過源碼分析來帶領大家瞭解 Horovod。本文是系列第十篇,看看horovod 如何運行在 spark 之上。

Horovod on Spark 具體有兩種底層實現:MPI,GLOO。因爲篇幅所限,本文介紹 MPI 實現,下一篇介紹GLOO實現。

本系列其他文章如下:

[源碼解析] 深度學習分佈式訓練框架 Horovod (1) --- 基礎知識

[源碼解析] 深度學習分佈式訓練框架 horovod (2) --- 從使用者角度切入

[源碼解析] 深度學習分佈式訓練框架 horovod (3) --- Horovodrun背後做了什麼

[源碼解析] 深度學習分佈式訓練框架 horovod (4) --- 網絡基礎 & Driver

[源碼解析] 深度學習分佈式訓練框架 horovod (5) --- 融合框架

[源碼解析] 深度學習分佈式訓練框架 horovod (6) --- 後臺線程架構

[源碼解析] 深度學習分佈式訓練框架 horovod (7) --- DistributedOptimizer

[源碼解析] 深度學習分佈式訓練框架 horovod (8) --- on spark

[源碼解析] 深度學習分佈式訓練框架 horovod (9) --- 啓動 on spark

0x01 回顧

1.1 總體序列圖

接上文,我們首先要回顧下 Horovod on Spark 的總體序列圖,這樣腦子裏有一個全景,溫故而知新。

img

1.2 總體邏輯

總體來說,Horovod on Spark 的總體邏輯分爲以下階段:

  • 啓動 SparkDriverService 服務,利用 _make_spark_thread 啓動 Spark task,然後 horovod 會等待啓動結束;
  • 多線程在 spark executor 之中啓動 spark task,每個task之中運行一個 SparkTaskService,SparkTaskService 會向 hovorod 主進程中的 SparkDriverTask 進行註冊,並且等待下一步運行啓動的指令;
  • Horovod 收到所有 task 結束的信息之後,通知各個 task,進入下一階段;
  • Horovod 調用 mpi_run (又利用到 mpirun_rsh.py)在每一個 spark executor 上啓動 orted(這裏是通過 SparkTaskService 來啓動 orted),以啓動 MPI cluster;
  • orted 在每一個 executor 之上運行訓練代碼;

前文已經分析了前面三個階段,本文繼續後面兩個階段的分析。

1.3 問題

結合上面的流程圖,這裏就有一個問題會令人疑惑。

Horovod 按說可以直接調用 mpirun 來在遠端啓動 orted(orted 就是 mpi 可執行程序。mpirun 是 orterun 的別名,而 ortedrun 會最終調用到 orted)。但是爲什麼流程圖上不是直接調用,而是通過 mpirun_rsh.py,進而通過 SparkTaskService 來啓動 orted?

原因應該是:

  • 通常 MPI 會通過 SSH 來連接 hosts,但是這種方式無法在 Spark Executor 之中啓動 Python function。
  • Orted 需要運行在 Spark Executor 之中,但是 mpirun 在啓動時候,沒辦法知道 Spark Executor 的 IP : PORT 這個組合,所以沒法直接啓動。
  • 因此 MPI 使用RPC 來啓動用戶代碼:
    • 通過 SparkDriverService 和 SparkTaskService 等交互纔可以知道這個 IP : PORT 組合信息。
    • 使用 horovod.spark.driver.mpirun_rsh 來連接每個 Executor,然後 "remote shell" 到這些 executors 之中。
    • 直接使用 SparkTaskService 來啓動 orted。

0x02 第四階段 : 啓動 Job

下面我們看看第四階段,就是如何運行 訓練 job。

2.1 _launch_job

_launch_job 很簡單:

  • 首先 driver.get_common_interfaces 獲取網絡路由信息;
  • 其次 調用 run_contoller 來啓動 job;
def _launch_job(use_mpi, use_gloo, settings, driver, env, stdout=None, stderr=None):
    nics = driver.get_common_interfaces()
    run_controller(use_gloo, lambda: gloo_run(settings, nics, driver, env, stdout, stderr),
                   use_mpi, lambda: mpi_run(settings, nics, driver, env, stdout, stderr),
                   False, lambda: None,
                   settings.verbose)

2.2 獲取路由信息

get_common_interfaces 與普通模式下的 get_common_interfaces 不同。因爲此時,Spark Executor 之中的 SparkTaskService 的信息已經保存在 Driver 之中,直接獲取即可。

def get_common_interfaces(self):
    if self._nics is not None:
        return self._nics

    nics = None
    if len(self._task_addresses_for_tasks) > 0:
        # in Elastic Horovod on Spark with auto-scaling
        # keys in task_addresses are in range(max_np or proc_num)
        # but not all keys may exist, so we don't do for index in range(proc_num)
        indices = list(self._task_addresses_for_tasks.keys())
        nics = set(self._task_addresses_for_tasks[indices[0]].keys()) # 直接獲取
        for index in indices[1:]:
            nics.intersection_update(self._task_addresses_for_tasks[index].keys())

    return nics

2.3 run_controller

就是依據配置和編譯情況來進行處理,選擇 gloo,js,還是 mpi。

def run_controller(use_gloo, gloo_run, use_mpi, mpi_run, use_jsrun, js_run, verbosity):
    if use_gloo:
        gloo_run()
    elif use_mpi:
        mpi_run()
    elif use_jsrun:
        js_run()
    else:
        if mpi_built(verbose=verbose):
            if lsf.LSFUtils.using_lsf() and is_jsrun_installed():
                js_run()
            else:
                mpi_run()
        elif gloo_built(verbose=verbose):
            gloo_run()

所以我們開始啓動 job,具體我們分爲 MPI,GLOO兩種進行分析。

0x03 MPI 實驗

我們首先要做一些 MPI 相關實驗,其原因是因爲:

  • MPI 的調用之中有些看起來很奇怪的行爲,或者說是一些 trick。
  • 這些 trick 對於 "horovod on spark" 基於 MPI 的實現是很有幫助,但是對於理解代碼卻是一個極大的干擾
  • 我們暫時沒有時間和精力去研究 MPI 的源碼是如何實現的,因爲已經超出了本文範疇。

所以我們只能針對某些奇怪的行爲,對 MPI 的相關實現機制做一些假設和估計。然後通過一個簡單的實驗來驗證我們的假設。

3.1 問題點

我們執行的 mpi 命令格式如下,這個命令格式就是爲了模擬Horovod的 MPI 命令:

mpirun --allow-run-as-root -n 4 --hostfile ./remote_hosts -mca plm_rsh_agent "python rsh.py" python user_function.py

問題點就是:

  • plm_rsh_agent "python rsh.py" 的作用是什麼?
  • rsh.py 之中,有哪些 trick?如何調用遠程 mpi 程序?
  • python user_function.py 是在 rsh.py 之後運行嗎?

3.2 名詞解釋

3.2.1 orterun & orted

最開始在看到這個命令時候,容易讓人很暈。因爲代碼中沒有任何提及。

其實,outed 就是 mpi 可執行程序。

mpirun 是 orterun 的別名,而 ortedrun 會最終調用到 orted

具體解釋如下,信息來源爲 http://cn.voidcc.com/question/p-wkloammx-bha.html:

mpirunmpiexec基本上是相同的 - 許多MPI實現中的進程啓動器的名稱。 MPI標準沒有提到如何啓動和控制等級,但它建議(儘管不要求),如果有任何類型的啓動器,它應該被命名爲mpiexec。一些MPI實現以mpirun開始,然後採用mpiexec以實現兼容性。其他實現則相反。最後,大多數實現都使用兩個名稱來提供它們的啓動器。在實踐中,mpirunmpiexec所做的事情應該沒有什麼不同。

不同的MPI實現有不同的啓動和控制過程的方法。 MPICH從一個名爲MPD(多用途守護進程或其他)的基礎架構開始。然後切換到新的Hydra流程管理器。由於Hydra的功能與MPD不同,因此基於Hydra的mpiexec採用的命令行參數不同於基於MPD的命令行參數,並且使用戶可以明確選擇基於Hydra的命令行參數,因此它可用作mpiexec.hydra。舊的稱爲mpiexec.mpd。可能有一個基於MPICH的MPI庫只提供Hydra啓動程序,然後mpiexecmpiexec.hydra將是相同的可執行文件。英特爾MPI基於MPICH,其新版本使用Hydra進程管理器。

Open MPI建立在開放運行環境(ORTE)的基礎上,其自身的進程啓動器被稱爲orterun。爲了兼容,orterun也符號鏈接爲mpirunmpiexec

總結:

  • mpiexec.something是MPI進程啓動的給定實現的特定版本
  • mpiexecmpirun是通用名稱的符號鏈接到實際發射通常副本或
  • mpiexecmpirun應該這樣做
  • 某些實現命名他們的發射器mpiexec,有些人命名它mpirun,有人將其命名爲兩者,當系統路徑中同時有多個MPI實現可用時,這通常是混淆的來源(例如,當從發行版安裝時)

3.2.2 mpi orterun 源碼

mpi之中 orterun 對應的源碼如下,最主要是調用了 orte_submit_job 提交 job。

int orterun(int argc, char *argv[])
{
    orte_submit_status_t launchst, completest;

    /* orte_submit_init() will also check if the user is running as
       root (and may issue a warning/exit). */
    if (ORTE_SUCCESS != orte_submit_init(argc, argv, NULL)) {
        exit(1);
    }

    /* setup to listen for commands sent specifically to me, even though I would probably
     * be the one sending them! Unfortunately, since I am a participating daemon,
     * there are times I need to send a command to "all daemons", and that means *I* have
     * to receive it too
     */
    orte_rml.recv_buffer_nb(ORTE_NAME_WILDCARD, ORTE_RML_TAG_DAEMON,
                            ORTE_RML_PERSISTENT, orte_daemon_recv, NULL);

    /* if the user just wants us to terminate a DVM, then do so */
    if (orte_cmd_options.terminate_dvm) {
        // 省略部分代碼
    } else {
        /* spawn the job and its daemons */
        memset(&launchst, 0, sizeof(launchst));
        memset(&completest, 0, sizeof(completest));
        launchst.active = true;
        completest.active = true;
      
        // 在這裏進行提交 job
        if (ORTE_SUCCESS != orte_submit_job(argv, NULL,
                                            launched, &launchst,
                                            completed, &completest)) {
            ORTE_UPDATE_EXIT_STATUS(1);
            goto DONE;
        }
    }

    // wait for response and unpack the status, jobid
    // 省略部分代碼
}

3.3 實驗設計

3.3.1 組件

有如下幾個組件,其作用分別如下:

  • host 文件。作用是指定本次運行有哪些host,以及host之上運行幾個MPI進程。
  • rsh.py。作用是作爲 rsh agent 來給遠端機器下達命令。
    • MPI 用戶也可以通過其他方式給遠程機器下發命令。
    • 用戶可以對每個主機使用遠程 shell(sshrsh)而無需登錄主機。默認情況下,mpirun 使用 ssh
    • 如果 mpirun 使用 ssh 出現問題,可以嘗試在 mpirun 命令中使用 --mca plm_rsh_agent rsh 選項,以使用 rsh 命令進行連接。
  • user_function.py。就是用戶希望執行的函數。

3.3.2 host 文件 remote_hosts

remote_hosts 文件內容如下:

1.1.1.1:2
2.2.2.2:2

其意義是:

  • 1.1.1.1 這個 ip 運行 2 個 slot,即兩個 MPI 進程。
  • 2.2.2.2 這個 ip 運行 2 個 slot,即兩個 MPI 進程。

3.3.3 rsh.py

rsh.py 內容如下,作用就是打印 MPI 傳入的 command,然後在遠端host之上啓動的 MPI 進程中運行新命令:

import os
import sys
import subprocess

if __name__ == '__main__':
  command = " ".join(sys.argv[0:])
  print(command)
  new_command = " ".join(sys.argv[2:])
  print(new_command)
  subprocess.Popen(new_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

3.3.4 user_function.py

內容如下,就是爲了測試,打印一條語句。

print('hello world')

3.4 實驗結果

我們在 1.1.1.1 之上運行 mpi 命令。

mpirun --allow-run-as-root -n 4 --hostfile ./remote_hosts -mca plm_rsh_agent "python rsh.py" python user_function.py

結果如下:

# 以下是 command 內容,就是 MPI 傳遞給 rsh.py 的內容,這裏居然有 plm_rsh_agent "python rsh.py" 
rsh.py 1.1.1.1 orted -mca ess "env" -mca ess_base_jobid "41114481152" -mca ess_base_vpid 1 -mca ess_base_num_procs "4" -mca ored_node_regex "ip-[2:1]-1-1-1,[2:1]2.2.2@0(2)" -mca orted_hnp_uri "41114481152.0,tcp://1.1.1.1:53405" -mca plm "rsh" --tree-spawn -mca orte_parent_uri "41114481152.0,tcp://1.1.1.1:53405"  -mca plm_rsh_agent "python rsh.py" -mca pmix "^s1,s2,cray,isolated"

# 以下是 new_command 內容,就是 在遠端host上執行 用戶代碼 的方法,這裏居然有 plm_rsh_agent "python rsh.py" 
orted -mca ess "env" -mca ess_base_jobid "41114481152" -mca ess_base_vpid 1 -mca ess_base_num_procs "4" -mca ored_node_regex "ip-[2:1]-1-1-1,[2:1]2.2.2@0(2)" -mca orted_hnp_uri "41114481152.0,tcp://1.1.1.1:53405" -mca plm "rsh" --tree-spawn -mca orte_parent_uri "41114481152.0,tcp://1.1.1.1:53405"  -mca plm_rsh_agent "python rsh.py" -mca pmix "^s1,s2,cray,isolated"

# 以下是 user_function.py 的執行內容
hello world

因此我們知道

  • plm_rsh_agent "python rsh.py" 的作用是在遠端運行 MPI orted。
  • python user_function.py 是在 rsh 之後運行的,而且是在遠端的 orted 之中運行。
  • 在 rsh.py 執行過程中,其接受到的命令內容有些奇怪

3.5 運行過程

運行過程如下:

  1. mpirun 運行 mpirun -mca plm_rsh_agent "python rsh.py" python user_function.py,此時在遠端會運行一個 MPI daemon,用來響應處理;
  2. mpirun 調用了 rsh.py;
  3. rsh.py 使用 subprocess(orted -mca plm_rsh_agent "python rsh.py") 在遠端啓動 orted(會與 daemon 溝通),運行用戶代碼;

具體如下圖:

                                                         1.1.1.1        +          2.2.2.2
                                                                        |
                                                                        |
                                                                        |  1      +---------------+
mpirun -mca plm_rsh_agent "python rsh.py" python user_function.py  +----------->  |  MPI deamon   |
                                                                        |         +-------+-------+
             +                                                          |                 |
             |                                                          |                 |
             | 2                                                        |                 |
             |                                                          |                 |  3
             |                                                          |                 |
             |  rsh.py 1.1.1.1 orted -mca plm_rsh_agent "python rsh.py" |                 |
             |                                                          |                 v
             |                                                          |         +-------+--------------------------+
             |                                                          |         | orted                            |
             |                                                          |         |                                  |
             v                                                          |         |                                  |
+------------+------------------------------------------------------+   |         |   +---------------------------+  |
| rsh.py                                                            |   |         |   | user_function.py          |  |
|                                                                   |   |         |   |                           |  |
|    rsh.py 1.1.1.1 orted -mca plm_rsh_agent "python rsh.py"        |   |         |   |                           |  |
|                                                                   |   |   3     |   |      print('hello world') |  |
|    subprocess(orted -mca plm_rsh_agent "python rsh.py") +-------------------->  |   |                           |  |
|                                                                   |   |         |   +---------------------------+  |
+-------------------------------------------------------------------+   +         +----------------------------------+

手機如下:

3.6 Trick 分析

我們發現有幾個奇怪的點:

  • mpirun 運行 mpirun -mca plm_rsh_agent "python rsh.py" python user_function.py
  • mpirun 調用了 rsh.py,但是在 rsh.py 收到的 argv 中,居然也有 plm_rsh_agent "python rsh.py" 。按說這時候不應該有這個參數了,因爲 rsh.py 已經調用了,就不應該再有這個參數
  • rsh.py 運行遠端 MPI,使用的是 orted -mca plm_rsh_agent "python rsh.py",這裏居然還有 plm_rsh_agent "python rsh.py" 這個參數。這時候也不應該,因爲 orted 已經運行在遠端了,這時候也傳入一個用來遠端控制的 rsh agent 參數,太奇怪了

就是說plm_rsh_agent "python rsh.py" 這個參數居然被 MPI 傳遞到各個階段,無論是 rsh agent 或者 遠端 mpi

rsh agent 就是 trick。不知道 MPI 爲什麼要把 plm_rsh_agent "python rsh.py" 在各個階段傳遞的意圖,可能是爲了更好的控制。

因爲沒有精力來分析 MPI 源碼,所以初步判斷,遠端 MPI daemon 在運行 orted -mca plm_rsh_agent "python rsh.py"時候,會判斷是否已經是遠端,如果是遠端,就不再運行 rsh agent 了。

所以,我們在後面分析中,在 Spark task 之中 發現 類似 plm_rsh_agent "python rsh.py" ,就不用再疑惑了

0x04 MPI 實現

一般來說,Horovod on Spark 是以 MPI 模式來運行,所以我們重點看這裏。

4.1 mpi_run in spark

mpi_run 代碼位於:horovod/spark/mpi_run.py,作用是:

  • 依據各種配置生成remote shell的agent;
  • 依據各種配置生成可執行命令;
  • 調用hr_mpi_run(horovod.runner.mpi_run 就是普通模式下的 mpi_run)運行命令;

比如得到 rsh_agent 大致如下:

("/usr/bin/python", "-m", "horovod.spark.driver.mpirun_rsh", "xxxxx", "yyy")

得到 command 大致如下:

("/usr/bin/python", "-m", "horovod.spark.task.mpirun_exec_fn", "xxxxx", "yyy")

具體代碼如下:

from horovod.runner.mpi_run import mpi_run as hr_mpi_run

def mpi_run(settings, nics, driver, env, stdout=None, stderr=None):
    """
    Runs mpirun.

    :param settings: Settings for running MPI.
                     Note: settings.num_proc and settings.hosts must not be None.
    :param nics: Interfaces to include by MPI.
    :param driver: The Spark driver service that tasks are connected to.
    :param env: Environment dictionary to use for running MPI.  Can be None.
    :param stdout: Stdout of the mpi process.
                   Only used when settings.run_func_mode is True.
    :param stderr: Stderr of the mpi process.
                   Only used when settings.run_func_mode is True.
    """
    env = {} if env is None else copy.copy(env)  # copy env so we do not leak env modifications

    # Pass secret key through the environment variables.
    env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(settings.key)
    # we don't want the key to be serialized along with settings from here on
    settings.key = None

    # 拼接出rsh_agent
    rsh_agent = (sys.executable,
                 '-m', 'horovod.spark.driver.mpirun_rsh',
                 codec.dumps_base64(driver.addresses()),
                 codec.dumps_base64(settings))
    settings.extra_mpi_args = ('{extra_mpi_args} -x NCCL_DEBUG=INFO -mca plm_rsh_agent "{rsh_agent}"'
                               .format(extra_mpi_args=settings.extra_mpi_args if settings.extra_mpi_args else '',
                                       rsh_agent=' '.join(rsh_agent)))
    # 拼接出command
    command = (sys.executable,
               '-m', 'horovod.spark.task.mpirun_exec_fn',
               codec.dumps_base64(driver.addresses()),
               codec.dumps_base64(settings))
    hr_mpi_run(settings, nics, env, command, stdout=stdout, stderr=stderr)

4.2 mpi_run in normal

上面代碼最後是運行 hr_mpi_run,其實 hr_mpi_run 是 horovod.runner.mpi_run,就是普通模式下的 mpi_run。

horovod.runner.mpi_run 首先 就是依據各種配置以及參數來構建 mpirun 命令的所有參數,比如 ssh 的參數,mpi 參數,nccl 參數等等。

得到了 command 大致如下:

mpirun --allow-run-as-root --map-by slot -x SSH_CONNECITION -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx" /usr/bin/python -m horovod.spark.task.mpurun_exec_fn xxxxx

具體代碼如下:

def mpi_run(settings, nics, env, command, stdout=None, stderr=None):
    """
    Runs mpi_run.

    Args:
        settings: Settings for running MPI.
                  Note: settings.num_proc and settings.hosts must not be None.
        nics: Interfaces to include by MPI.
        env: Environment dictionary to use for running command.
        command: Command and arguments to run as a list of string.
        stdout: Stdout of the mpi process.
                Only used when settings.run_func_mode is True.
        stderr: Stderr of the mpi process.
                Only used when settings.run_func_mode is True.
    """

    # 獲取mpi相關配置
    mpi_impl_flags, impl_binding_args, mpi = _get_mpi_implementation_flags(settings.tcp_flag, env=env)
    impi = _IMPI_IMPL == mpi

    # 獲取ssh配置
    ssh_args = []
    if settings.ssh_port:
        ssh_args += [f'-p {settings.ssh_port}']
    if settings.ssh_identity_file:
        ssh_args += [f'-i {settings.ssh_identity_file}']

    mpi_ssh_args = ''
    if ssh_args:
        joined_ssh_args = ' '.join(ssh_args)
        mpi_ssh_args = f'-bootstrap=ssh -bootstrap-exec-args \"{joined_ssh_args}\"' if impi else f'-mca plm_rsh_args \"{joined_ssh_args}\"'

    # 網卡相關信息
    tcp_intf_arg = '-mca btl_tcp_if_include {nics}'.format(
        nics=','.join(nics)) if nics and not impi else ''
    nccl_socket_intf_arg = '-{opt} NCCL_SOCKET_IFNAME={nics}'.format(
        opt='genv' if impi else 'x',
        nics=','.join(nics)) if nics else ''

    # On large cluster runs (e.g. Summit), we need extra settings to work around OpenMPI issues
    host_names, host_to_slots = hosts.parse_hosts_and_slots(settings.hosts)
    if not impi and host_names and len(host_names) >= _LARGE_CLUSTER_THRESHOLD:
        mpi_impl_flags.append('-mca plm_rsh_no_tree_spawn true')
        mpi_impl_flags.append('-mca plm_rsh_num_concurrent {}'.format(len(host_names)))

    # if user does not specify any hosts, mpirun by default uses local host.
    # There is no need to specify localhost.
    hosts_arg = '-{opt} {hosts}'.format(opt='hosts' if impi else 'H',
                hosts=','.join(host_names) if host_names and impi else settings.hosts)

    ppn_arg = ' '
    if host_to_slots and impi:
        ppn = host_to_slots[host_names[0]]
        for h_name in host_names[1:]:
            if ppn != host_to_slots[h_name]:
                raise Exception('''Different slots in -hosts parameter are not supported in Intel(R) MPI.
                                 Use -machinefile <machine_file> for this purpose.''')
        ppn_arg = ' -ppn {} '.format(ppn)

    if settings.prefix_output_with_timestamp and not impi:
        mpi_impl_flags.append('--timestamp-output')

    binding_args = settings.binding_args if settings.binding_args and not impi else ' '.join(impl_binding_args)

    basic_args = '-l' if impi else '--allow-run-as-root --tag-output'

    output = []
    if settings.output_filename:
        output.append('-outfile-pattern' if impi else '--output-filename')
        output.append(settings.output_filename)

    env_list = '' if impi else ' '.join(
                    '-x %s' % key for key in sorted(env.keys()) if env_util.is_exportable(key))

    # Pass all the env variables to the mpirun command.
    mpirun_command = (
        'mpirun {basic_args} '
        '-np {num_proc}{ppn_arg}{hosts_arg} '
        '{binding_args} '
        '{mpi_args} '
        '{mpi_ssh_args} '
        '{tcp_intf_arg} '
        '{nccl_socket_intf_arg} '
        '{output_filename_arg} '
        '{env} {extra_mpi_args} {command}'  # expect a lot of environment variables
        .format(basic_args=basic_args,
                num_proc=settings.num_proc,
                ppn_arg=ppn_arg,
                hosts_arg=hosts_arg,
                binding_args=binding_args,
                mpi_args=' '.join(mpi_impl_flags),
                tcp_intf_arg=tcp_intf_arg,
                nccl_socket_intf_arg=nccl_socket_intf_arg,
                mpi_ssh_args=mpi_ssh_args,
                output_filename_arg=' '.join(output),
                env=env_list,
                extra_mpi_args=settings.extra_mpi_args if settings.extra_mpi_args else '',
                command=' '.join(quote(par) for par in command))
    )

    # we need the driver's PATH and PYTHONPATH in env to run mpirun,
    # env for mpirun is different to env encoded in mpirun_command
    for var in ['PATH', 'PYTHONPATH']:
        if var not in env and var in os.environ:
            # copy env so we do not leak env modifications
            env = copy.copy(env)
            # copy var over from os.environ
            env[var] = os.environ[var]

    # Execute the mpirun command.
    if settings.run_func_mode:
        exit_code = safe_shell_exec.execute(mpirun_command, env=env, stdout=stdout, stderr=stderr)
    else:
        os.execve('/bin/sh', ['/bin/sh', '-c', mpirun_command], env)

4.3 執行命令

目前得到的命令是:

mpirun --allow-run-as-root --map-by slot -x SSH_CONNECITION -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx" /usr/bin/python -m horovod.spark.task.mpurun_exec_fn xxxxx

所以我們接着分析。

當 mpi_run 準備好命令之後,他調用 safe_shell_exec.execute 或者 bin/sh 執行命令。對於 safe_shell_exec.execute 來說,它需要執行的命令是:

mpirun --allow-run-as-root --map-by slot -x SSH_CONNECITION -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx" /usr/bin/python -m horovod.spark.task.mpurun_exec_fn xxxxx

這樣,就是先調用 safe_shell_exec.execute 或者 bin/sh 執行 "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx",然後執行 horovod.spark.task.mpurun_exec_fn xxxxx。

4.3.1 mpi 參數

對於 mpirun 來說,參數 --mca pls_rsh_agent rsh 告訴節點間通訊用rsh。

這樣我們就知道 horovod.spark.driver.mpirun_rsh 就是在節點通訊時候,首先執行的腳本。

就是說,當 mpirun 想在異地節點運行一個 程序(horovod.spark.task.mpurun_exec_fn) 時候,首先運行 horovod.spark.driver.mpirun_rsh 從而在異地節點上啓動一個 orted,其次在這個 異地 orted 之上運行 horovod.spark.task.mpurun_exec_fn

4.3.3 mpirun_rsh.py

所以,horovod.spark.driver.mpirun_rsh 會最先運行,我們需要首先看看,就是下圖中最下面部分

mpirun_rsh.py 的作用如其註釋所言,目的是被 MPI 調用以便連接到一個 host,並且執行指定的命令

命令通常是 orted ,用來創建 MPI cluster。orted 進程然後被用來啓動遠端進程(Horovod 用戶的 Python方法)。 orted 進程將運行在最低index的 task上,同一個host 的其他task將執行 no-op 並且等待 orted task 結束

Method run by MPI to connect to a host hash and execute the given command.

The command is usually `orted` to setup the MPI cluster. That `orted` process
is then used to spin-up the actual remote process, the Horovod user's Python method.
The `orted` process will run on the lowest task index and all other tasks with the
same host hash are expected to no-op (see `horovod.spark._task_fn`)
and wait for the first task to terminate.

但是實際上代碼其實很簡單,就是直接調用了 rsh,所以我們還得接着看。

if len(sys.argv) < 5:
    print('Usage: %s <service addresses> <settings> <host hash> '
          '<command...>' % sys.argv[0])
    sys.exit(1)

addresses = codec.loads_base64(sys.argv[1])
key = codec.loads_base64(os.environ.get(secret.HOROVOD_SECRET_KEY))
settings = codec.loads_base64(sys.argv[2])
host_hash = sys.argv[3]
command = " ".join(sys.argv[4:])
env = {}  # orted does not need any env vars, the target training code gets env from mpirun

# Since tasks with the same host hash have shared memory,
# we will run only one orted process on the first task.
rsh(addresses, key, host_hash, command, env, 0, settings.verbose) # 直接調用

4.3.4 rsh

這裏纔是上述邏輯的具體實現,所以rsh 的作用就是:

  • 與在 Spark Driver 上運行的 SparkDriverService 進行交互,從 SparkDriverService 獲取需要運行 task 的所需信息;
  • 與 Spark Executor 中的 SparkTaskService 交互,運行 command;

具體到代碼就是:

  • 利用 driver_client.task_host_hash_indices(host_hash) 從在 Spark Driver 上運行的 SparkDriverService 獲取某一個 host 上的所有 task;
  • 利用 task_indices[local_rank] 獲取到對應的 task;
  • 利用 driver_client.all_task_addresses(task_index) 獲取 task 的地址;
  • 利用 task_service.SparkTaskClient.run_command 來運行 command;

command 舉例如下,此時 command 已經被 mpirun 處理轉義

/usr/local/bin/orted -mca ess "env" -mcc ess_base_num_procs "2" -mca orte_hnp_uri "xxxx" -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx"

具體代碼是:

def rsh(driver_addresses, key, host_hash, command, env, local_rank, verbose,
        stdout=None, stderr=None, prefix_output_with_timestamp=False,
        background=True, events=None):
    """
    Method to run a command remotely given a host hash, local rank and driver addresses.

    This method connects to the SparkDriverService running on the Spark driver,
    retrieves all information required to connect to the task with given local rank
    of that host hash and invoke the command there.

    The method returns immediately after launching the command if background is True (default).
    When background is set to False, this method waits for command termination and returns
    command's result. If there is an exception while waiting for the result (i.e. connection reset)
    it returns -1.

    :param driver_addresses: driver's addresses
    :param key: used for encryption of parameters passed across the hosts
    :param host_hash: host hash to connect to
    :param command: command and arguments to invoke
    :param env: environment to use
    :param local_rank: local rank on the host of task to run the command in
    :param verbose: verbosity level
    :param stdout: Task stdout is redirected to this stream.
    :param stderr: Task stderr is redirected to this stream.
    :param prefix_output_with_timestamp: shows timestamp in stdout/stderr forwarding on the driver if True
    :param background: run command in background if True, returns command result otherwise
    :param events: events to abort the command, only if background is True
    :return exit code if background is False
    """
    driver_client = driver_service.SparkDriverClient(driver_addresses, key, verbose=verbose)
    task_indices = driver_client.task_host_hash_indices(host_hash)
    task_index = task_indices[local_rank]
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key, verbose=verbose)
    task_client.stream_command_output(stdout, stderr)
    task_client.run_command(command, env,
                            capture_stdout=stdout is not None,
                            capture_stderr=stderr is not None,
                            prefix_output_with_timestamp=prefix_output_with_timestamp)

    if not background:
        events = events or []
        stop = threading.Event()
        for event in events:
            on_event(event, task_client.abort_command, stop=stop)

        try:
            exit_code = task_client.wait_for_command_exit_code()
            return exit_code
        except:
            traceback.print_exc()
            return -1
        finally:
            stop.set()

4.3.5 發送命令

具體 run_command 就是向 SparkTaskService 發送 RunCommandRequest。

class BasicTaskClient(network.BasicClient):

  def run_command(self, command, env,
                    capture_stdout=False, capture_stderr=False,
                    prefix_output_with_timestamp=False):
        self._send(RunCommandRequest(command, env,
                                     capture_stdout, capture_stderr,
                                     prefix_output_with_timestamp))

具體如下圖邏輯所示:

與之前的測試代碼對比如下:

                                                   Our test code     +    Horovod on spark
                                                                     |
                                                                     |
 mpirun -mca plm_rsh_agent "python rsh.py" python user_function.py   |    mpirun pls_rsh_agent "python mpirun_rsh" python -m mpurun_exec_fn
                                                                     |
          +                                                          |           +
          |                                                          |           |
          |  rsh.py 1.1.1.1 orted -mca plm_rsh_agent "python rsh.py" |           |    orted -mca pls_rsh_agent "python -m mpirun_rsh"
          |                                                          |           |
          v                                                                      v
+----------------------------------------------------------------+   |    +------+---------------------------------------------------+
| rsh.py (via SSH)                                               |   |    | mpirun_rsh                                               |
|                                                                |   |    |                                                          |
|    rsh.py 1.1.1.1 orted -mca plm_rsh_agent "python rsh.py"     |   |    +------+---------------------------------------------------+
|                                                                |   |           |
|                                                                |   |           |
|                                                                |   |           v
|                                                                |   |    +----------------------------------------------------------+
|                                                                |   |    | rsh (via RPC)                                            |
|                                                                |   |    |                                                          |
|    subprocess(orted -mca plm_rsh_agent "python rsh.py")        |   |    |                                                          |
|                                                                |   |    |  task_client = task_service.SparkTaskClient              |
|                                                                |   |    |                                                          |
|                                                                |   |    |  task_client.run_command(                                |
|                                                                |   |    |       orted -mca pls_rsh_agent "python -m mpirun_rsh"    |
|                                                                |   |    |  )                                                       |
+---------+------------------------------------------------------+   |    +------+---------------------------------------------------+
          |                                                          |           |
          |                                                          |           |
          v                                                          |           v
+---------+------------------------------------------------------+   |    +------+---------------------------------------------------+
| user_function.py                                               |   |    | mpirun_exec_fn.py                                        |
|                                                                |   |    |                                                          |
|    print('hello world')                                        |   |    |              task_exec +--------> user_function          |
|                                                                |   |    |                                                          |
+----------------------------------------------------------------+   |    +----------------------------------------------------------+
                                                                     +

手機如下:

因此,下面就會進入到 spark executor 去運行

4.4 Run in Spark Executor

再次注意,這裏已經是 遠端的 Spark Executor 了

上節提到,系統會利用 task_service.SparkTaskClient.run_command 來運行command;

command 舉例如下,此時 command 已經被 mpirun 處理轉義:

/usr/local/bin/orted -mca ess "env" -mcc ess_base_num_procs "2" -mca orte_hnp_uri "xxxx" -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx"

需要依據上圖留意一點:系統在 Spark Executor 上執行 command 之後,會接着運行 mpirun_exec_fn

我們接下來就看看如何處理 RunCommandRequest。具體都是在 BasicTaskService 之中完成。

4.4.1 RunCommandRequest

可以看到,接受到消息之後,是調用 _run_command 完成。

def _handle(self, req, client_address):
    if isinstance(req, RunCommandRequest):
        self._wait_cond.acquire()
        try:
            if self._command_thread is None:
                # we add req.env to _command_env and make this available to the executed command
                if self._command_env:
                    env = self._command_env.copy()
                    self._add_envs(env, req.env)
                    req.env = env

                # We only permit executing exactly one command, so this is idempotent.
                self._command_abort = threading.Event()
                self._command_stdout = Pipe() if req.capture_stdout else None
                self._command_stderr = Pipe() if req.capture_stderr else None
                args = (req.command, req.env, self._command_abort,
                        self._command_stdout, self._command_stderr,
                        self._index,
                        req.prefix_output_with_timestamp)
                self._command_thread = in_thread(self._run_command, args)
        finally:
            self._wait_cond.notify_all()
            self._wait_cond.release()
        return network.AckResponse()

4.4.2 _run_command

_run_command 就是調用 safe_shell_exec.execute 直接運行。

def _run_command(self, command, env, event,
                 stdout=None, stderr=None, index=None,
                 prefix_output_with_timestamp=False):
    self._command_exit_code = safe_shell_exec.execute(
        command,
        env=env,
        stdout=stdout, stderr=stderr,
        index=index,
        prefix_output_with_timestamp=prefix_output_with_timestamp,
        events=[event])
    if stdout:
        stdout.close()
    if stderr:
        stderr.close()

因此,接下來就是在 Spark Executor 之中,開始執行

/usr/local/bin/orted -mca ess "env" -mcc ess_base_num_procs "2" -mca orte_hnp_uri "xxxx" -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx"

注意,此時是在 Spark Executor 之中,所以接下來行爲會和之前不同。

4.4.3 mpirun_rsh

mpirun_rsh 依然是調用 rsh。

addresses = codec.loads_base64(sys.argv[1])
key = codec.loads_base64(os.environ.get(secret.HOROVOD_SECRET_KEY))
settings = codec.loads_base64(sys.argv[2])
host_hash = sys.argv[3]
command = " ".join(sys.argv[4:])
env = {}  # orted does not need any env vars, the target training code gets env from mpirun

# Since tasks with the same host hash have shared memory,
# we will run only one orted process on the first task.
rsh(addresses, key, host_hash, command, env, 0, settings.verbose)

4.4.4 rsh

代碼如下:

def rsh(driver_addresses, key, host_hash, command, env, local_rank, verbose,
        stdout=None, stderr=None, prefix_output_with_timestamp=False,
        background=True, events=None):
    driver_client = driver_service.SparkDriverClient(driver_addresses, key, verbose=verbose)
    task_indices = driver_client.task_host_hash_indices(host_hash)
    task_index = task_indices[local_rank]
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key, verbose=verbose)
    task_client.stream_command_output(stdout, stderr)
    task_client.run_command(command, env,
                            capture_stdout=stdout is not None,
                            capture_stderr=stderr is not None,
                            prefix_output_with_timestamp=prefix_output_with_timestamp)

但是此時運行就出現了與之前的不同之處

此時在 Spark Executor 再次調用

/usr/local/bin/orted -mca ess "env" -mcc ess_base_num_procs "2" -mca orte_hnp_uri "xxxx" -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx"

回憶一下 0x03 MPI 實驗 的結果,我們知道,pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx" 這部分在遠端上其實不會有實際效果遠端 orted 轉而會繼續運行傳過來的 mpirun_exec_fn

如果哪位朋友對 MPI 有深入瞭解,還望賜教

4.4.5 mpirun_exec_fn

代碼位於:horovod/spark/task/mpirun_exec_fn.py。

就是調用到了task_exec。

def main(driver_addresses, settings):
    # prepend HOROVOD_SPARK_PYTHONPATH to PYTHONPATH
    if 'HOROVOD_SPARK_PYTHONPATH' in os.environ:
        ppath = os.environ['HOROVOD_SPARK_PYTHONPATH']

        # add injected HOROVOD_SPARK_PYTHONPATH to sys.path
        for p in reversed(ppath.split(os.pathsep)):
            sys.path.insert(1, p)  # don't put it in front which is usually .

        if 'PYTHONPATH' in os.environ:
            ppath = os.pathsep.join([ppath, os.environ['PYTHONPATH']])
        os.environ['PYTHONPATH'] = ppath

    # change current working dir to where the Spark worker runs
    # because orted runs this script where mpirun was executed
    # this env var is injected by the Spark task service
    work_dir = os.environ.get('HOROVOD_SPARK_WORK_DIR')
    if work_dir:
        os.chdir(work_dir)

    task_exec(driver_addresses, settings, 'OMPI_COMM_WORLD_RANK', 'OMPI_COMM_WORLD_LOCAL_RANK')

0x05 第五階段 : 運行用戶代碼

5.1 task_exec

task_exec 就是用來運行用戶代碼。

可以看到,是從 Driver 之中取出之前存儲的代碼,然後運行。

def task_exec(driver_addresses, settings, rank_env, local_rank_env):
    # Die if parent process terminates
    in_thread(target=_parent_process_monitor, args=(os.getppid(),))

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    rank = int(os.environ[rank_env])
    local_rank = int(os.environ[local_rank_env])
    driver_client = driver_service.SparkDriverClient(driver_addresses, key,
                                                     verbose=settings.verbose)

    # tell driver about local rank and rank
    # in elastic mode the driver already knows this mapping
    # for simplicity we keep code paths the same for elastic and static mode
    host_hash = os.environ['HOROVOD_HOSTNAME']
    task_index = driver_client.set_local_rank_to_rank(host_hash, local_rank, rank)

    # gather available resources from task service
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key,
                                               verbose=settings.verbose)
    task_info.set_resources(task_client.resources())

    fn, args, kwargs = driver_client.code()
    result = fn(*args, **kwargs)
    task_client.register_code_result(result)

5.2 獲取訓練代碼

在MapReduce之中,則是把Jar包(二進制庫)分發到各個節點,然後各個節點執行jar包之中相應的代碼。其實這樣很不方便。

Spark提出了函數序列化功能,可以很好的解決這個問題,這是Spark對分佈式編程的一個貢獻。Spark系統會把你寫的那些自定義函數(你的業務功能)自動序列化到各個節點去執行。函數序列化發送功能給Spark帶來的另外好處是:用戶可以使用spark-shell在命令行直接寫分佈式代碼,實時操作,實時得到結果。

比如,初始化/協調等工作是在Driver程序中進行,但是代碼實際執行是在Worker節點中的Executor中進行。當Executor端執行時需要用到Driver端封裝的class對象時,Driver端就需要把Driver端的class對象通過序列化傳輸到Executor端,這個class對象則需要實現Serializable序列化方法。

Horovod on spark 這裏就是直接傳輸 python 訓練代碼原始文本,因爲 pyton 的腳本特性,所以可以直接運行代碼原始文本

獲取訓練代碼的函數如下,在 SparkDriverClient 類之中就是給 Driver 發送 CodeRequest 請求:

def code(self):
    resp = self._send(CodeRequest())
    return resp.fn, resp.args, resp.kwargs

在 SparkDriverService 之中,收到 CodeRequest 請求之後,會進行處理。

if isinstance(req, CodeRequest):
    return CodeResponse(self._fn, self._args, self._kwargs)

就是把 SparkDriverService 之前存儲的訓練代碼 _fn 以及其參數一起發給 SparkTaskService。

class CodeResponse(object):
    def __init__(self, fn, args, kwargs):
        self.fn = fn
        """Function."""

        self.args = args
        """Function args."""

        self.kwargs = kwargs
        """Function kwargs."""

最終邏輯大致如下:

+---------------------------------+                     +---------------------------------+
| Horovod Main thread             |                     | Spark Executor                  |
|                                 |                     |                                 |
|                                 |                     |                                 |
|  +-------------------------+    |       1 register    |        +----------------------+ |
|  |     SparkDriverService  | <---------------------------------+  SparkTaskService    | |
|  |                         |    |                     |        |                      | |
|  |                         |    |      2 notify start |        |                      | |
|  |                         | +-------------------------------> |                      | |
|  |                         |    |                     |        |                      | |
|  |                         |    |                     |        |                      | |
|  |                         |    | 3 RunCommandRequest |        |                      | |
|  |                         | +--------------------------------------> orted mpirun_rsh| |
|  |                         |    |                     |        |        +             | |
|  |                         |    |                     |        |        | 4           | |
|  |                         |    |                     |        |        |             | |
|  |                         |    |                     |        |        v             | |
|  |                         |    |                     |        |      task_exec       | |
|  |                         |    |                     |        |        +             | |
|  |                         |    |                     |        |        | 5           | |
|  |                         |    |                     +        |        |             | |
|  |                         |    |6 set_local_rank_to_rank      |        v             | |
|  |                         | +------------------------+---------> SparkTaskClient     | |
|  |                         |    |                     |        |                      | |
|  |                         |    |    7 code()         |        |                      | |
|  |                         | +---------------------------------------> 8 fn()         | |
|  |                         |    |                     |        |                      | |
|  +-------------------------+    |                     |        +----------------------+ |
+---------------------------------+                     +---------------------------------+

手機如下:

至此,spark on MPI 分析結束,我們下文介紹 spark on GLOO。

0xEE 個人信息

★★★★★★關於生活和技術的思考★★★★★★

微信公衆賬號:羅西的思考

如果您想及時得到個人撰寫文章的消息推送,或者想看看個人推薦的技術資料,敬請關注。

在這裏插入圖片描述

0xFF

mpirun,mpiexec和mpiexec.hydra有什麼區別和關係?

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章