[源碼解析] PyTorch 分佈式之彈性訓練(2)---啓動&單節點流程

[源碼解析] PyTorch 分佈式之彈性訓練(2)---啓動&單節點流程

0x00 摘要

在前面的文章之中,我們已經學習了PyTorch 分佈式的基本模塊,介紹了官方的幾個例子,我們接下來會介紹PyTorch的彈性訓練,本文是第二篇,重點關注的是如何啓動彈性訓練,並且可以對系統總體架構有所瞭解。

彈性訓練系列文章如下:

[源碼解析] PyTorch 分佈式之彈性訓練(1) --- 總體思路

0x01 重要概念

爲了更好的說明(這個說明可能在後面文章也會出現,因爲太重要了),我們先總述一下TE 最重要的 Agent 和 Rendezvous 兩個概念。

  • Agent :Agent是運行在單節點上的獨立後臺進程,可以認爲是 worker manager 或者 process supervisor,其負責啓動worker,監控 worker 運行,捕獲woker異常,通過 rendezvous 實現 worker 間的相互發現(比如把狀態上報到KVStore),成員變動時候基於 rendezvous 進行變更同步等等。
  • Rendezvous :爲了實現彈性訓練,需要有一個節點/進程之間彼此發現的機制。Rendezvous就是這個發現機制或者說同步組件。當系統啓動或者成員變更時候,所有worker會(重新)集合(rendezvous)以建立一個新的進程組。

我們從源碼中取出示意圖看看,大家先有一個總體概念。

0x02 分佈式運行

2.1 方式改變

2.1.1 原有方式

我們知道,PET是從 PyTorch v1.9 合併進來的,因爲合併了彈性訓練,所以分佈式啓動的方式有了很大的改變。

V1.9 之前是使用 torch/distributed/launch.py 進行啓動,比如:

python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
           --nnodes=2 --node_rank=0 --master_addr="192.168.1.1"
           --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
           and all other arguments of your training script)

此處參數含義是:

  • nnodes :是參與訓練的節點數目。
  • nproc_per_node :每個節點上運行的進程數目。
  • node_rank :當前節點標識符。
  • master_addrmaster_port 是 master 監聽的地址和端口。

當運行時,torch.distributed.launch 會設置一些環境變量,包括 world_sizemaster_addrmaster_port 等等。然後在當前機器上創建 nproc_per_node 個進程,這些進程構成了一個本地組。如果一共有 NODE_SIZE 個機器參與訓練,則一共有 NODE_SIZE * TRAINERS_PER_NODE 個進程。如果想啓動一個分佈式訓練任務,則需要在所有的機器上執行相關命令。

2.1.2 目前方式

PyTorch 1.9 使用 torch/distributed/run.py 進行啓動。如果依然採用 torch/distributed/launch.py,其實其內部已經透傳給 run.py,具體參見代碼:

def main(args=None):
    logger.warn(
        "The module torch.distributed.launch is deprecated "
        "and going to be removed in future."
        "Migrate to torch.distributed.run"
    )
    args = parse_args(args)
    run(args)

torch.distributed.run是之前torch.distributed.launch的一個超集,提供如下新功能:

  • 容錯:通過重新啓動所有workers,可以優雅地處理worker故障。
  • 自動:Worker 的RANKWORLD_SIZE 是自動分配的
  • 彈性:允許在最小值和最大值(彈性)之間更改節點數。

爲了使用彈性訓練,用戶代碼也需要做一些修改,如果用戶的訓練腳本已經支持 torch.distributed.launch ,則只需要修改幾處就可以使用torch.distributed.run

  • 無需手動傳遞RANK , WORLD_SIZE , MASTER_ADDR 和 MASTER_PORT。
  • 必須提供rdzv_backendrdzv_endpoint。對於大多數用戶來說,這其實就是“c10d”(參見“rendezvous“)。其實這就替代了之前的MASTER_ADDR 和 MASTER_PORT。
  • use_env 參數已被刪除。請從 LOCAL_RANK 環境變量中獲取local_rank (例如,os.environ["LOCAL_RANK"])。
  • 用戶需要確保腳本中有 load_checkpoint(path)save_checkpoint(path) 邏輯,即手動處理Checkpoint。因爲當worker失敗時,我們將使用最近的checkpoint來恢復現場,重啓所有worker。

