[源碼解析] TensorFlow 分佈式之 ClusterCoordinator

[源碼解析] TensorFlow 分佈式之 ClusterCoordinator

本文我們主要來看看ParameterServerStrategy如何分發計算,也就是ClusterCoordinator如何運作。這是TF分佈式的最後一篇。

安利兩個github,都是非常好的學習資料,推薦。

https://github.com/yuhuiaws/ML-study

https://github.com/Jack47/hack-SysML

另外推薦西門宇少的最新大作讓Pipeline在Transformer LM上沿着Token level並行起來——TeraPipe

本系列其他文章是:

[翻譯] TensorFlow 分佈式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分佈式之論文篇 "Implementation of Control Flow in TensorFlow"

[源碼解析] TensorFlow 分佈式環境(1) --- 總體架構

[源碼解析] TensorFlow 分佈式環境(2)---Master 靜態邏輯

[源碼解析] TensorFlow 分佈式環境(3)--- Worker 靜態邏輯

[源碼解析] TensorFlow 分佈式環境(4) --- WorkerCache

[源碼解析] TensorFlow 分佈式環境(5) --- Session

[源碼解析] TensorFlow 分佈式環境(7) --- Worker 動態邏輯

[源碼解析] TensorFlow 分佈式環境(8) --- 通信機制

[翻譯] 使用 TensorFlow 進行分佈式訓練

[源碼解析] TensorFlow 分佈式 DistributedStrategy 之基礎篇

[源碼解析] TensorFlow 之 分佈式變量

[源碼解析] TensorFlow 分佈式之 MirroredStrategy

[源碼解析] TensorFlow 分佈式之 MirroredStrategy 分發計算

[源碼解析] TensorFlow 分佈式之 ParameterServerStrategy V1

[源碼解析] TensorFlow 分佈式之 ParameterServerStrategy V2

1. 思路

TensorFlow 2 推薦使用一種基於中央協調的架構來進行參數服務器訓練。每個工作者和參數服務器都運行一個 tf.distribution.Server,在此基礎上,一個協調者任務負責在工作者和參數服務器上創建資源,調度功能,並協調訓練。協調器使用 tf.distribution.experimental.coordinator.ClusterCoordinator 來協調集羣,使用 tf.distribution.experimental.ParameterServerStrategy 來定義參數服務器上的變量和工作者的計算。

ClusterCoordinator 是一個用於安排和協調遠程函數執行的對象。該類用於創建容錯(fault-tolerant)資源和調度函數到遠程 TensorFlow 服務器。目前該類不支持獨立使用,它應該與旨在與之合作的 tf.distribution 策略一起使用。ClusterCoordinator 類目前只適用於和 tf.distribution.experimental.ParameterServerStrategy 一起工作。

1.1 使用

在使用 ParameterServerStrategy 定義所有的計算後,用戶可以使用 tf.distribution.experimental.coordinator.ClusterCoordinator 類來創建資源並將訓練步驟分配給遠程工作者。

首先,我們來創建一個 ClusterCoordinator 對象並傳入策略對象。

strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)

其次,創一個屬於每個工作者(per-worker)的數據集和一個迭代器。在下面代碼的 per_worker_dataset_fn 中,建議將 dataset_fn 包裹到 strategy.distribution_datasets_from_function 中,以允許無縫高效的把數據預取(prefetching )到 GPU。

@tf.function
def per_worker_dataset_fn():
  return strategy.distribute_datasets_from_function(dataset_fn)

per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)

最後一步是使用 ClusterCoordinator.schedule 將計算分配給遠程工作者。

  • schedule 方法把一個 tf.function 插入隊列,並立即返回一個 future-like 的 RemoteValue 。隊列之中的函數將被派發給後臺線程中的遠程工作者,RemoteValue 將被異步填充結果。
  • 用戶可以使用 join 方法( ClusterCoordinator.join )來等待所有被規劃(scheduled)的函數執行。
@tf.function
def step_fn(iterator):
	return next(iterator)

num_epoches = 4
steps_per_epoch = 5
for i in range(num_epoches):
  accuracy.reset_states()
  for _ in range(steps_per_epoch):
    coordinator.schedule(step_fn, args=(per_worker_iterator,))
  # Wait at epoch boundaries.
  coordinator.join()
  print ("Finished epoch %d, accuracy is %f." % (i, accuracy.result().numpy()))

下面是如何得到 RemoteValue 的結果。

loss = coordinator.schedule(step_fn, args=(per_worker_iterator,))
print ("Final loss is %f" % loss.fetch())

用戶也可以啓動所有的步驟(steps),並在等待完成時做一些事情。

for _ in range(total_steps):
  coordinator.schedule(step_fn, args=(per_worker_iterator,))
while not coordinator.done():
  time.sleep(10)
  # Do something like logging metrics or writing checkpoints.

1.2 問題點

依據前面的代碼,我們總結出來問題點如下:

  • Worker 如何知道使用哪些設備?
  • 如何具體執行用戶函數?
  • 如何獲取數據?

接下來我們就嘗試通過分析代碼來回答這些問題。

2. 定義

ClusterCoordinator 的主要思路如下。

  • 協調者不是訓練工作者之一,相反,它負責創建資源,如變量和數據集,調度 "tf.function",保存檢查點等等。
  • 爲了使訓練工作順利進行,協調者派遣 "tf.function" 在遠程工作者上執行。
  • 在收到協調者的請求後,工作者通過從參數服務器讀取變量、執行操作和更新參數服務器上的變量來執行 "tf.function"。
  • 每個工作者只處理來自協調者的請求,並與參數服務器進行通信。而不與集羣中的其他工作者直接互動。

ClusterCoordinator 定義具體如下,我們可以看到,其主要是配置了 _strategy 成員變量,生成了 _cluster 成員變量。

@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[])
class ClusterCoordinator(object):
    
  def __new__(cls, strategy):
    #  ClusterCoordinator  is kept as a single instance to a given  Strategy .
    if strategy._cluster_coordinator is None:
      strategy._cluster_coordinator = super(
          ClusterCoordinator, cls).__new__(cls)
    return strategy._cluster_coordinator

  def __init__(self, strategy):
    """Initialization of a  ClusterCoordinator  instance.

    Args:
      strategy: a supported  tf.distribute.Strategy  object. Currently, only
         tf.distribute.experimental.ParameterServerStrategy  is supported.

    Raises:
      ValueError: if the strategy being used is not supported.
    """
    if not getattr(self, "_has_initialized", False):
      if not isinstance(strategy,
                        parameter_server_strategy_v2.ParameterServerStrategyV2):
        raise ValueError(
            "Only  tf.distribute.experimental.ParameterServerStrategy  "
            "is supported to work with "
            " tf.distribute.experimental.coordinator.ClusterCoordinator  "
            "currently.")
      self._strategy = strategy
      self.strategy.extended._used_with_coordinator = True
      self._cluster = Cluster(strategy)
      self._has_initialized = True

  def __del__(self):
    self._cluster.stop()

  @property
  def strategy(self):
    """Returns the  Strategy  associated with the  ClusterCoordinator ."""
    return self._strategy

2.1 Schedule

