[源码解析] PyTorch 分布式之弹性训练(7)---节点变化

[源码解析] PyTorch 分布式之弹性训练(7)---节点变化

0x00 摘要

本文分析如何处理节点变化。即对成员更改作出反应,并使用新的成员来重启所有workers,从而实现弹性训练。

总体思路是和当工作进程失败时的处理一样:相应elastic agent将杀死该节点上的所有工作进程,与其他代理建立会合(rendezvous),并使用新的会合(rendezvous)信息重新启动所有工作进程。

弹性训练系列文章如下:

[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

[源码解析] PyTorch 分布式之弹性训练(3)---代理

[源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

[源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎

[源码解析] PyTorch 分布式之弹性训练(6)---监控/容错

0x01 变化方式

节点变化有两点方式。

1.1 Scale-down

节点离开(scale-down)的处理如下:

  • 当Scale down事件发生时,rendezvous将不会通知 torchelastic agent。
  • torchelastic agent 自己会监控到有进程错误,从而进行处理。
  • 如果TE agent以max_restarts=0配置启动,它依赖于底层调度程序来处理作业重新启动。
  • 如果max_restarts>0,TE代理将终止workers并开始新一轮rendezvous。
    • 代理得到离开的通知,于是现有workers(所有节点上的)都全部停止。
    • 这些workers将形成一个新的“WorkerGroup”,所有worker都将以新的RANKWORLD_SIZE 运行。

1.2 Scale-up

节点加入(scale-up)的处理如下:

  • 当Scale up事件发生时,新节点被提交到作业,torchelastic rendezvous将检测到有新节点试图加入。
    • 如果rendezvous已经达到最多节点数,新节点将不会添加到等待列表,因为已经满了,所以没有必要拆除已经完全体的rendezvous。新节点将一直等待直到超时(默认为600秒)。
    • 新节点将定期检查参与节点数目。如果数目变为小于max_nodes,等待节点将被加入到等待列表中。否则它将在600秒之后超时。
  • 当代理决定处理 Scale up时:
    • torchelastic rendezvous将停止所有workers并执行新一轮的 re-rendezvous。
    • 这些workers(现有以及新加入的)将形成一个新的“WorkerGroup”,所有worker都将以新的RANKWORLD_SIZE 运行。

注:scale up发生时,max_restarts 将不会减少。

0x02 节点加入

2.1 新节点加入

假设目前已经有了一个弹性训练集群正在运行,弹性区间为 (min=1, max=4)。目前已经有2个节点在运行,用户想启动第三个节点,于是使用如下方法启动一个新进程。

python -m torch.distributed.run
        --nnodes=1:4
        --nproc_per_node=$NUM_TRAINERS
        --rdzv_id=$JOB_ID
        --rdzv_backend=c10d
        --rdzv_endpoint=$HOST_NODE_ADDR
        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

新进程会启动一个代理。代理经过一系列操作,调用 next_rendezvous,其中启动一个 ExitOp,一个 JoinOp 。

def next_rendezvous(self) -> Tuple[Store, int, int]:
    exit_op = _RendezvousExitOp()
    join_op = _RendezvousJoinOp()
    
    self._op_executor.run(exit_op, deadline)
    self._op_executor.run(join_op, deadline)    

2.2 处理 Join 操作

以下操作是在 _DistributedRendezvousOpExecutor 之中。

有了前文分析,我们知道,业务流程是 run 调用 Join 算子来分析出来下一个 Action,然后根据 Action 来执行对应的业务操作

2.2.1 run处理

_DistributedRendezvousOpExecutor.run 函数实现了基础逻辑,就是依据 action 类型进行各种操作。对于我们示例,state_handler 就是_RendezvousJoinOp。

    def run(
        self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
    ) -> None:
        """See base class."""
        action = None

        while action != _Action.FINISH: # 一直循环,直到结束
            
            # 这里很重要,在所有node之间做信息同步
            has_set = self._state_holder.sync() # 因为最新状态在 rendezvous。
            self._state = self._state_holder.state
            # 利用最新状态构建了 ctx
            ctx = _RendezvousContext(self._node, self._state, self._settings)

            # Determine the next action to take based on the current state of
            # the rendezvous.
            action = state_handler(ctx, deadline) # 调用_RendezvousJoinOp,决定下一个操作

            # 省略后续部分

2.2.2 Join操作

因为之前做了同步,所以这里的ctx就包括了最新的state,这就是Rendezvous的全局状态。因为此时,Rendezvous 已经结束了,所以 state 的状态是 complete,进入如下流程,返回 _Action.ADD_TO_WAIT_LIST。

    if state.complete:
        # If we are here, it means we are not part of the rendezvous. In
        # case the rendezvous has capacity for additional participants add
        # ourself to the wait list for the next round.
        if len(state.participants) < ctx.settings.max_nodes: # 如果当前节点数目小于最大配置
            if ctx.node not in state.wait_list: # 如果当前node不在等待列表之中
                return _Action.ADD_TO_WAIT_LIST  # 发送一个等待action

总体代码如下:

class _RendezvousJoinOp:
    """Represents a rendezvous join operation."""

    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
        state = ctx.state # 从上下文之中提取 _RendezvousState 状态

        # A closed rendezvous means that it no longer accepts new nodes.
        if state.closed:
            return _Action.ERROR_CLOSED # 如果已经结束,就返回 _Action.ERROR_CLOSED

        is_participant = ctx.node in state.participants # 看看是参与者

        # If we are part of the rendezvous and it is already complete there is
        # no further action to take.
        if state.complete and is_participant: # 如果是参与者且状态结束,就返回 _Action.FINISH
            return _Action.FINISH

        now = time.monotonic()
        if now > deadline: # 如果已经超时
            rollback_period = 5  # 5 seconds

            # If we still have time to rollback (a short period on top of the
            # operation deadline), try to remove ourself from the rendezvous.
            # It is okay if we can't though as our keep-alive will eventually
            # expire.
            if now <= deadline + rollback_period: # 如果还有时间来 rollback
                # If we are part of the rendezvous, it means we couldn't find
                # enough participants to complete it on time.
                if is_participant: # 已经是参与者了
                    return _Action.REMOVE_FROM_PARTICIPANTS # 需要从参与者列表移除
                # If we are in the wait list, it means we couldn't wait till the
                # next round of the rendezvous.
                if ctx.node in state.wait_list: # 已经在等待列表之中
                    return _Action.REMOVE_FROM_WAIT_LIST # 需要从等待列表移除
            return _Action.ERROR_TIMEOUT # 返回超时

        if state.complete: # 如果 rendezvous 已经结束
            # If we are here, it means we are not part of the rendezvous. In
            # case the rendezvous has capacity for additional participants add
            # ourself to the wait list for the next round.
            if len(state.participants) < ctx.settings.max_nodes: # 如果还没有达到最大节点数
                if ctx.node not in state.wait_list: # 如果当前node不在等待列表之中
                    return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,发送一个等待action
        elif is_participant: # 如果已经在参与者列表
            # If the rendezvous has enough number of participants including us,
            # check whether we have passed the rendezvous deadline. If yes,
            # complete it.
            if len(state.participants) >= ctx.settings.min_nodes: # 如果达到了最小节点数
                if cast(datetime, state.deadline) < datetime.utcnow(): # 如果达到了超时
                    return _Action.MARK_RENDEZVOUS_COMPLETE # 标示 rendezvous 已经结束
        else: # 否则就直接加入到参与者
            # The rendezvous is not complete yet and we are not part of it. Try
            # to join.
            return _Action.ADD_TO_PARTICIPANTS

        if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE
            return _Action.KEEP_ALIVE

        # At this point either the rendezvous is not complete, but we are part
        # of it, which means we have to wait for other participants to join; or
        # the rendezvous is complete, but we are not part of it, which means we
        # have to wait for the next round.
        return _Action.SYNC # 否则返回同步状态 _Action.SYNC

2.2.3 等待业务操作

_DistributedRendezvousOpExecutor 之中,run 函数实现了基础逻辑,就是依据 action 类型进行各种操作。

    def run(
        self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
    ) -> None:
        """See base class."""
        action = None

        while action != _Action.FINISH: # 一直循环,直到结束
     
            # 这里很重要,在所有node之间做信息同步
            has_set = self._state_holder.sync() # 因为最新状态在 rendezvous。
            self._state = self._state_holder.state
					  # 使用最新state构建ctx
            ctx = _RendezvousContext(self._node, self._state, self._settings)

            # Determine the next action to take based on the current state of
            # the rendezvous.
            action = state_handler(ctx, deadline) # 调用_RendezvousJoinOp,决定下一个操作,这里得到了 _Action.ADD_TO_WAIT_LIST

            if action == _Action.SYNC:
                _delay(seconds=1)
            else:
                if action == _Action.KEEP_ALIVE:
                    self._keep_alive()
                elif action == _Action.ADD_TO_WAIT_LIST: # 从 Join 算子得到了_Action.ADD_TO_WAIT_LIST
                    self._add_to_wait_list() # 进行业务逻辑
                # 省略其他action

                # Attempt to sync our changes back to other nodes.
                self._state_holder.mark_dirty() # 同步回其他节点

具体处理等待操作就是加入到等待列表。

def _add_to_wait_list(self) -> None:
    self._state.wait_list.add(self._node)
    self._keep_alive()

我们回忆一下 _RendezvousState。_RendezvousState 是rendezvous的状态。是动态信息。

  • round:Rendezvous的当前轮次
  • complete:一个布尔值,指示rendezvous当前一轮是否完成了。
  • deadline:截止时间,如果如果当前轮次一直在等待节点加入,如果这个参数设置了,就是等待的截至时间。
  • closed:一个布尔值,指示rendezvous是否结束了。
  • participants:字典,存放参与者和它们对应ranks。
  • wait_list:set结构,存放等待参与下一轮rendezvous操作的一组节点
  • last_heartbeats:字典,包含每个节点上次心跳时间。
class _RendezvousState:
    round: int
    complete: bool
    deadline: Optional[datetime]
    closed: bool
    participants: Dict[_NodeDesc, int] # 参与者,未来会用到的成员变量
    wait_list: Set[_NodeDesc]  # 等待者,这里用到的成员变量
    last_heartbeats: Dict[_NodeDesc, datetime]

    def __init__(self) -> None:
        self.round = 0
        self.complete = False
        self.deadline = None
        self.closed = False
        self.participants = {}
        self.wait_list = set() # 这里用到的成员变量
        self.last_heartbeats = {}

目前逻辑如下:

  1. 启动一个新 worker。此时下图右侧上方的 _RendezvousState 之中,wait_list 为空。
  2. 调用 next_rendezvous,发起新一轮 rendezvous。
  3. _RendezvousJoinOp 内部运行,生成 ADD_TO_WAIT_LIST。
  4. executor . run 内部运行 _add_to_wait_list。
  5. 往 wait_list 添加一个新的 node。此时下图右侧上方的 _RendezvousState 之中,wait_list 多了一个 1。
  python -m torch.distributed.run             +-------------------------+     +
      --nnodes=xxx TRAINING_SCRIPT.py         | _RendezvousState        |     |
                 +                            |                         |     |
                 |                            |    participants = [1,2] |     |
                 | 1                          |                         |     |
                 v                            |    wait_list = []       |     |
          next_rendezvous                     |                         |     |
                 +                            +------------+------------+     |
                 | 2                                       |                  |
                 |                                         |                  |
                 v                                         |                  |
+----------------+-----------------------+                 |                  |
| _op_executor.run(_RendezvousJoinOp)    |                 |                  |
|           +              +             |                 |                  |
|           |              | 3           |                 |                  |
|           |              |             |                 |                  |
|           |              v             |                 |                  |
|           |   _Action.ADD_TO_WAIT_LIST |                 v                  |
|           |              +             |                                    |
|           |              |             |    +--------------------------+    |
|           +<-------------+             |    | _RendezvousState         |    |
|           |                            |    |                          |    |
|           |                            |    |    participants = [1,2]  |    |
|           v       4                    | 5  |                          |    |
|      self._add_to_wait_list() +----------------> wait_list = [3]       |    |
|                                        |    |                          |    |
+----------------------------------------+    +--------------------------+    |
                                                                              |
                                                                              v

                                                                         Timeline

2.3 Agent 处理

_DistributedRendezvousOpExecutor . run 处理之后,操作回到了代理之中。代理主循环之中,程序会进入 while 循环,然后通过 _monitor_workers 定期轮训用户程序运行情况,依据情况作出判断。

    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
        # NOTE: currently only works for a single role

        spec = self._worker_group.spec
        role = spec.role

        self._initialize_workers(self._worker_group) # 启动worker
        monitor_interval = spec.monitor_interval
        rdzv_handler = spec.rdzv_handler

        while True:
            assert self._worker_group.state != WorkerState.INIT
            # 定期监控
            time.sleep(monitor_interval)
            # 监控客户程序运行情况
            run_result = self._monitor_workers(self._worker_group)
            state = run_result.state # 进程运行情况
            self._worker_group.state = state

            if state == WorkerState.SUCCEEDED:
                # 程序正常结束
                self._exit_barrier()
                return run_result
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                # 程序出错
                if self._remaining_restarts > 0: # 重试
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group)
                else:
                    self._stop_workers(self._worker_group) # 重试次数达到,结束workers
                    self._worker_group.state = WorkerState.FAILED
                    self._exit_barrier()
                    return run_result
            elif state == WorkerState.HEALTHY:
								# 程序正常运行
                # 节点成员关系有变化,比如scale up
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                # 如果有新的节点在waiting,就重启所有workers
                if num_nodes_waiting > 0:
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(f"[{role}] Worker group in {state.name} state")

所以,代理定期运行 _monitor_workers 监控worker运行情况才是关键。run_result.state 是进程运行情况,当状态是 WorkerState.HEALTHY,说明原有程序正常运行,接下来看看节点成员关系是否有变化。

调用 rdzv_handler.num_nodes_waiting() 拿到等待列表数目,如果有新的节点在waiting,就说明有新的节点试图加入集群,这时就会发生一个Re-rendezvous。代理将重启所有workers。重启时候,会把等待列表中的节点加入到参与列表之中。我们依次看看如何处理。

2.3.1 检查等待列表

处理时候,首先会调用 num_nodes_waiting 看看还有多少节点在等待,具体是看看 state.wait_list 的长度。我们通过之前 Join 操作知道,如果有新节点,会插入到这个列表之中。

num_nodes_waiting 方法的作用是 返回在 rendezvous barrier 上等待的节点数目(这些节点不会在当前工作组被包括)。调用者应该周期调用这个方法,来确定是否有新节点等候加入当前工作组,因此需要调用next_rendezvous() 来提交他们。

def num_nodes_waiting(self) -> int:
    """See base class."""
    with self._heartbeat_lock:
        self._state_holder.sync()

        return len(self._state_holder.state.wait_list)

目前逻辑如下:

  1. 启动一个新 worker。
  2. 调用 next_rendezvous,发起新一轮 rendezvous。
  3. _RendezvousJoinOp 内部运行,生成 ADD_TO_WAIT_LIST。
  4. executor.run 内部运行 _add_to_wait_list。
  5. 往 wait_list 添加一个新的 node。
  6. Agent 之中,定期(比如 30S)运行一次 _monitor_workers,获取worker 子进程状态。
  7. 如果是 HEALTHY,则调用num_nodes_waiting 获取 wait_list 个数。
  8. 如果 wait_list 之中等待节点数目大于 0,则:
  9. 调用 _restart_workers 重启进程组。
  python -m torch.distributed.run             +-------------------------+     +
      --nnodes=xxx TRAINING_SCRIPT.py         | _RendezvousState        |     |
                 +                            |                         |     |
                 |                            |    participants = [1,2] |     |
                 | 1                          |                         |     |
                 v                            |    wait_list = []       |     |
          next_rendezvous                     |                         |     |
                 +                            +------------+------------+     |
                 | 2                                       |                  |
                 |                                         |                  |
                 v                                         |                  |
+----------------+-----------------------+                 |                  |
| _op_executor.run(_RendezvousJoinOp)    |                 |                  |
|           +              +             |                 |                  |
|           |              | 3           |                 |                  |
|           |              |             |                 |                  |
|           |              v             |                 |                  |
|           |   _Action.ADD_TO_WAIT_LIST |                 v                  |
|           |              +             |                                    |
|           |              |             |    +--------------------------+    |
|           +<-------------+             |    | _RendezvousState         |    |
|           |                            |    |                          |    |
|           |                            |    |    participants = [1,2]  |    |
|           v       4                    | 5  |                          |    |
|      self._add_to_wait_list() +----------------> wait_list = [3]       |    |
|                                        |    |                          |    |
+----------------------------------------+    +------------+-------------+    |
                                                           |                  |
+----------------------------------------+                 |                  |
| agent._invoke_run                      |                 |                  |
|                                        |                 |                  |
|                                        |                 |                  |
|        _monitor_workers Every 30S      |                 |                  |
|                +                       |                 |                  |
|                | 6                     |                 |                  |
|                |                       |                 v                  |
|                v                       |                                    |
|         WorkerState.HEALTHY            |     +--------------------------+   |
|                +                       |     | _RendezvousState         |   |
|                |                       |     |                          |   |
|                | 7                     |     |     participants = [1,2] |   |
|                v                       |  8  |                          |   |
|        num_nodes_waiting   <-------------------->  wait_list = [3]      |   |
|                +                       |     |                          |   |
|                | 9                     |     |                          |   |
|                |                       |     +--------------------------+   |
|                v                       |                                    |
|        _restart_workers                |                                    v
|                                        |
+----------------------------------------+                               Timeline

2.3.3 重启worker组

如果等待列表之中有节点,就会重启workers。我们走一下这个流程。

@prof
def _restart_workers(self, worker_group: WorkerGroup) -> None:
    """
    Restarts (stops, rendezvous, starts) all local workers in the group.
    """

    role = worker_group.spec.role
    self._stop_workers(worker_group)
    worker_group.state = WorkerState.STOPPED
    self._initialize_workers(worker_group)
2.3.3.1 _stop_workers

首先会停止目前 workers,代码在torch/distributed/elastic/agent/server/local_elastic_agent.py。

@prof
def _stop_workers(self, worker_group: WorkerGroup) -> None:
    self._shutdown()
2.3.3.2 _shutdown

_shutdown 就是让上下文关闭。

def _shutdown(self) -> None:
    if self._pcontext:
        self._pcontext.close()
2.3.3.3 关闭上下文

在 MultiprocessContext 之中,close 方法是关闭所有子进程,然后等待其全部停止。

    def _close(self) -> None:
        if self._pc:
            for proc in self._pc.processes:
                proc.terminate()
                proc.join()
2.3.3.4 _initialize_workers

当关闭了所有当前运行的子进程之后,会重新全部初始化。

@prof
def _initialize_workers(self, worker_group: WorkerGroup) -> None:
    r"""
    Starts a fresh set of workers for the worker_group.
    Essentially a rendezvous followed by a start_workers.

    The caller should first call ``_stop_workers()`` to stop running workers
    prior to calling this method.

    Optimistically sets the state of the worker group that
    just started as ``HEALTHY`` and delegates the actual monitoring
    of state to ``_monitor_workers()`` method
    """
    role = worker_group.spec.role

    # TODO after stopping workers, wait at least monitor_interval*2 for
    # workers on different nodes to fail on a collective op before waiting
    # on the rdzv barrier, this way we ensure that nodes enter rdzv
    # at around the same time and reduce false positive rdzv timeout errors
    self._rendezvous(worker_group)

    worker_ids = self._start_workers(worker_group)
    for local_rank, w_id in worker_ids.items():
        worker = worker_group.workers[local_rank]
        worker.id = w_id

    worker_group.state = WorkerState.HEALTHY

_rendezvous经过一系列操作,调用 next_rendezvous,在其中启动一个 ExitOp,一个 JoinOp 。

def next_rendezvous(self) -> Tuple[Store, int, int]:

    exit_op = _RendezvousExitOp()
    join_op = _RendezvousJoinOp()
    
    self._op_executor.run(exit_op, deadline)
    self._op_executor.run(join_op, deadline)    
2.3.3.5 _RendezvousJoinOp

我们又回来了,这是新一轮 Rendezvous 操作。_DistributedRendezvousOpExecutor 之中,run 函数实现了基础逻辑,就是依据 action 类型进行各种操作。对于我们示例,state_handler 就是_RendezvousJoinOp

def run(
    self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
    """See base class."""
    action = None

    while action != _Action.FINISH:
        # Reads or writes the latest rendezvous state shared by all nodes in
        # the rendezvous. Note that our local changes might get overridden
        # by another node if that node synced its changes before us.
        has_set = self._state_holder.sync()
        self._state = self._state_holder.state
        ctx = _RendezvousContext(self._node, self._state, self._settings)

        # Determine the next action to take based on the current state of
        # the rendezvous.
        # 调用到_RendezvousJoinOp,大家可以过一下 _RendezvousJoinOp 代码,发现此时将返回 ADD_TO_PARTICIPANTS
        action = state_handler(ctx, deadline) 

        if action == _Action.SYNC:
            # Delay the execution by one second to avoid overloading the
            # backend if we are asked to poll for state changes.
            _delay(seconds=1)
        else:
            if action == _Action.KEEP_ALIVE:
                self._keep_alive()
            elif action == _Action.ADD_TO_PARTICIPANTS: # 运行到这里
                self._add_to_participants()
            elif action == _Action.ADD_TO_WAIT_LIST:
                self._add_to_wait_list()
            elif action == _Action.REMOVE_FROM_PARTICIPANTS:
                self._remove_from_participants()
            elif action == _Action.REMOVE_FROM_WAIT_LIST:
                self._remove_from_wait_list()
            elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
                self._mark_rendezvous_complete()
            elif action == _Action.MARK_RENDEZVOUS_CLOSED:
                self._mark_rendezvous_closed()

            # Attempt to sync our changes back to other nodes.
            self._state_holder.mark_dirty()

这次会生成 ADD_TO_PARTICIPANTS。

class _RendezvousJoinOp:
    """Represents a rendezvous join operation."""

    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
        state = ctx.state # 从上下文之中提取 _RendezvousState 状态

        # A closed rendezvous means that it no longer accepts new nodes.
        if state.closed:
            return _Action.ERROR_CLOSED # 如果已经结束,就返回 _Action.ERROR_CLOSED

        is_participant = ctx.node in state.participants # 看看是参与者

        # If we are part of the rendezvous and it is already complete there is
        # no further action to take.
        if state.complete and is_participant: # 如果是参与者且状态结束,就返回 _Action.FINISH
            return _Action.FINISH

        now = time.monotonic()
        if now > deadline: # 如果已经超时
            rollback_period = 5  # 5 seconds

            # If we still have time to rollback (a short period on top of the
            # operation deadline), try to remove ourself from the rendezvous.
            # It is okay if we can't though as our keep-alive will eventually
            # expire.
            if now <= deadline + rollback_period: # 如果还有时间来 rollback
                # If we are part of the rendezvous, it means we couldn't find
                # enough participants to complete it on time.
                if is_participant: # 已经是参与者了
                    return _Action.REMOVE_FROM_PARTICIPANTS # 需要从参与者列表移除
                # If we are in the wait list, it means we couldn't wait till the
                # next round of the rendezvous.
                if ctx.node in state.wait_list: # 已经在等待列表之中
                    return _Action.REMOVE_FROM_WAIT_LIST # 需要从等待列表移除
            return _Action.ERROR_TIMEOUT # 返回超时

        if state.complete: # 如果 rendezvous 已经结束
            # If we are here, it means we are not part of the rendezvous. In
            # case the rendezvous has capacity for additional participants add
            # ourself to the wait list for the next round.
            if len(state.participants) < ctx.settings.max_nodes: # 如果还没有达到最大节点数
                if ctx.node not in state.wait_list: # 如果当前node不在等待列表之中
                    return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,发送一个等待action
        elif is_participant: # 如果已经在参与者列表
            # If the rendezvous has enough number of participants including us,
            # check whether we have passed the rendezvous deadline. If yes,
            # complete it.
            if len(state.participants) >= ctx.settings.min_nodes: # 如果达到了最小节点数
                if cast(datetime, state.deadline) < datetime.utcnow(): # 如果达到了超时
                    return _Action.MARK_RENDEZVOUS_COMPLETE # 标示 rendezvous 已经结束
        else: # 否则就直接加入到参与者
            # The rendezvous is not complete yet and we are not part of it. Try
            # to join.
            return _Action.ADD_TO_PARTICIPANTS

        if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE
            return _Action.KEEP_ALIVE

        # At this point either the rendezvous is not complete, but we are part
        # of it, which means we have to wait for other participants to join; or
        # the rendezvous is complete, but we are not part of it, which means we
        # have to wait for the next round.
        return _Action.SYNC # 否则返回同步状态 _Action.SYNC
2.3.3.6 _add_to_participants

引擎收到 ADD_TO_PARTICIPANTS 之后,会调用 _add_to_participants 从 wait_list 移除节点,插入到 participants。

def _add_to_participants(self) -> None:
    log.debug(
        f"The node '{self._node}' added itself to the participants of round "
        f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync."
    )

    state = self._state
    state.wait_list.remove(self._node) # 移除节点

    # The ranks of the participants will be set once the rendezvous is
    # complete.
    state.participants[self._node] = 0 # 重新插入

    self._keep_alive()

    if len(state.participants) == self._settings.min_nodes:
        state.deadline = datetime.utcnow() + self._settings.timeout.last_call

    if len(state.participants) == self._settings.max_nodes:
        self._mark_rendezvous_complete()

我们这次从 _restart_workers 开始绘制。

  1. 调用 _stop_workers 来关闭worker子进程。此时下图右侧上方 _RendezvousState之中,participants=[1,2]。
  2. 通过 MultiprocessContext.close() 完成关闭操作。
  3. 通过 _initialize_workers 重新初始化 worker。
  4. 调用 next_rendezvous 完成新的同步操作。
  5. _RendezvousJoinOp 这次返回ADD_TO_PARTICIPANTS。
  6. 调用 _add_to_participants 进行状态切换。
  7. wait_list 之中的Node被移动到 participants。此时下图右侧上方 _RendezvousState之中,participants=[1,2,3]。
                         +-----------------------------+   +------------------------+  |
                         |  agent._invoke_run          |   | _RendezvousState       |  |
                         |                             |   |                        |  |
                         |       _restart_workers      |   |   participants = [1,2] |  |
                         |              +              |   |                        |  |
+----------------------+ |              |              |   |   wait_list = [3]      |  |
| MultiprocessContext  | |              | 1            |   |                        |  |
|                      | | 2            v              |   +------------------------+  |
|        close()  <-----------+  _stop_workers         |                               |
|                      | |              +              |                               |
+----------------------+ |              |              |                               |
                         |              | 3            |                               |
                         |              v              |                               |
                         |     _initialize_workers     |                               |
                         |              +              |                               |
                         |              |              |                               |
                         +-----------------------------+                               |
                                        |                                              |
                                        | 4                                            |
                                        v                                              |
                                 next_rendezvous                                       |
                                        +                                              |
                                        |                                              |
                                        v                                              |
            +---------------------------+---------------+                              |
            | _op_executor.run(_RendezvousJoinOp)       |                              |
            |           +               +               |                              |
            |           |               |               |                              |
            |           |               | 5             |                              |
            |           |               v               |                              |
            |           |       ADD_TO_PARTICIPANTS     |                              |
            |           |               +               |   +-----------------------+  |
            |           |               |               |   | _RendezvousState      |  |
            |           | <-------------+               |   |                       |  |
            |           |                               |   | participants = [1,2,3]|  |
            |           v     6                  7      |   |                       |  |
            |        _add_to_participants  +--------------> | wait_list = []        |  |
            |                                           |   |                       |  |
            +-------------------------------------------+   +-----------------------+  v

                                                                                 Timeline


0x03 节点离开

3.1 处理机制

节点离开(scale-down)的处理如下:

  • 当Scale down事件发生时,rendezvous将不会通知 torchelastic agent。
  • 如果TE agent以“max_restarts=0”启动,它依赖于底层调度程序来处理作业重新启动。
  • 如果“max_restarts>0”,TE代理将终止workers并开始新一轮rendezvous。
    • 代理得到离开的通知,于是现有workers(所有节点上)都全部停止。
    • 这些workers将形成一个新的“WorkerGroup”,所有worker都将以新的RANKWORLD_SIZE 运行。、

3.2 如何模拟

如果想模拟调试的同学,可以在 test/distributed/elastic/agent/server/test/local_elastic_agent_test.py 之中找到示例代码。

def test_double_agent_elastic(self):
    """
    start ``nnodes`` agents, kill odd ones (do not restart), validate
    elasticity (scale-down) works. (scale-up covered in fault_tolerance test)
    """
    min_nodes = 1
    max_nodes = 2
    wait = 2
    node_conf = Conf(entrypoint=_dist_sum, args=(wait,), local_world_size=2)
    agent_results = mp.Queue()
    agent_args = {
        "conf": node_conf,
        "agent_results": agent_results,
        "min_nodes": min_nodes,
        "max_nodes": max_nodes,
        "max_restarts": 2,
    }

    procs = []
    for _ in range(max_nodes):
        p = mp.Process(
            target=self.run_agent,
            kwargs=agent_args,
        )
        procs.append(p)
        p.start()

    # kill odd agents
    for i in range(max_nodes):
        if i % 2 != 0:
            procs[i].kill()

    for i in range(max_nodes):
        p = procs[i]
        p.join()
        if i % 2 == 0:
            self.assertEqual(0, p.exitcode)
        else:
            self.assertEqual(-signal.SIGKILL, p.exitcode)

3.3 如何处理

节点离开,与错误处理是同一个代码。错误处理代码如下,如果重试尚未达到最大次数,则试图重启workers。如果已经达到了最大次数,则停止 workers。

    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
        
        # 省略
     
        while True:

            # 定期监控
            time.sleep(monitor_interval)
            # 监控客户程序运行情况
            run_result = self._monitor_workers(self._worker_group)
            
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
            # 程序出错
            
            if self._remaining_restarts > 0: # 重试
                self._remaining_restarts -= 1
                self._restart_workers(self._worker_group) # 进行重启
            else:
                self._stop_workers(self._worker_group) # 重试次数达到,结束workers
                self._worker_group.state = WorkerState.FAILED
                self._exit_barrier()
                return run_result

3.3.1 重启

_restart_workers 会停掉所有 workers,然后重新一轮 rendezvous 。

@prof
def _restart_workers(self, worker_group: WorkerGroup) -> None:
    """
    Restarts (stops, rendezvous, starts) all local workers in the group.
    """

    role = worker_group.spec.role
    self._stop_workers(worker_group)
    worker_group.state = WorkerState.STOPPED
    self._initialize_workers(worker_group)

3.3.2 停止

停止 workers 就是关闭上下文。

def _shutdown(self) -> None:
    if self._pcontext:
        self._pcontext.close()
        
@prof
def _stop_workers(self, worker_group: WorkerGroup) -> None:
    self._shutdown()

在 MultiprocessContext 之中,close 方法是关闭所有子进程,然后等待其全部停止。

    def _close(self) -> None:
        if self._pc:
            for proc in self._pc.processes:
                proc.terminate()
                proc.join()

流程图如下:

  1. 监控子进程状态。
  2. 发现 UNHEALTHY 或者 FAILED,看看重启次数是否还有。我们假定是3号进程失败。
  3. 如果没有,就调用 _stop_workers 结束子进程。
  4. 调用 MultiprocessContext.close 进行具体结束操作。
  5. 如果还可以重启,调用_restart_workers。
  6. 调用 _stop_workers 结束子进程。
  7. 调用 MultiprocessContext.close 进行具体结束操作。
  8. 调用 _initialize_workers 重新初始化worker。
  9. 调用 next_rendezvous 重新同步。
  10. 进行后续操作。
                                                                                 +
+-------------------------------------------+    +---------------------------+   |
| agent._invoke_run                         |    | _RendezvousState          |   |
|                                           |    |                           |   |
|                                           |    |                           |   |
|     _monitor_workers Every 30S            |    |    participants = [1,2,3] |   |
|             +                             |    |                           |   |
|             | 1                           |    |    wait_list = [ ]        |   |
|             |                             |    |                           |   |
|             v                             |    +---------------------------+   |
|     WorkerState.UNHEALTHY,FAILED          |                                    |
|             +                             |                                    |
|             |                             |                                    |
|             | 2                           |                                    |
|             v                             |                                    |
|   self._remaining_restarts > 0 ? +--+     |                                    |
|             +                       |     |                                    |
|          5  | YES                NO | 3   |                                    |
|             |                       |     |                                    |
|             v                       v     |    +----------------------+        |
|     _restart_workers        _stop_workers |    | MultiprocessContext  |        |
|             +                       +     |    |                      |        |
|             | 6                     |  4  |    |                      |        |
|             |                       +--------> |                      |        |
|             v                             |    |        close()       |        |
|      _stop_workers +-------------------------> |                      |        |
|             +                 7           |    +----------------------+        |
|             |                             |                                    |
|             | 8                           |                                    |
|             v                             |                                    |
|    _initialize_workers                    |                                    |
|             +                             |                                    |
|             |                             |                                    |
+-------------------------------------------+                                    |
              | 9                                                                |
              |                                                                  |
              v                                +--------------------------+      |
        next_rendezvous                        | _RendezvousState         |      |
              +                                |                          |      |
              |               10               |     participants = [1,2] |      |
              +---------------------------->   |                          |      |
              |                                |     wait_list = [ ]      |      v
              | 10                             +--------------------------+
              v                                                             Timeline

至此,弹性训练全部分析完毕,或者说PyTorch分布式分析就告一段落,我们下文会介绍其他框架/库的分布式实现,敬请期待。

0xFF 参考

[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

[源码解析] PyTorch 分布式之弹性训练(3)---代理

[源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

[源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章