下面是一個訓練腳本的示例,該腳本在每個epoch上設置檢查點,因此在失敗時最差也只是會丟失一個epoch的訓練成果。

  def main():
       args = parse_args(sys.argv[1:])
       state = load_checkpoint(args.checkpoint_path)
       initialize(state)

       # torch.distributed.run ensure that this will work
       # by exporting all the env vars needed to initialize the process group
       torch.distributed.init_process_group(backend=args.backend)

       for i in range(state.epoch, state.total_num_epochs)
            for batch in iter(state.dataset)
                train(batch, state.model)

            state.epoch += 1
            save_checkpoint(state)

所以,我們接下來看看在新模式之下,如何分佈式啓動。

2.2 部署

部署一般按照如下方式。

  1. (C10d後端不需要)啓動 rendezvous 後端服務器,並獲取端點(作爲--rdzv_endpoint傳遞給啓動程序腳本)
  2. 單節點多 worker:在主機上啓動 launcher 以啓動代理進程,代理會創建並監視本地工作組。
  3. 多節點多 worker:在所有節點上使用相同的參數啓動 launcher 參加訓練。

當使用作業/羣集管理器時,多節點作業的入口點命令應爲 launcher。

2.3 示例

我們首先通過幾個例子來看看如何啓動分佈式訓練。

2.3.1 單節點多worker啓動

單節點多worker的啓動方式如下,其實就是Standalone 模式,這是分佈式模式的一種特例,具體就是針對單機多 Worker 提供了一些便利設置。

python -m torch.distributed.run
        --standalone
        --nnodes=1
        --nproc_per_node=$NUM_TRAINERS
        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

2.3.2 容錯方式啓動

如下是容錯方式啓動,固定數目workers,沒有彈性訓練。 --nproc_per_node=$NUM_TRAINERS 一般是 單節點上GPU 個數。

python -m torch.distributed.run
        --nnodes=$NUM_NODES
        --nproc_per_node=$NUM_TRAINERS
        --rdzv_id=$JOB_ID
        --rdzv_backend=c10d
        --rdzv_endpoint=$HOST_NODE_ADDR
        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

HOST_NODE_ADDR, 的格式是: [:] ,指定了 C10d rendezvous 後端所運行的節點地址和端口,這個節點可以是訓練集羣中任意節點,但是最好找一個高帶寬的節點。

2.3.3 彈性方式啓動

下面是彈性訓練,彈性區間爲 (min=1, max=4)。通過指定rdzv參數,可以實現多機訓練,具備容錯與彈性能力

在多臺機器上分別執行以下命令啓動:最小節點數爲MIN_SIZE,最大爲MAX_SIZE,利用etcd服務實現一致性和信息同步。

python -m torch.distributed.run
        --nnodes=1:4
        --nproc_per_node=$NUM_TRAINERS
        --rdzv_id=$JOB_ID
        --rdzv_backend=c10d
        --rdzv_endpoint=$HOST_NODE_ADDR
        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

HOST_NODE_ADDR, 的格式是: [:] ,指定了 C10d rendezvous 後端所運行的節點地址和端口,這個節點可以是訓練集羣中任意節點,但是最好找一個高帶寬的節點。

關於 rendezvous backend,有幾點說明:

對於多節點訓練,需要指定:

  • --rdzv_id: 一個唯一的 job id,在參與job的所有節點之間共享。
  • --rdzv_backend: torch.distributed.elastic.rendezvous.RendezvousHandler 的一個實現。 (--rdzv_backend默認是static模式,不支持容錯和彈性伸縮)
  • --rdzv_endpoint: rendezvous backend 所運行的 endpoint,通常格式爲:host:port。就是取代了之前的 master address / port 設置。

目前,以下幾種後端可以直接使用,c10d (推薦), etcd-v2, and etcd (legacy) 。爲了使用 etcd-v2 或者 etcd,需要搭建一個 v2 api開啓的 etcd server (即. --enable-v2)。

0x03 啓動腳本

既然以上啓動都是用 torch/distributed/run.py,所以我們仔細分析一下這個腳本,該腳本提供三個功能:

  • 依靠"重啓所有 workers"來處理 worker 失敗;

  • 自動分配 worker 的RANK and WORLD_SIZE

  • 彈性訓練,即 node 數目允許在minimum和maximum之間改變;

3.1 參數定義