由 ClusterCoordinator 對象提供的最重要的 API 是 schedule,其會分派 tf.function 到一個工作者,以便異步執行,具體如下:

  • 該方法是非阻塞的,因爲它把 fn 插入隊列,並立即返回 tf.distribution.experimental.coordinator.RemoteValue 對象。fn 排隊等待稍後執行。
  • 在隊列之中排隊的函數將被派發給後臺線程中的遠程工作者來異步執行,他們的 RemoteValue 將被異步賦值。
  • 由於 schedule 不需要分配一個工作者,傳遞進來的 tf.function 可以在任何可用的工作者上執行。
  • 可以調用 fetch 來等待函數執行完成,並從遠程工作者那裏獲取其輸出。另一方面,也可以調用 tf.distribution.experimental.coordinator.ClusterCoordinator.join 來等待所有預定的函數完成。

失敗和容錯的策略如下:

  • 由於工作者在執行函數的任何時候都可能失敗,所以函數有可能被部分執行,但是 tf.distribution.experimental.coordinator.ClusterCoordinator 保證在這些事件中,函數最終將在任何可用的工作者上執行。
  • schedule 保證 fn 至少在工作者上執行一次;如果其對應的工作者在執行過程中失敗,由於函數的執行不是原子性的,所以一個函數可能被執行多次。
  • 如果被執行的工作者在結束之前變得不可用,該函數將在另一個可用的工作者上重試。
  • 如果任何先前安排的函數出現錯誤,schedule 將拋出其中任何一個錯誤,並清除到目前爲止收集的錯誤。用戶可以在返回的 tf.distribution.experimental.coordinator.RemoteValue 上調用 fetch 來檢查它們是否已經執行、失敗或取消,如果需要,可以重新安排相應的函數。當 schedule 引發異常時,它保證沒有任何函數仍在執行。

Schedule 的具體定義如下,數據迭代器作爲參數之一會和 fn 一起被傳入。

  def schedule(self, fn, args=None, kwargs=None):
    """Schedules  fn  to be dispatched to a worker for asynchronous execution.

    This method is non-blocking in that it queues the  fn  which will be
    executed later and returns a 
     tf.distribute.experimental.coordinator.RemoteValue  object immediately.
     fetch  can be called on it to wait for the function execution to finish
    and retrieve its output from a remote worker. On the other hand, call
     tf.distribute.experimental.coordinator.ClusterCoordinator.join  to wait for
    all scheduled functions to finish.

     schedule  guarantees that  fn  will be executed on a worker at least once;
    it could be more than once if its corresponding worker fails in the middle
    of its execution. Note that since worker can fail at any point when
    executing the function, it is possible that the function is partially
    executed, but  tf.distribute.experimental.coordinator.ClusterCoordinator 
    guarantees that in those events, the function will eventually be executed on
    any worker that is available.

    If any previously scheduled function raises an error,  schedule  will raise
    any one of those errors, and clear the errors collected so far. What happens
    here, some of the previously scheduled functions may have not been executed.
    User can call  fetch  on the returned
     tf.distribute.experimental.coordinator.RemoteValue  to inspect if they have
    executed, failed, or cancelled, and reschedule the corresponding function if
    needed.

    When  schedule  raises, it guarantees that there is no function that is
    still being executed.

    At this time, there is no support of worker assignment for function
    execution, or priority of the workers.

     args  and  kwargs  are the arguments passed into  fn , when  fn  is
    executed on a worker. They can be
     tf.distribute.experimental.coordinator.PerWorkerValues  and in this case,
    the argument will be substituted with the corresponding component on the
    target worker. Arguments that are not
     tf.distribute.experimental.coordinator.PerWorkerValues  will be passed into
     fn  as-is. Currently,  tf.distribute.experimental.coordinator.RemoteValue 
    is not supported to be input  args  or  kwargs .

    Args:
      fn: A  tf.function ; the function to be dispatched to a worker for
        execution asynchronously. Regular python funtion is not supported to be
        scheduled.
      args: Positional arguments for  fn .
      kwargs: Keyword arguments for  fn .

    Returns:
      A  tf.distribute.experimental.coordinator.RemoteValue  object that
      represents the output of the function scheduled.

    Raises:
      Exception: one of the exceptions caught by the coordinator from any
        previously scheduled function, since the last time an error was thrown
        or since the beginning of the program.
    """
    if not isinstance(fn,
                      (def_function.Function, tf_function.ConcreteFunction)):
      raise TypeError(
          " tf.distribute.experimental.coordinator.ClusterCoordinator.schedule "
          " only accepts a  tf.function  or a concrete function.")
    # Slot variables are usually created during function tracing time; thus
    #  schedule  needs to be called within the  strategy.scope() .
    with self.strategy.scope():
      self.strategy.extended._being_scheduled = True  
      remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs)
      self.strategy.extended._being_scheduled = False  
      return remote_value

2.2 Join

Join 方法的作用是阻塞直到所有預定的函數都執行完畢,其具體特點如下:

  • 如果任何先前安排的函數產生錯誤,join 將因爲拋出一個錯誤而失敗,並清除到目前爲止收集的錯誤。如果發生這種情況,一些先前安排的函數可能沒有被執行。
  • 用戶可以對返回的 tf.distribution.experimental.coordinator.RemoteValue 調用 fetch 來檢查它們是否已經執行、失敗或取消了。
  • 如果一些已經取消的函數需要重新安排,用戶應該再次調用 schedule 。
  • 當 join 返回或拋出異常時,它保證沒有任何函數仍在執行。
  def join(self):
    """Blocks until all the scheduled functions have finished execution.

    If any previously scheduled function raises an error,  join  will fail by
    raising any one of those errors, and clear the errors collected so far. If
    this happens, some of the previously scheduled functions may have not been
    executed. Users can call  fetch  on the returned
     tf.distribute.experimental.coordinator.RemoteValue  to inspect if they have
    executed, failed, or cancelled. If some that have been cancelled need to be
    rescheduled, users should call  schedule  with the function again.

    When  join  returns or raises, it guarantees that there is no function that
    is still being executed.

    Raises:
      Exception: one of the exceptions caught by the coordinator by any
        previously scheduled function since the last time an error was thrown or
        since the beginning of the program.
    """
    self._cluster.join()

2.3 Done

Done 方法返回所有分發的函數是否已經執行完畢。如果任何先前分發的函數引發錯誤,done'將會失敗。

  def done(self):
    """Returns whether all the scheduled functions have finished execution.

    If any previously scheduled function raises an error,  done  will fail by
    raising any one of those errors.

    When  done  returns True or raises, it guarantees that there is no function
    that is still being executed.

    Returns:
      Whether all the scheduled functions have finished execution.
    Raises:
      Exception: one of the exceptions caught by the coordinator by any
        previously scheduled function since the last time an error was thrown or
        since the beginning of the program.
    """
    return self._cluster.done()

2.4 Fetch

Fetch 會獲取 remote values 的結果。

  def fetch(self, val):
    """Blocking call to fetch results from the remote values.

    This is a wrapper around
     tf.distribute.experimental.coordinator.RemoteValue.fetch  for a
     RemoteValue  structure; it returns the execution results of
     RemoteValue s. If not ready, wait for them while blocking the caller.

    Example:
    ```python
    strategy = ...
    coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
        strategy)

    def dataset_fn():
      return tf.data.Dataset.from_tensor_slices([1, 1, 1])

    with strategy.scope():
      v = tf.Variable(initial_value=0)

    @tf.function
    def worker_fn(iterator):
      def replica_fn(x):
        v.assign_add(x)
        return v.read_value()
      return strategy.run(replica_fn, args=(next(iterator),))

    distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
    distributed_iterator = iter(distributed_dataset)
    result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
    assert coordinator.fetch(result) == 1
    ```

    Args:
      val: The value to fetch the results from. If this is structure of
         tf.distribute.experimental.coordinator.RemoteValue ,  fetch()  will be
        called on the individual
         tf.distribute.experimental.coordinator.RemoteValue  to get the result.

    Returns:
      If  val  is a  tf.distribute.experimental.coordinator.RemoteValue  or a
      structure of  tf.distribute.experimental.coordinator.RemoteValue s,
      return the fetched  tf.distribute.experimental.coordinator.RemoteValue 
      values immediately if they are available, or block the call until they are
      available, and return the fetched
       tf.distribute.experimental.coordinator.RemoteValue  values with the same
      structure. If  val  is other types, return it as-is.
    """

    def _maybe_fetch(val):
      if isinstance(val, RemoteValue):
        return val.fetch()
      else:
        return val

    return nest.map_structure(_maybe_fetch, val)

3. 數據

除了調度遠程函數,ClusterCoordinator 還幫助在所有工作者上創建數據集,並當一個工作者從失敗中恢復時重建這些數據集。用戶可以通過調用 dataset_fn 來在worker設備上創建數據集。使用例子如下:

strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
    strategy=strategy)

@tf.function
def worker_fn(iterator):
  return next(iterator)

def per_worker_dataset_fn():
  return strategy.distribute_datasets_from_function(
      lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))

per_worker_dataset = coordinator.create_per_worker_dataset(
    per_worker_dataset_fn)
per_worker_iter = iter(per_worker_dataset)
remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
assert remote_value.fetch() == 3

3.1 建立數據集

上面代碼使用了 create_per_worker_dataset 在worker上創建數據集,這些數據集由 dataset_fn 生成,並返回一個代表這些數據集的集合。在這樣的數據集集合上調用 iter 會返回一個 tf.distribution.experimental.coordinator.PerWorkerValues,它是一個迭代器的集合,其中的迭代器已經被放置在各個工作者上。

需要注意,不支持在迭代器的 "PerWorkerValues"上直接調用 "next"。該迭代器應該是作爲一個參數傳遞給 tf.distribution.experimental.coordinator.ClusterCoordinator.schedule 。當計劃的函數即將被工作者執行時,該函數將收到與該工作者相對應的單個迭代器。該函數可以對該迭代器調用 next 方法。

目前,schedule 方法假定工作者都是相同的,因此假設不同工作者上的數據集是一樣的,除非它們包含 dataset.shuffle 操作,並且沒有設置隨機種子,在這種情況下,它們的洗牌方式會不同。正因爲如此,建議將數據集無限地重複,並安排有限的步驟,而不是依賴於數據集的 OutOfRangeError 來結束。

  def create_per_worker_dataset(self, dataset_fn):
    """Create dataset on workers by calling  dataset_fn  on worker devices.

    This creates the given dataset generated by dataset_fn on workers
    and returns an object that represents the collection of those individual
    datasets. Calling  iter  on such collection of datasets returns a
     tf.distribute.experimental.coordinator.PerWorkerValues , which is a
    collection of iterators, where the iterators have been placed on respective
    workers.

    Calling  next  on a  PerWorkerValues  of iterator is unsupported. The
    iterator is meant to be passed as an argument into
     tf.distribute.experimental.coordinator.ClusterCoordinator.schedule . When
    the scheduled function is about to be executed by a worker, the
    function will receive the individual iterator that corresponds to the
    worker. The  next  method can be called on an iterator inside a
    scheduled function when the iterator is an input of the function.

    Currently the  schedule  method assumes workers are all the same and thus
    assumes the datasets on different workers are the same, except they may be
    shuffled differently if they contain a  dataset.shuffle  operation and a
    random seed is not set. Because of this, we also recommend the datasets to
    be repeated indefinitely and schedule a finite number of steps instead of
    relying on the  OutOfRangeError  from a dataset.

    Args:
      dataset_fn: The dataset function that returns a dataset. This is to be
        executed on the workers.

    Returns:
      An object that represents the collection of those individual
      datasets.  iter  is expected to be called on this object that returns
      a  tf.distribute.experimental.coordinator.PerWorkerValues  of the
      iterators (that are on the workers).
    """
    return values_lib.get_per_worker_dataset(dataset_fn, self)

get_per_worker_dataset 則返回 PerWorkerDatasetFromDataset 或者 PerWorkerDatasetFromDatasetFunction。

def get_per_worker_dataset(dataset_or_dataset_fn, coordinator):
  if callable(dataset_or_dataset_fn):
    return PerWorkerDatasetFromDatasetFunction(dataset_or_dataset_fn,
                                               coordinator)
  else:
    return PerWorkerDatasetFromDataset(dataset_or_dataset_fn, coordinator)

3.2 PerWorkerDistributedDataset

PerWorkerDistributedDataset 代表了從一個數據集建立的工作者使用的分佈式數據集。

class PerWorkerDatasetFromDataset(PerWorkerDatasetFromDatasetFunction):
  """Represents worker-distributed datasets created from a dataset."""

  def __init__(self, dataset, coordinator):
    """Makes an iterable from datasets created by the given dataset.

    It creates a dataset_fn which deserializes a dataset from a graph under the
    hood.

    Args:
      dataset: A tf.data.Dataset, a DistributedDataset or a
        DistributedDatasetsFromFunction
      coordinator: a  ClusterCoordinator  object, used to create dataset
        resources.
    """
    if isinstance(dataset, input_lib.DistributedDataset):
      original_dataset = dataset._original_dataset
      serialized = serialize_dataset_to_graph(original_dataset)

      def dataset_fn():
        deserialized = deserialize_dataset_from_graph(
            serialized, original_dataset.element_spec)
        dataset.build(dataset_to_replace=deserialized)
        return dataset
      
    elif isinstance(dataset, input_lib.DistributedDatasetsFromFunction):
      def dataset_fn():
        dataset.build()
        return dataset
      
    elif isinstance(dataset, dataset_ops.Dataset):
      serialized = serialize_dataset_to_graph(dataset)

      def dataset_fn():
        return deserialize_dataset_from_graph(serialized, dataset.element_spec)
      
    else:
      raise ValueError("Unexpected dataset type!")

    super(PerWorkerDatasetFromDataset, self).__init__(dataset_fn, coordinator)

3.3 PerWorkerDatasetFromDatasetFunction

PerWorkerDistributedDataset 代表了從一個數據集方法建立的工作者使用的分佈式數據集。

iter 之中有:

  • 調用 _create_per_worker_iterator 得到一個 iter(dataset)。

  • 調用 self._coordinator._create_per_worker_resources 爲每工作者生成一個 iterator。

  • 最後返回一個 PerWorkerDistributedIterator。