啓動腳本中,一些參數定義如下:

  • Node - 物理實例或容器;映射到與 job manager 所協調的單元。
  • Worker - 分佈式訓練環境中的worker。
  • WorkerGroup - 執行相同功能的一組worker(例如trainers)。
  • LocalWorkerGroup - 在同一節點上運行的工作組中的workers子集。
    • 一個節點運行 LOCAL_WORLD_SIZE個workers,這些 workers 組成LocalWorkerGroup
    • 節點上所有LocalWorkerGroups組成WorkerGroups
  • RANK - 工作組中worker的rank,是全局rank,可以認爲是一個全局GPU資源列表。
    • Rank是不穩定的,在重啓之間,本地Workers 會被分配到不同的ranks,所以不要在代碼中對RANKLOCAL_RANK的穩定性做任何假設和依賴編碼。
    • rendezvous完成後,其所有成員將對工作成員資格以及每個人在其中的角色(role)達成共識。此角色(role)使用一個介於 0 ~ world size 之間的整型來表示,被稱之爲rank。
  • LOCAL_RANK - 本地工作組中,某個worker 的 rank,可以認爲是當前節點上的GPU資源列表。
  • GROUP_RANK - worker group的rank。介於0和“最大節點數”之間的數字。如果每個節點運行一個單一工作組,那GROUP_RANK就是這個節點的rank。
  • ROLE_RANK - 對於具有相同角色worker來說,他們之間共享的rank,角色在“WorkerSpec”中被指定。
  • WORLD_SIZE - 工作組中worker的總數。因爲節點會加入/離開,所以WORLD_SIZE會變化,不能依賴 WORLD_SIZE的穩定性進行編碼。
  • LOCAL_WORLD_SIZE - 本地工作組的大小,即本地運行的worker數目,等於在torch.distributed.run運行時候指定的--nproc_per_node。目前,torch/distributed/run.py 僅支持同構的 LOCAL_WORLD_SIZE。也就是說,假設所有節點運行相同數量的本地工作者(每個角色)。
  • ROLE_WORLD_SIZE - 具有同樣角色的workers總數,在 WorkerSpec之中被指定。
  • rdzv_id - 用戶定義的id,用於唯一標識作業的工作組。這個id在每個節點加入特定工作組時候使用。
  • rdzv_backend-rendezvous 的後端(例如“c10d”)。這通常是一個強一致性的鍵值存儲。
  • rdzv_endpoint - rendezvous 後端端點;通常以“<host>:<port>”的形式出現。
  • run_id: 用戶定義的id,它唯一地標識分佈式應用程序的一個實例。它通常映射到作業id並用於允許節點加入正確的分佈式應用程序。
  • TORCHELASTIC_RUN_ID - 與 rendezvous run_id 相等,即唯一的job id。
  • TORCHELASTIC_RESTART_COUNT - 迄今爲止,工作組重啓的次數。
  • TORCHELASTIC_MAX_RESTARTS - 配置的最大重啓數目。

3.2 相關函數/變量

爲了更好的理解上面的參數,我們選取部分相關函數/變量看看。

world_size,rank

這兩個變量是動態生成的,所以從 state 之中取出。

rank, world_size = self._get_world()
    
def _get_world(self) -> Tuple[int, int]:
	state = self._state_holder.state
	return state.participants[self._this_node], len(state.participants)

_pg_group_ranks

該全局變量存儲了每個 group 的 global rank 到 local rank 映射信息。

# Process group's global rank to local rank mapping
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}

其賦值舉例如下:

# Create the global rank to group rank mapping
_pg_group_ranks[pg] = {
    global_rank: group_rank
    for group_rank, global_rank in enumerate(ranks)
}

group_rank

我們可以利用 global rank 從 _pg_group_ranks 之中提取對應的 local rank。

def _get_group_rank(group: ProcessGroup, rank):
    """
    Helper that gets a given group's local rank in the group from a given global
    rank.
    """
    if group is GroupMember.WORLD:
        raise RuntimeError("group.WORLD does not have local rank to global "
                           "rank mapping")
    if group not in _pg_group_ranks:
        raise RuntimeError("The given group does not exist")
    try:
        group_rank = _pg_group_ranks[group][rank]
    except KeyError:
        raise RuntimeError(f"The global rank {rank} is not part of the group {group}") from None
    return group_rank

global_rank

我們可以利用一個 group 的 local rank 獲取到其 gloabl rank。