class PerWorkerDatasetFromDatasetFunction(object):
  """Represents worker-distributed datasets created from dataset function."""

  def __init__(self, dataset_fn, coordinator):
    """Makes an iterable from datasets created by the given function.

    Args:
      dataset_fn: A function that returns a  Dataset .
      coordinator: a  ClusterCoordinator  object, used to create dataset
        resources.
    """

    def disallow_variable_creation(next_creator, **kwargs):
      raise ValueError("Creating variables in  dataset_fn  is not allowed.")

    if isinstance(dataset_fn, def_function.Function):
      with variable_scope.variable_creator_scope(disallow_variable_creation):
        dataset_fn = dataset_fn.get_concrete_function()
    elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
      with variable_scope.variable_creator_scope(disallow_variable_creation):
        dataset_fn = def_function.function(dataset_fn).get_concrete_function()
    self._dataset_fn = dataset_fn
    self._coordinator = coordinator
    self._element_spec = None

  def __iter__(self):
    # We would like users to create iterators outside  tf.function s so that we
    # can track them.
    if (not context.executing_eagerly() or
        ops.get_default_graph().building_function):
      raise RuntimeError(
          "__iter__() is not supported inside of tf.function or in graph mode.")

    def _create_per_worker_iterator():
      dataset = self._dataset_fn()
      return iter(dataset)

    # If PerWorkerDatasetFromDatasetFunction.__iter__ is called multiple
    # times, for the same object it should only create and register resource
    # once. Using object id to distinguish different iterator resources.
    per_worker_iterator = self._coordinator._create_per_worker_resources(
        _create_per_worker_iterator)

    # Setting type_spec of each RemoteValue so that functions taking these
    # RemoteValues as inputs can be traced.
    for iterator_remote_value in per_worker_iterator._values:
      iterator_remote_value._type_spec = (
          input_lib.get_iterator_spec_from_dataset(
              self._coordinator.strategy, self._dataset_fn.structured_outputs))

    return PerWorkerDistributedIterator(per_worker_iterator._values)

  @property
  def element_spec(self):
    """The type specification of an element of this dataset.

    This property is subject to change without notice.
    """
    return self._dataset_fn.structured_outputs.element_spec

3.4 _create_per_worker_resources

_create_per_worker_resources 會調用各個工作者的方法來讓每個工作者得到數據。

def _create_per_worker_resources(self, fn, args=None, kwargs=None):
  """Synchronously create resources on the workers.

  The resources are represented by
   tf.distribute.experimental.coordinator.RemoteValue s.

  Args:
    fn: The function to be dispatched to all workers for execution
      asynchronously.
    args: Positional arguments for  fn .
    kwargs: Keyword arguments for  fn .

  Returns:
    A  tf.distribute.experimental.coordinator.PerWorkerValues  object, which
    wraps a tuple of  tf.distribute.experimental.coordinator.RemoteValue 
    objects.
  """
  results = []
  for w in self._cluster.workers:
    results.append(w.create_resource(fn, args=args, kwargs=kwargs))  
  return PerWorkerValues(tuple(results))

3.5 PerWorkerValues

PerWorkerValues 是一個容納 value 列表的容器,每個工作者對應一個 value。Tf.distribution.experimental.coordinator.PerWorkerValues 包含一個值的集合,其中每個值都位於其相應的工作者上,當被用作 tf.distribution.experimental.coordinator.ClusterCoordinator.schedule() 的 args 或 kwargs 時,某一個工作者的特定值將被傳遞到該工作者上執行的函數中。

創建 tf.distribution.experimental.coordinator.PerWorkerValues 對象的唯一路徑是通過在 ClusterCoordinator.create_per_worker_dataset 返回的分佈式數據集實例上調用 iter 。目前還不支持創建自定義 tf.distribution.experimental.coordinator.PerWorkerValues 的機制。

@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[])
class PerWorkerValues(composite_tensor.CompositeTensor):
  """A container that holds a list of values, one value per worker.

   tf.distribute.experimental.coordinator.PerWorkerValues  contains a collection
  of values, where each of the values is located on its corresponding worker,
  and upon being used as one of the  args  or  kwargs  of
   tf.distribute.experimental.coordinator.ClusterCoordinator.schedule() , the
  value specific to a worker will be passed into the function being executed at
  that corresponding worker.

  Currently, the only supported path to create an object of
   tf.distribute.experimental.coordinator.PerWorkerValues  is through calling
   iter  on a  ClusterCoordinator.create_per_worker_dataset -returned
  distributed dataset instance. The mechanism to create a custom
   tf.distribute.experimental.coordinator.PerWorkerValues  is not yet supported.
  """

  def __init__(self, values):
    for v in values:
      if not isinstance(v, RemoteValue):
        raise AssertionError(
            " PerWorkerValues  should only take  RemoteValue s.")
    self._values = tuple(values)

  @property
  def _type_spec(self):
    return PerWorkerValuesTypeSpec(
        self._values[0]._type_spec,  
        type(self))

獲取數據的邏輯如下:

4. Cluster

Cluster 纔是業務執行者。

4.1 定義

Cluster 是一個工作者集羣。在初始化方法之中,會做如下處理:

  • 設置如何忽略參數服務器暫時錯誤。
  • 設定工作者的設備名字。
  • 生成一系列工作者。

這裏要注意的是如何忽略因爲工作者瞬時連接錯誤而報告的故障。

  • 工作者和參數服務器之間的瞬時連接問題會由工作者轉達給協調者,這將導致協調者認爲存在參數服務器故障。
  • 瞬時與永久的參數服務器故障之間的區別是工作者報告的數量。當這個環境變量設置爲正整數 K 時,協調器忽略最多 K 個失敗報告,也就是說,只有超過 K 個執行錯誤,並且這些錯誤是因爲同一個參數服務器實例導致的,我們才認爲參數服務器實例遇到了失敗。