def _get_global_rank(group, group_rank):
    """
    Helper that gets a given group's global rank from a given local rank in the
    group.
    """
    if group is GroupMember.WORLD:
        raise RuntimeError("group.WORLD does not have local rank to global "
                           "rank mapping")
    group_rank_map = _pg_group_ranks[group]
    for rank, grp_rank in group_rank_map.items():
        if grp_rank == group_rank:
            return rank
    raise RuntimeError("The group rank is not part of the group")

group_size

我們可以 _get_group_size 獲取到某一個group 的大小。

def _get_group_size(group):
    """
    Helper that gets a given group's world size.
    """
    if group is GroupMember.WORLD or group is None:
        default_pg = _get_default_group()
        return default_pg.size()
    if group not in _pg_group_ranks:
        raise RuntimeError("The given group does not exist")
    return len(_pg_group_ranks[group])

nproc_per_node

這個變量可以得到每個node之上支持多少個進程。

def determine_local_world_size(nproc_per_node: str):
    try:
        logging.info(f"Using nproc_per_node={nproc_per_node}.")
        return int(nproc_per_node)
    except ValueError:
        if nproc_per_node == "cpu":
            num_proc = os.cpu_count()
            device_type = "cpu"
        elif nproc_per_node == "gpu":
            if not torch.cuda.is_available():
                raise ValueError("Cuda is not available.")
            device_type = "gpu"
            num_proc = torch.cuda.device_count()
        elif nproc_per_node == "auto":
            if torch.cuda.is_available():
                num_proc = torch.cuda.device_count()
                device_type = "gpu"
            else:
                num_proc = os.cpu_count()
                device_type = "cpu"
        else:
            raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}")
        )
        return num_proc

3.3 腳本入口

腳本入口主要代碼如下,可以看到,其調用到了 elastic_launch 來完成功能,所以我們下一節就要順藤摸瓜來看看這個函數。

from torch.distributed.launcher.api import LaunchConfig, elastic_launch

def run(args):
    if args.standalone: # 有兩種模式:Standalone 模式和分佈式模式,這裏要判斷一下
        args.rdzv_backend = "c10d"
        args.rdzv_endpoint = "localhost:29400"
        args.rdzv_id = str(uuid.uuid4())
        log.info(
            f"\n**************************************\n"
            f"Rendezvous info:\n"
            f"--rdzv_backend={args.rdzv_backend} "
            f"--rdzv_endpoint={args.rdzv_endpoint} "
            f"--rdzv_id={args.rdzv_id}\n"
            f"**************************************\n"
        )

    config, cmd, cmd_args = config_from_args(args)
    elastic_launch(
        config=config,
        entrypoint=cmd,
    )(*cmd_args)


def main(args=None):
    args = parse_args(args)
    run(args)


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO, format="[%(levelname)s] %(asctime)s %(module)s: %(message)s"
    )
    main()

0x04 單體總體流程

我們下面就從 elastic_launch 開始,看看在單節點上如何啓動運行。我們首先給出一個總體示意圖,圖上是兩個節點,每個節點有一個 agent,agent下面是一個 worker group,組下面是4個worker。

4.1 小例子

我們再從源碼中找一個例子來看看,這裏只是設置了兩個workers。

import uuid
import torch
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

def worker_fn(t1, t2):
    return torch.add(t1, t2)

def main():
    t1 = torch.rand((3,3), requires_grad=True)
    t2 = torch.rand((3, 3), requires_grad=True)

    config = LaunchConfig(
        min_nodes=2,
        max_nodes=4,
        nproc_per_node=1,
        run_id=str(uuid.uuid4()),
        role="trainer",
        rdzv_endpoint="localhost:29400",
        rdzv_backend="c10d",
        max_restarts=1,
        monitor_interval=1,
        start_method="spawn",
    )

    outputs = elastic_launch(config, worker_fn)(t1, t2)

if __name__ == '__main__':
    main()

輸出如下,可以看到有兩個 worker 進程 和一個 agent 進程。

{"name": "torchelastic.worker.status.SUCCEEDED", "source": "WORKER", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": 0, "group_rank": 0, "worker_id": "12172", "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\", \"local_rank\": [0], \"role_rank\": [0], \"role_world_size\": [2]}", "agent_restarts": 0}}