class Cluster(object):
  """A cluster with workers.

  We assume all function errors are fatal and based on this assumption our
  error reporting logic is:
  1) Both  schedule  and  join  can raise a non-retryable error which is the
  first error seen by the coordinator from any previously scheduled functions.
  2) When an error is raised, there is no guarantee on how many previously
  scheduled functions have been executed; functions that have not been executed
  will be thrown away and marked as cancelled.
  3) After an error is raised, the internal state of error will be cleared.
  I.e. functions can continue to be scheduled and subsequent calls of  schedule 
  or  join  will not raise the same error again.

  Attributes:
    failure_handler: The failure handler used to handler worker preemption
      failure.
    workers: a list of  Worker  objects in the cluster.
  """

  def __init__(self, strategy):
    """Initializes the cluster instance."""

    self._num_workers = strategy._num_workers
    self._num_ps = strategy._num_ps

    # 如何忽略參數服務器暫時錯誤
    self._transient_ps_failures_threshold = int(
        os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
    self._potential_ps_failures_lock = threading.Lock()
    self._potential_ps_failures_count = [0] * self._num_ps

    self._closure_queue = _CoordinatedClosureQueue()
    self.failure_handler = WorkerPreemptionHandler(context.get_server_def(),
                                                   self)
    
    # 設定 worker 的設備名字
    worker_device_strings = [
        "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
    ]
    # 生成 Workers
    self.workers = [
        Worker(i, w, self) for i, w in enumerate(worker_device_strings)
    ]

4.2 Schedule

這個類提供的最重要的API是 "schedule"/"join" 這對函數。"schedule" API是非阻塞的,它把一個 "tf.function "插入隊列,並立即返回一個 "RemoteValue"。

  def schedule(self, function, args, kwargs):
    """Schedules  function  to be dispatched to a worker for execution.

    Args:
      function: The function to be dispatched to a worker for execution
        asynchronously.
      args: Positional arguments for  fn .
      kwargs: Keyword arguments for  fn .

    Returns:
      A  RemoteValue  object.
    """
    closure = Closure(
        function,
        self._closure_queue._cancellation_mgr, 
        args=args,
        kwargs=kwargs)
    self._closure_queue.put(closure)
    return closure.output_remote_value

  def join(self):
    """Blocks until all scheduled functions are executed."""
    self._closure_queue.wait()

具體邏輯如下,虛線表示數據集被傳入,這裏的 Queue 是 from six.moves import queue 引入的 queue.Queue,我們接下來在_CoordinatedClosureQueue之中會見到。

或者我們從官方文檔圖來看,目前完成的是左邊圓圈部分。

4.3 停止

停止代碼如下,具體是調用隊列的處理方法。

  def stop(self):
    """Stop worker, worker preemption threads, and the closure queue."""
    self.failure_handler.stop()

    for worker in self.workers:
      worker.stop()
    self._closure_queue.stop()

  def done(self):
    """Returns true if all scheduled functions are executed."""
    return self._closure_queue.done()


5. 任務 Closure

Closure 的作用是把任務封裝起來,並且提供了其他功能。

class Closure(object):
  """Hold a function to be scheduled and its arguments."""

  def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
    if not callable(function):
      raise ValueError("Function passed to  ClusterCoordinator.schedule  must "
                       "be a callable object.")
    self._args = args or ()
    self._kwargs = kwargs or {}

    _disallow_remote_value_as_input(self._args)
    _disallow_remote_value_as_input(self._kwargs)

    if isinstance(function, def_function.Function):
      replica_args = _select_worker_slice(0, self._args)
      replica_kwargs = _select_worker_slice(0, self._kwargs)

      # Note: no need to handle function registration failure since this kind of
      # failure will not raise exceptions as designed in the runtime. The
      # coordinator has to rely on subsequent operations that raise to catch
      # function registration failure.

      # Record the function tracing overhead. Note that we pass in the tracing
      # count of the def_function.Function as a state tracker, so that metrics
      # will only record the time for actual function tracing (i.e., excluding
      # function cache lookups).
      with metric_utils.monitored_timer(
          "function_tracing", state_tracker=function._get_tracing_count):  
        self._concrete_function = function.get_concrete_function(
            *nest.map_structure(_maybe_as_type_spec, replica_args),
            **nest.map_structure(_maybe_as_type_spec, replica_kwargs))
    elif isinstance(function, tf_function.ConcreteFunction):
      self._concrete_function = function

    if hasattr(self, "_concrete_function"):
      # If we have a concrete function, we get to retrieve the output type spec
      # via the structured_output.
      output_type_spec = func_graph.convert_structure_to_signature(
          self._concrete_function.structured_outputs)
      self._function = cancellation_mgr.get_cancelable_function(
          self._concrete_function)
    else:
      # Otherwise (i.e. what is passed in is a regular python function), we have
      # no such information.
      output_type_spec = None
      self._function = function

    self.output_remote_value = RemoteValueImpl(self, output_type_spec)

5.1 執行

Closure 的 execute_on 負責運行,具體是在指定的設備上執行 self._function,就是用戶自定義的 function。需要注意的是,with context.executor_scope(worker.executor) 使用了 context。

  def execute_on(self, worker):
    """Executes the closure on the given worker.

    Args:
      worker: a Worker object.
    """
    replica_args = _select_worker_slice(worker.worker_index, self._args)
    replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)

    e = (
        _maybe_rebuild_remote_values(worker, replica_args) or
        _maybe_rebuild_remote_values(worker, replica_kwargs))
    if e:
      if not isinstance(e, InputError):
        e = InputError(e)
      self.output_remote_value._set_error(e) 
      return

    with ops.device(worker.device_name): # 在指定設備上
      with context.executor_scope(worker.executor): # 通過上下文
        with metric_utils.monitored_timer("closure_execution"):
          output_values = self._function( # 運行用戶的參數
              *nest.map_structure(_maybe_get_remote_value, replica_args),
              **nest.map_structure(_maybe_get_remote_value, replica_kwargs))
    self.output_remote_value._set_values(output_values) 

Self._function 是用戶自定義的 function,我們再給出一個方法示例,可以看出來可以使用 strategy.run 把訓練方法分發到遠端工作者進行訓練。

@tf.function
def worker_fn(iterator):

	def replica_fn(inputs):
      batch_data, labels = inputs
      # calculate gradient, applying gradient, metrics update etc.

	strategy.run(replica_fn, args=(next(iterator),))

5.2 取消

用戶可以設置取消 Closure,就是在返回值之中做下設置。

  def mark_cancelled(self):
    self.output_remote_value._set_error(  
        errors.CancelledError(
            None, None, "The corresponding function is "
            "cancelled. Please reschedule the function."))

5.3 ResourceClosure

ResourceClosure 是派生類,把 Closure 用 RemoteValue 包裝起來。實際上使用的都是 ResourceClosure。

class ResourceClosure(Closure):

  def build_output_remote_value(self):
    if self._output_remote_value_ref is None:
      # We need to remember the Closure object in the  RemoteValue  here.
      ret = RemoteValueImpl(self, self._output_type_spec)
      self._output_remote_value_ref = weakref.ref(ret)
      return ret
    else:
      return self._output_remote_value_ref()

6. 隊列

_CoordinatedClosureQueue 是任務所在的隊列。

6.1 定義

from six.moves import queue

class _CoordinatedClosureQueue(object):
  """Manage a queue of closures, inflight count and errors from execution.

  This class is thread-safe.
  """

  def __init__(self):
    #  self._inflight_closure_count  only tracks the number of inflight closures
    # that are "in generation". Once an error occurs, error generation is
    # incremented and all subsequent arriving closures (from inflight) are
    # considered "out of generation".
    self._inflight_closure_count = 0

    self._queue_lock = threading.Lock()

    # Condition indicating that all pending closures (either queued or inflight)
    # have been processed, failed, or cancelled.
    self._stop_waiting_condition = threading.Condition(self._queue_lock)

    # Condition indicating that an item becomes available in queue (not empty).
    self._closures_queued_condition = threading.Condition(self._queue_lock)
    self._should_process_closures = True

    # Condition indicating that a queue slot becomes available (not full).
    # Note that even with "infinite" queue size, there is still a "practical"
    # size limit for the queue depending on host memory capacity, and thus the
    # queue will eventually become full with a lot of enqueued closures.
    self._queue_free_slot_condition = threading.Condition(self._queue_lock)

    # Condition indicating there is no inflight closures.
    self._no_inflight_closure_condition = threading.Condition(self._queue_lock)

    # Use to cancel in-flight closures.
    self._cancellation_mgr = cancellation.CancellationManager()

    self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
    self._error = None

    # The following is a lock to make sure when  wait  is called and before it
    # returns no  put  can be executed during this period. It is because  wait 
    # won't know what to do with newly put closures. This lock adds an cutoff
    # for  wait  so that closures put into the queue while waiting would not be
    # taken responsible by this  wait .
    #
    # We cannot reuse the  self._queue_lock  since when  wait  waits for a
    # condition, the  self._queue_lock  will be released.
    #
    # We don't use a reader/writer's lock on purpose to reduce the complexity
    # of the code.
    self._put_wait_lock = threading.Lock()


6.2 插入取出

Put 和 get 方法分別負責插入和取出。

  def put(self, closure):
    """Put a closure into the queue for later execution.

    If  mark_failed  was called before  put , the error from the first
    invocation of  mark_failed  will be raised.

    Args:
      closure: The  Closure  to put into the queue.
    """
    with self._put_wait_lock, self._queue_lock:
      self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
      self._queue.put(closure, block=False)
      self._raise_if_error()
      self._closures_queued_condition.notify()

  def get(self, timeout=None):
    """Return a closure from the queue to be executed."""
    with self._queue_lock:
      while self._queue.empty() and self._should_process_closures:
        if not self._closures_queued_condition.wait(timeout=timeout):
          return None
      if not self._should_process_closures:
        return None
      closure = self._queue.get(block=False)
      self._queue_free_slot_condition.notify()
      self._inflight_closure_count += 1
      return closure

Put_back 則負責把 closure 重新放回queue。

  def put_back(self, closure):
    """Put the closure back into the queue as it was not properly executed."""
    with self._queue_lock:
      if self._inflight_closure_count < 1:
        raise AssertionError("There is no inflight closures to put_back.")
      if self._error:
        closure.mark_cancelled()
      else:
        self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
        self._queue.put(closure, block=False)
        self._closures_queued_condition.notify()
      self._inflight_closure_count -= 1
      if self._inflight_closure_count == 0:
        self._no_inflight_closure_condition.notifyAll()

6.3 等待

方法 wait 會等待所有 closures 結束。

  def wait(self, timeout=None):
    """Wait for all closures to be finished before returning.

    If  mark_failed  was called before or during  wait , the error from the
    first invocation of  mark_failed  will be raised.

    Args:
      timeout: A float specifying a timeout for the wait in seconds.

    Returns:
      True unless the given timeout expired, in which case it returns False.
    """
    with self._put_wait_lock, self._queue_lock:
      while (not self._error and
             (not self._queue.empty() or self._inflight_closure_count > 0)):
        if not self._stop_waiting_condition.wait(timeout=timeout):
          return False
      self._raise_if_error()
      return True

6.4 異常&結束

Mark_failed 和 done 則是處理結束和異常的一套組合。

  def mark_failed(self, e):
    """Sets error and unblocks any wait() call."""
    with self._queue_lock:
      # TODO(yuefengz): maybe record all failure and give users more
      # information?
      if self._inflight_closure_count < 1:
        raise AssertionError("There is no inflight closures to mark_failed.")
      if self._error is None:
        self._error = e
      self._inflight_closure_count -= 1
      if self._inflight_closure_count == 0:
        self._no_inflight_closure_condition.notifyAll()
      self._stop_waiting_condition.notifyAll()

  def done(self):
    """Returns true if the queue is empty and there is no inflight closure.

    If  mark_failed  was called before  done , the error from the first
    invocation of  mark_failed  will be raised.
    """
    with self._queue_lock:
      self._raise_if_error()
      return self._queue.empty() and self._inflight_closure_count == 0


6.5 停止

Stop 和 _cancel_all_closures 負責暫停 closures。

  def stop(self):
    with self._queue_lock:
      self._should_process_closures = False
      self._closures_queued_condition.notifyAll()

  def _cancel_all_closures(self):
    """Clears the queue and sets remaining closures cancelled error.

    This method expects self._queue_lock to be held prior to entry.
    """
    self._cancellation_mgr.start_cancel()
    while self._inflight_closure_count > 0:
      self._no_inflight_closure_condition.wait()
    while True:
      try:
        closure = self._queue.get(block=False)
        self._queue_free_slot_condition.notify()
        closure.mark_cancelled()
      except queue.Empty:
        break
    # The cancellation manager cannot be reused once cancelled. After all
    # closures (queued or inflight) are cleaned up, recreate the cancellation
    # manager with clean state.
    # Note on thread-safety: this is triggered when one of theses
    # ClusterCoordinator APIs are called:  schedule ,  wait , and  done . At the
    # same time, no new closures can be constructed (which reads the
    # _cancellation_mgr to get cancellable functions).
    self._cancellation_mgr = cancellation.CancellationManager()

  def _raise_if_error(self):
    """Raises the error if one exists.

    If an error exists, cancel the closures in queue, raises it, and clear
    the error.

    This method expects self._queue_lock to be held prior to entry.
    """
    if self._error:
      logging.error("Start cancelling closures due to error %r: %s",
                    self._error, self._error)
      self._cancel_all_closures()
      try:
        raise self._error  
      finally:
        self._error = None

7.4 Worker

Worker 是函數的執行者。

7.1 定義

Worker 的定義如下,其啓動了一個線程來運行 _process_queue。

class Worker(object):
  """A worker in a cluster.

  Attributes:
    worker_index: The index of the worker in the cluster.
    device_name: The device string of the worker, e.g. "/job:worker/task:1".
    executor: The worker's executor for remote function execution.
    failure_handler: The failure handler used to handler worker preemption
      failure.
  """

  def __init__(self, worker_index, device_name, cluster):
    self.worker_index = worker_index
    self.device_name = device_name
    # 這裏會有一個executor
    self.executor = executor.new_executor(enable_async=False)
    self.failure_handler = cluster.failure_handler
    self._cluster = cluster
    self._resource_remote_value_refs = []
    self._should_worker_thread_run = True

    # Worker threads need to start after  Worker 's initialization.
    threading.Thread(target=self._process_queue,
                     name="WorkerClosureProcessingLoop-%d" % self.worker_index,
                     daemon=True).start()

New_executor 會調用 TFE_NewExecutor。

def new_executor(enable_async):
  handle = pywrap_tfe.TFE_NewExecutor(enable_async)
  return Executor(handle)

TFE_NewExecutor 定義在 tensorflow/c/eager/c_api_experimental.cc,其生成了 TFE_Executor。

TFE_Executor* TFE_NewExecutor(bool is_async) {
  return new TFE_Executor(is_async);
}

TFE_Executor 定義如下,Executor類是會話執行器的抽象,在 TF2 之中,也有 EagerExecutor。

struct TFE_Executor {
  explicit TFE_Executor(bool async)
      : owned_executor(new tensorflow::EagerExecutor(async)) {}

  explicit TFE_Executor(tensorflow::EagerExecutor* executor)
      : owned_executor(nullptr), unowned_executor(executor) {}

  tensorflow::EagerExecutor* executor() {
    return owned_executor == nullptr ? unowned_executor : owned_executor.get();
  }

  std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
  tensorflow::EagerExecutor* unowned_executor;
};

7.2 處理

_process_queue 方法會從 queue 之中取出 Closure,然後運行任務。

  def _process_queue(self):
    """Function running in a worker thread to process closure queues."""
    self._maybe_delay()
    while self._should_worker_thread_run:
      closure = self._cluster._closure_queue.get()  
      if not self._should_worker_thread_run or closure is None:
        return
      self._process_closure(closure)
      # To properly stop the worker and preemption threads, it is important that
      #  ClusterCoordinator  object is not held onto so its  __del__  can be
      # called. By removing the reference to the  closure  that has already been
      # processed, we ensure that the  closure  object is released, while
      # getting the next  closure  at above  self._cluster._closure_queue.get() 
      # call.
      del closure

7.2.1 等待

_process_queue 之中首先會調用 _maybe_delay 等待環境變量配置。

  def _maybe_delay(self):
    """Delay if corresponding env vars are set."""
    # If the following two env vars variables are set. Scheduling for workers
    # will start in a staggered manner. Worker i will wait for
    #  TF_COORDINATOR_SCHEDULE_START_DELAY  * i seconds, not exceeding
    #  TF_COORDINATOR_SCHEDULE_START_DELAY_MAX .
    delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0"))
    delay_cap = int(
        os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0"))
    if delay_cap:
      delay_secs = min(delay_secs * self.worker_index, delay_cap)
    if delay_secs > 0:
      logging.info("Worker %d sleeping for %d seconds before running function",
                   self.worker_index, delay_secs)
    time.sleep(delay_secs)

7.2.2 處理任務

_process_queue 之中接着會調用 _process_closure 來運行 closure。

  def _process_closure(self, closure):
    """Runs a closure with preemption handling."""
    try:
      with self._cluster.failure_handler.wait_on_failure(
          on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure),  
          on_recovery_fn=self._set_resources_aborted,
          worker_device_name=self.device_name):
        closure.execute_on(self)
        with metric_utils.monitored_timer("remote_value_fetch"):
          # Copy the remote tensor to local (the coordinator) in case worker
          # becomes unavailable at a later time.
          closure.output_remote_value.get()
        self._cluster._closure_queue.mark_finished()  
    except Exception as e:  
      # Avoid logging the derived cancellation error
      if not isinstance(e, errors.CancelledError):
        logging.error(
            "/job:worker/task:%d encountered the following error when "
            "processing closure: %r:%s", self.worker_index, e, e)
      closure.output_remote_value._set_error(e)  
      self._cluster._closure_queue.mark_failed(e)  