{"name": "torchelastic.worker.status.SUCCEEDED", "source": "WORKER", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": 1, "group_rank": 0, "worker_id": "3276", "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\", \"local_rank\": [1], \"role_rank\": [1], \"role_world_size\": [2]}", "agent_restarts": 0}}

{"name": "torchelastic.worker.status.SUCCEEDED", "source": "AGENT", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": null, "group_rank": 0, "worker_id": null, "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\"}", "agent_restarts": 0}}

4.2 入口

順着代碼我們深入挖掘一下。elastic_launch 的作用就是啓動一個 torchelastic agent,然後通過這個 agent來調用用戶程序入口,agent 會啓動 worker 進行訓練,並且管理 worker 生命週期

class elastic_launch:
    """
    Launches an torchelastic agent on the container that invoked the entrypoint.

        1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
           ``entrypoint`` can be a function or a command.
        2. The return value is a map of each worker's output mapped
           by their respective global rank.
    """

    def __init__(
        self,
        config: LaunchConfig,
        entrypoint: Union[Callable, str, None],
    ):
        self._config = config
        self._entrypoint = entrypoint

    def __call__(self, *args, **kwargs):
        return launch_agent(self._config, self._entrypoint, list(args)) # 內部會調用用戶程序

4.3 啓動代理

launch_agent 啓動了一個 LocalElasticAgent,調用了其 run 方法。

@record
def launch_agent(
    config: LaunchConfig,
    entrypoint: Union[Callable, str, None],
    args: List[Any],
) -> Dict[int, Any]:
    if not config.run_id:
        run_id = str(uuid.uuid4().int)
        config.run_id = run_id

    entrypoint_name = _get_entrypoint_name(entrypoint, args)

    rdzv_parameters = RendezvousParameters(
        backend=config.rdzv_backend,
        endpoint=config.rdzv_endpoint,
        run_id=config.run_id,
        min_nodes=config.min_nodes,
        max_nodes=config.max_nodes,
        **config.rdzv_configs,
    )

    agent = None
    rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)
    master_addr, master_port = _get_addr_and_port(rdzv_parameters)
    try:
        spec = WorkerSpec( # 1. 得到spec
            role=config.role,
            local_world_size=config.nproc_per_node,
            entrypoint=entrypoint,
            args=tuple(args),
            rdzv_handler=rdzv_handler, # RendezvousHandler
            max_restarts=config.max_restarts,
            monitor_interval=config.monitor_interval,
            redirects=config.redirects,
            tee=config.tee,
            master_addr=master_addr,
            master_port=master_port,
        )

        cfg = metrics.MetricsConfig(config.metrics_cfg) if config.metrics_cfg else None
        metrics.initialize_metrics(cfg)

        agent = LocalElasticAgent( # 2. 構建代理
            spec=spec, start_method=config.start_method, log_dir=config.log_dir
        )

        result = agent.run() # 3. 啓動代理
        events.record(agent.get_agent_status_event(WorkerState.SUCCEEDED))
        if result.is_failed():
            # ChildFailedError is treated specially by @record
            # if the error files for the failed children exist
            # @record will copy the first error (root cause)
            # to the error file of the launcher process.
            raise ChildFailedError(
                name=entrypoint_name,
                failures=result.failures,
            )
        else:
            return result.return_values
    except ChildFailedError:
        raise
    except Exception:
        if agent:
            events.record(agent.get_agent_status_event(WorkerState.FAILED))
        else:
            events.record(_construct_event(config))
        raise
    finally:
        rdzv_handler.shutdown()

這裏有幾個關鍵點:

4.3.1 WorkerSpec

WorkerSpec :這是配置信息,裏面包含了代理所需要的某些全局信息,比如 RendezvousHandler,role,entry(用戶函數)。

spec = {WorkerSpec} 
   args = {tuple: 2} (tensor, tensor)
   fn = {NoneType} None
   local_world_size = {int} 1
   master_addr = {NoneType} None
   master_port = {NoneType} None
   max_restarts = {int} 1
   monitor_interval = {int} 1
   rdzv_handler = {DynamicRendezvousHandler}
   redirects = {Std} Std.NONE
   role = {str} 'trainer'
   tee = {Std} Std.NONE
   entry = worker_fn

代理會從這裏提取各種所需信息。比如_start_workers 會從中獲取 store。

use_agent_store = spec.rdzv_handler.get_backend() == "static"

此時邏輯爲:

+--------------------------+      +---------------------------------------------------+
|LocalElasticAgent         |      | WorkerSpec                                        |
|                          |      |                                                   |
|     WorkerSpec +--------------> |      rdzv_handler = {DynamicRendezvousHandler} --------+
|                          |      |                                                   |    |
|     rdzv_run_id          |      |      entry = worker_fn                            |    |
|                          |      |                                                   |    |
|     store                |      |      role = {str} 'trainer'                       |    |
|                          |      |                                                   |    |
|                          |      +---------------------------------------------------+    |
|                          |                                                               |
|                          |                                                               |
|                          |                                                               |
|                          |                                                               |
|                          |               +-----------------------------------------+     |
+--------------------------+               |DynamicRendezvousHandler                 |     |
                                           |                                         |     |
                                           |                                         |     |
                                           |   _settings: RendezvousSettings         | <---+
                                           |                                         |
                                           |   _store: Store                         |
                                           |                                         |
                                           |   _state_holder: _RendezvousStateHolder |
                                           |                                         |
                                           |   _op_executor: _RendezvousOpExecutor   |
                                           |                                         |
                                           +-----------------------------------------+

4.3.2 WorkerGroup

WorkerGroup 代表了一個工作組。WorkerGroup 作爲一個整體來管理多個 workers,進行批量處理。

class WorkerGroup:
    """
    Represents the set of ``Worker`` instances for the given ``WorkerSpec``
    managed by ``ElasticAgent``. Whether the worker group contains cross
    instance workers or not depends on the implementation of the agent.
    """

    __slots__ = ["spec", "workers", "store", "group_rank", "group_world_size", "state"]

    def __init__(self, spec: WorkerSpec):
        self.spec = spec
        self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)]

        # assigned after rdzv
        self.store = None
        self.group_rank = None
        self.group_world_size = None

        self.state = WorkerState.INIT

在SimpleElasticAgent 初始化之中,會建立一個 WorkerGroup。

class SimpleElasticAgent(ElasticAgent):
    """
    An ``ElasticAgent`` that manages workers (``WorkerGroup``)
    for a single ``WorkerSpec`` (e.g. one particular type of worker role).
    """

    def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
        self._worker_group = WorkerGroup(spec)
        self._remaining_restarts = self._worker_group.spec.max_restarts
        self._store = None
        self._exit_barrier_timeout = exit_barrier_timeout
        self._total_execution_time = 0

具體如下:

+-----------------------------+      +------------------------------------------------+
| LocalElasticAgent           |      | WorkerSpec                                     |
|                             |      |                                                |
| +------------------------+  |      |   rdzv_handler = {DynamicRendezvousHandler} -------+
| |WorkerGroup             |  |      |                                                |   |
| |            spec +--------------> |   entry = worker_fn                            |   |
| |            workers     |  |      |                                                |   |
| |            store       |  |      |   role = {str} 'trainer'                       |   |
| |            group_rank  |  |      |                                                |   |
| |       group_world_size |  |      +------------------------------------------------+   |
| |                        |  |                                                           |
| +------------------------+  |                                                           |
|                             |                                                           |
| rdzv_run_id                 |                                                           |
| store                       |            +-----------------------------------------+    |
|                             |            |DynamicRendezvousHandler                 |    |
+-----------------------------+            |                                         |    |
                                           |                                         |    |
                                           |   _settings: RendezvousSettings         | <--+
                                           |                                         |
                                           |   _store: Store                         |
                                           |                                         |
                                           |   _state_holder: _RendezvousStateHolder |
                                           |                                         |
                                           |   _op_executor: _RendezvousOpExecutor   |
                                           |                                         |
                                           +-----------------------------------------+

4.4 代理運行

SimpleElasticAgent 是 LocalElasticAgent 的基類,所以會先運行到WorkerSpec.run 方法這裏,run方法則調用了 _invoke_run。

    @prof
    def run(self, role: str = DEFAULT_ROLE) -> RunResult:
        start_time = time.monotonic()
        try:
            result = self._invoke_run(role) # 調用
            self._total_execution_time = int(time.monotonic() - start_time)
            self._record_metrics(result)
            self._record_worker_events(result)
            return result
        finally:
            # record the execution time in case there were any exceptions during run.
            self._total_execution_time = int(time.monotonic() - start_time)
            self._shutdown()

4.5 代理主循環