7.3 數據

我們接下來看看如何把數據讀取放到工作者上運行。前面提到了,在 _create_per_worker_resources 會調用 create_resource,爲每一個工作者建立其自己的資源。

  def create_resource(self, function, args=None, kwargs=None):
    """Synchronously creates a per-worker resource represented by a  RemoteValue .

    Args:
      function: the resource function to be run remotely. It should be a
         tf.function , a concrete function or a Python function.
      args: positional arguments to be passed to the function.
      kwargs: keyword arguments to be passed to the function.

    Returns:
      one or several RemoteValue objects depending on the function return
      values.
    """
    # Some notes about the concurrency: currently all the activities related to
    # the same worker such as creating resources, setting resources' aborted
    # status, and executing closures happen on the same thread. This allows us
    # to have simpler logic of concurrency.
    closure = ResourceClosure(
        function,
        self._cluster.closure_queue._cancellation_mgr,  
        args=args,
        kwargs=kwargs)
    resource_remote_value = closure.build_output_remote_value()
    self._register_resource(resource_remote_value)

    # The following is a short-term solution to lazily create resources in
    # parallel.
    resource_remote_value._set_aborted() 
    return resource_remote_value

_register_resource 則會把每個 Worker 的資源註冊到 Worker 之上。

def _register_resource(self, resource_remote_value):
  if not isinstance(resource_remote_value, RemoteValue):
    raise ValueError("Resource being registered is not of type "
                     " tf.distribute.experimental.coordinator.RemoteValue .")
  self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))

邏輯如下,虛線表述數據流。用戶通過 put 方法向隊列之中放入 Closure,Worker 通過 put 方法從隊列獲取 Closure 執行。

7.4 停止

Stop 等一系列方法負責停止。

  def stop(self):
    """Ensure the worker thread is closed."""
    self._should_worker_thread_run = False

  def _set_resources_aborted(self):
    for weakref_resource in self._resource_remote_value_refs:
      resource = weakref_resource()
      if resource:
        resource._set_aborted()  # pylint: disable=protected-access

  def _set_dead(self):
    raise NotImplementedError("_set_dead is not implemented.")

7.5 與 Strategy 聯繫

至此,我們其實還沒有正式和 Strategy 聯繫起來,我們再用一個例子來看看,這裏會發現,傳遞給 coordinator 的方法之中,會調用 strategy.run(replica_fn, args=(next(iterator),)),這樣就和 strategy 聯繫起來了。

    strategy = ...
    coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
        strategy)

    def dataset_fn():
      return tf.data.Dataset.from_tensor_slices([1, 1, 1])

    with strategy.scope():
      v = tf.Variable(initial_value=0)

    @tf.function
    def worker_fn(iterator):
      def replica_fn(x):
        v.assign_add(x)
        return v.read_value()
      return strategy.run(replica_fn, args=(next(iterator),)) # 這裏正式聯繫起來

    distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
    distributed_iterator = iter(distributed_dataset)
    result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
    assert coordinator.fetch(result) == 1

8. Failover

8.1 策略

應對失敗的總體策略大致如下:

  • 當發現一個工作者失敗了,Coordinator 把 function 再次放入隊列,然後發給另一個工作者執行,同時啓動一個後臺線程等待恢復,如果恢復了,則用資源來重建這個工作者,繼續分配工作。

  • 因此,一些工作者的失敗並不妨礙集羣繼續工作,這使得集羣之中的實例可以偶爾不可用(例如,可搶佔或spot 實例)。但是協調者和參數服務器必須始終可用,這樣集羣才能取得進展。

8.2 工作者失敗

當發生工作者失敗(failure)時候,具體邏輯如下:

  • ClusterCoordinator 類與 tf.distribution.experimental.ParameterServerStrategy 一起使用時,具有內置的工作者故障容錯功能。也就是說,當一些工作者由於任何原因,協調器無法聯繫上它們,這些工作者的訓練進度將繼續由其餘工作者完成。
  • 在工作者恢復時,之前提供的數據集函數(對於自定義訓練循環,可以是 ClusterCoordinator.create_per_worker_dataset,或者是 tf.keras.utils.experimental.DatasetCreator 用於 Model.fit )將被調用到工作者身上,以重新創建數據集。
  • 當一個失敗的工作者恢復之後,在使用通過 create_per_worker_dataset 創建的數據被重新建立後,它將被添加到函數執行中。

8.3 參數服務器或者協調器故障

當參數服務器失敗時,schedule,join 或 done 會引發 tf.errors.UnavailableError。在這種情況下,除了重置失敗的參數服務器外,用戶還應該重新啓動協調器,使其重新連接到工作者和參數服務器,重新創建變量,並加載檢查點。如果協調器發生故障,在用戶把它重置回來之後,程序會自動連接到工作者和參數服務器,並從檢查點繼續前進。因爲協調器本身也可能變得不可用。因此建議使用某些工具以便不丟失訓練進度:

  • 因此,在用戶的程序中,必須定期保存檢查點文件,並在程序開始時恢復。如果 "tf.keras.optimizers.Optimizer" 被應用 checkpoint,在從檢查點恢復後,其 "iterations" 屬性會大致顯示已經進行的步驟數。這可以用來決定在訓練完成前還需要多少個 epochs 和步驟(steps)。
  • 對於 Model.fit,你應該使用 BackupAndRestore 回調,它可以自動處理進度的保存和恢復。
  • 對於一個自定義的訓練循環,你應該定期檢查模型變量,並在訓練開始前從檢查點(如果有的話)加載模型變量。如果優化器有檢查點,訓練進度可以從 optimizer.iterations 中大致推斷出來。
checkpoint_manager = tf.train.CheckpointManager(
    tf.train.Checkpoint(model=model, optimizer=optimizer),
    checkpoint_dir,
    max_to_keep=3)
if checkpoint_manager.latest_checkpoint:
  checkpoint = checkpoint_manager.checkpoint
  checkpoint.restore(
      checkpoint_manager.latest_checkpoint).assert_existing_objects_matched()

global_steps = int(optimizer.iterations.numpy())
starting_epoch = global_steps // steps_per_epoch

for _ in range(starting_epoch, num_epoches):
  for _ in range(steps_per_epoch):
    coordinator.schedule(step_fn, args=(per_worker_iterator,))
  coordinator.join()
  checkpoint_manager.save()

8.4 返回 RemoteValue

如果一個函數被成功執行,就可以成功獲取到 RemoteValue。這是因爲目前在執行完一個函數後,返回值會立即被複制到協調器。如果在複製過程中出現任何工作者故障,該函數將在另一個可用的工作者上重試。因此,如果你想優化性能,你可以安排(schedule)一個沒有返回值的函數。

8.5 錯誤報告

一旦協調器發現一個錯誤,如來自參數服務器的 UnavailableError 或其他應用錯誤,如來自 tf.debugging.check_numerics 的 InvalidArgument,它將在引發錯誤之前取消所有 pending 和排隊(queued)的函數。獲取它們相應的 RemoteValue 將引發一個 CancelledError 。