代理在 invoke_run 之中做如下操作:

  • 啓動 _initialize_workers,這裏會使用 _rendezvous 構建一個 rendezvous,然後調用 _start_workers 啓動 workers。
  • 進入 while True 循環,在循環之中:
    • 通過 _monitor_workers 定期輪訓用戶程序運行情況,得到客戶進程運行結果,然後依據情況作出判斷。
      • 如果程序正常結束,則返回。
      • 如果程序出錯,則重試,即重啓所有 workers,如果重試次數達到依然有問題,就結束所有workers。
      • 如果節點成員關係有變化,比如scale up就會有新的節點在waiting,這時候就重啓所有workers。
    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
        # NOTE: currently only works for a single role

        spec = self._worker_group.spec
        role = spec.role

        self._initialize_workers(self._worker_group) # 啓動worker
        monitor_interval = spec.monitor_interval
        rdzv_handler = spec.rdzv_handler

        while True:
            assert self._worker_group.state != WorkerState.INIT
            # 定期監控
            time.sleep(monitor_interval)
            # 監控客戶程序運行情況
            run_result = self._monitor_workers(self._worker_group) # 得到進程運行結果
            state = run_result.state
            self._worker_group.state = state

            put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
            put_metric(f"workers.{role}.{state.name.lower()}", 1)

            if state == WorkerState.SUCCEEDED:
                # 程序正常結束
                self._exit_barrier()
                return run_result
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                # 程序出錯
                if self._remaining_restarts > 0: # 重試
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group)
                else:
                    self._stop_workers(self._worker_group) # 重試次數達到,結束workers
                    self._worker_group.state = WorkerState.FAILED
                    self._exit_barrier()
                    return run_result
            elif state == WorkerState.HEALTHY:
                # 節點成員關係有變化,比如scale up,就會有新節點waiting
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                # 如果有新的節點在waiting,就重啓所有workers
                if num_nodes_waiting > 0:
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(f"[{role}] Worker group in {state.name} state")

於是最終邏輯如下:

+----------------------------------------------+
| LocalElasticAgent                            |
|                                              |    +---------------------------------------------------+
|  rdzv_run_id                                 |    | WorkerSpec                                        |
|                                              |    |                                                   |
|  store           +------------------------+  |    |      rdzv_handler = {DynamicRendezvousHandler} +-------+
|                  |WorkerGroup             |  |    |                                                   |    |
|  _pcontext       |            spec +------------> |      entry = worker_fn                            |    |
|                  |            workers     |  |    |                                                   |    |
|                  |            store       |  |    |      role = {str} 'trainer'                       |    |
|                  |            group_rank  |  |    |                                                   |    |
|                  |       group_world_size |  |    +---------------------------------------------------+    |
|                  |                        |  |                                                             |
|                  +------------------------+  |                                                             |
|  +----------------------------------------+  |                                                             |
|  | _invoke_run                            |  |                                                             |
|  |                                        |  |             +-----------------------------------------+     |
|  |   _initialize_workers +------------------------+        |DynamicRendezvousHandler                 |     |
|  |                                        |  |    |        |                                         |     |
|  |                                        |  |    |        |                                         |     |
|  |   while True:                          |  |    |        |   _settings: RendezvousSettings         | <---+
|  |       _monitor_workers(_worker_group)  |  |    |        |                                         |
|  |                +                       |  |    |        |   _store: Store                         |
|  |                | _pcontext.wait        |  |    |        |                                         |
|  |                |                       |  |    |        |   _state_holder: _RendezvousStateHolder |
|  +----------------------------------------+  |    |        |                                         |
|                   |                          |    |        |   _op_executor: _RendezvousOpExecutor   |
+----------------------------------------------+    |        |                                         |
                    |                               |        +-----------------------------------------+
                    |                               |
                    v                               v
         +-------------------------------------------------+
         |  +------------+  +------------+  +------------+ |
         |  |Process     |  |Process     |  |Process     | |
         |  |            |  |            |  |            | |
         |  |    work_fn |  |   work_fn  |  |    work_fn | |
         |  |            |  |            |  |            | |
         |  +------------+  +------------+  +------------+ |
         +-------------------------------------------------+

手機如下:

至此,腳本如何啓動和單體流程我們分析完畢,下一篇我們來具體分析代理。

0xFF 參考

[PyTorch Elastic源碼閱讀](

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章