在引發錯誤後,協調器將不會引發相同的錯誤或任何引發一個來自已取消函數的錯誤。

ClusterCoordinator 假設所有的函數錯誤都是致命的,基於這個假設,其的錯誤報告邏輯是:

  • Schedule 和 join 都可以引發一個不可重試的錯誤,這是協調者從任何先前安排的函數中看到的第一個錯誤。
  • 當一個錯誤被拋出時,不保證有多少先前安排的功能被執行;沒有被執行的功能將被丟棄並被標記爲取消。
  • 在一個錯誤被拋出後,錯誤的內部狀態將被清除。

8.6 WorkerPreemptionHandler

WorkerPreemptionHandler 是處理失敗的主要模塊,其定義如下:

class WorkerPreemptionHandler(object):
  """Handles worker preemptions."""

  def __init__(self, server_def, cluster):
    self._server_def = server_def
    self._cluster = cluster
    self._cluster_update_lock = threading.Lock()
    self._cluster_due_for_update_or_finish = threading.Event()
    self._worker_up_cond = threading.Condition(self._cluster_update_lock)
    self._error_from_recovery = None
    self._should_preemption_thread_run = True
    self._preemption_handler_thread = threading.Thread(
        target=self._preemption_handler,
        name="WorkerPreemptionHandler",
        daemon=True)
    self._preemption_handler_thread.start()

8.6.1 配置

在 Cluster 生成時,會把 WorkerPreemptionHandler 配置進來。

self.failure_handler = WorkerPreemptionHandler(context.get_server_def(), self)

8.6.2 等待

在處理 closure 時,會用 wait_on_failure 包裹一層用來處理錯誤。

  def _process_closure(self, closure):
    """Runs a closure with preemption handling."""
    assert closure is not None
    try:
      with self._cluster.failure_handler.wait_on_failure(
          on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure),  
          on_recovery_fn=self._set_resources_aborted,
          worker_device_name=self.device_name):
        closure.execute_on(self)

WorkerPreemptionHandler 的 wait_on_failure 方法如下:

  @contextlib.contextmanager
  def wait_on_failure(self,
                      on_failure_fn=None,
                      on_transient_failure_fn=None,
                      on_recovery_fn=None,
                      worker_device_name="(unknown)"):
    """Catches worker preemption error and wait until failed workers are back.

    Args:
      on_failure_fn: an optional function to run if preemption happens.
      on_transient_failure_fn: an optional function to run if transient failure
        happens.
      on_recovery_fn: an optional function to run when a worker is recovered
        from preemption.
      worker_device_name: the device name of the worker instance that is passing
        through the failure.

    Yields:
      None.
    """
    try:
      yield
    except (errors.OpError, InputError) as e:
      # If the error is due to temporary connectivity issues between worker and
      # ps, put back closure, ignore error and do not mark worker as failure.
      if self._cluster._record_and_ignore_transient_ps_failure(e):  
        if on_transient_failure_fn:
          on_transient_failure_fn()
        return

      # Ignoring derived CancelledErrors to tolerate transient failures in
      # PS-worker communication, which initially exposed as an UnavailableError
      # and then lead to sub-function cancellation, subsequently getting
      # reported from worker to chief as CancelledError.
      # We do not mark either worker or PS as failed due to only CancelledError.
      # If there are real (non-transient) failures, they must also be reported
      # as other errors (UnavailableError most likely) in closure executions.
      if isinstance(e, errors.CancelledError) and "/job:" in str(e):
        if on_transient_failure_fn:
          on_transient_failure_fn()
        return

      # This reraises the error, if it's not considered recoverable; otherwise,
      # the following failure recovery logic run. At this time, only worker
      # unavailability is recoverable. PS unavailability as well as other
      # errors in the user function is not recoverable.
      self._validate_preemption_failure(e)

      if on_failure_fn:
        on_failure_fn()

      with self._cluster_update_lock:
        self._cluster_due_for_update_or_finish.set()
        self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC)
        if self._error_from_recovery:
          try:
            raise self._error_from_recovery
          finally:
            self._error_from_recovery = None

      if on_recovery_fn:
        with self.wait_on_failure(
            on_recovery_fn=on_recovery_fn,
            on_transient_failure_fn=on_transient_failure_fn,
            worker_device_name=worker_device_name):
          on_recovery_fn()

_validate_preemption_failure 定義如下:

  def _validate_preemption_failure(self, e):
    """Validates that the given exception represents worker preemption."""

    # Only categorize the failure as a worker preemption if the cancellation
    # manager did not attempt to cancel the blocking operations.
    if _is_worker_failure(e) and (
        not self._cluster._closure_queue._cancellation_mgr.is_cancelled):  
      return
    raise e

8.6.3 handler

WorkerPreemptionHandler 有一個後臺線程 _preemption_handler_thread。

    self._preemption_handler_thread = threading.Thread(
        target=self._preemption_handler,
        name="WorkerPreemptionHandler",
        daemon=True)
    self._preemption_handler_thread.start()


_preemption_handler 會進行必要的錯誤處理。

  def _preemption_handler(self):
    """A loop that handles preemption.

    This loop waits for signal of worker preemption and upon worker preemption,
    it waits until all workers are back and updates the cluster about the
    restarted workers.
    """
    assert self._should_preemption_thread_run
    while True:
      self._cluster_due_for_update_or_finish.wait()
      if not self._should_preemption_thread_run:
        break

      with self._cluster_update_lock:
        try:
          context.context().update_server_def(self._server_def)

          # Cluster updated successfully, clear the update signal, and notify
          # all workers that they are recovered from failure.
          self._worker_up_cond.notify_all()
          # The check for _should_preemption_thread_run is necessary since the
          #  stop  may have already set _cluster_due_for_update_or_finish.
          if self._should_preemption_thread_run:
            self._cluster_due_for_update_or_finish.clear()
        except Exception as e:  
          try:
            self._validate_preemption_failure(e)
          except Exception as ps_e: 
            # In this case, a parameter server fails. So we raise this error to
            # the caller of  wait_on_failure .
            self._error_from_recovery = ps_e
            self._worker_up_cond.notify_all()
            if self._should_preemption_thread_run:
              self._cluster_due_for_update_or_finish.clear()
          # NOTE: Since the first RPC (GetStatus) of update_server_def is
          # currently blocking by default, error should only happen if:
          # (1) More workers failed while waiting for the previous workers to
          #     come back;
          # (2) Worker failed when exchanging subsequent RPCs after the first
          #     RPC returns.
          # Consider adding backoff retry logic if we see the error logged
          # too frequently.

9. 總結

依據前面的代碼,我們總結出來問題點如下:

  • Worker 如何知道使用哪些設備?答案是:在集羣建立工作者時候,會給每一個工作者設定一個設備。

  • 如何具體執行用戶函數?答案是:在工作者運行 Closure 時候,會在指定運行在本工作者設備上,然後運行指定的方法(Self._function)。Self._function 是用戶自定義的 function,其中可以使用 strategy.run 把訓練方法分發到遠端工作者進行訓練。

  • 如何獲取數據?答案是:爲每個工作者建立一個 PerWorkerValues,PerWorkerValues 是一個容納 value 列表的容器,每個工作者從對應 PerWorkerValues 之中獲取數據。

0xFF 參考

tensorflow源碼解析之distributed_runtime

TensorFlow分佈式訓練

Tensorflow分佈式原理理解

TensorFlow架構與設計:概述

Tensorflow 跨設備通信

TensorFlow 篇 | TensorFlow 2.x 分佈式訓練概覽

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章