[源碼解析] TensorFlow 分佈式 DistributedStrategy 之基礎篇

[源碼解析] TensorFlow 分佈式 DistributedStrategy 之基礎篇

前文之中我們已經介紹了 Strategy 這個基本概念,tf.distribute.Strategy 是一個可在多個 GPU、多臺機器或 TPU 上進行分佈式訓練的 TensorFlow API。使用此 API,您只需改動較少代碼就能基於現有模型和訓練代碼來實現單機多卡,多機多卡等情況的分佈式訓練。tf.distribute.Strategy 旨在實現以下目標:

  • 覆蓋不同維度的用戶用例。
  • 易於使用,支持多種用戶(包括研究人員和 ML 工程師等)。
  • 提供開箱即用的高性能。
  • 從用戶模型代碼之中解耦,這樣可以輕鬆切換策略。
  • 支持 Custom Training Loop,Estimator,Keras。
  • 支持 eager excution。

從系統角度或者說從開發者的角度看,Strategy 是基於Python作用域或裝飾器來實現的一套機制。它提供了一組命名的分佈式策略,如ParameterServerStrategy、CollectiveStrategy來作爲Python作用域,這些策略可以被用來捕獲用戶函數中的模型聲明和訓練邏輯,其將在用戶代碼開始時生效。在後端,分佈式系統可以重寫計算圖,並根據選擇的策略(參數服務器或集合)合併相應的語義。

因此我們分析的核心就是如何把數據讀取,模型參數,分佈式計算融合到Python作用域或裝飾器之中,本章我們就從 Strategy 的類體系結構和讀取數據開始。

依然安利兩個大神:

[TensorFlow Internals] (https://github.com/horance-liu/tensorflow-internals),雖然其分析的不是最新代碼,但是建議對 TF 內部實現機制有興趣的朋友都去閱讀一下,絕對大有收穫。
https://home.cnblogs.com/u/deep-learning-stacks/ 西門宇少,不僅僅是 TensorFlow,其公共號還有更多其他領域,業界前沿。

本系列其他文章是:

[翻譯] TensorFlow 分佈式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分佈式之論文篇 "Implementation of Control Flow in TensorFlow"

[源碼解析] TensorFlow 分佈式環境(1) --- 總體架構

[源碼解析] TensorFlow 分佈式環境(2)---Master 靜態邏輯

[源碼解析] TensorFlow 分佈式環境(3)--- Worker 靜態邏輯

[源碼解析] TensorFlow 分佈式環境(4) --- WorkerCache

[源碼解析] TensorFlow 分佈式環境(5) --- Session

[源碼解析] TensorFlow 分佈式環境(7) --- Worker 動態邏輯

[源碼解析] TensorFlow 分佈式環境(8) --- 通信機制

[翻譯] 使用 TensorFlow 進行分佈式訓練

1. StrategyBase

StrategyBase 是一個設備列表之上的狀態和計算分佈策略。是 v1 策略和 v2 策略類的基類。

1.1 初始化

StrategyBase 初始化方法之中最主要就是設定 extended,其類型是 StrategyExtendedV2 或者 StrategyExtendedV1。

class StrategyBase(object):

  def __init__(self, extended):
    
    self._extended = extended

    # Flag that is used to indicate whether distribution strategy is used with
    # Estimator. This is required for backward compatibility of loss scaling
    # when using v1 optimizer with estimator.
    self._scale_loss_for_estimator = False

    if not hasattr(extended, _retrace_functions_for_each_device):
        # extended._retrace_functions_for_each_device dictates
      # whether the same function will be retraced when it is called on
      # different devices.
      try:
        extended._retrace_functions_for_each_device = (
            len(extended.worker_devices) > 1)
        distribution_strategy_replica_gauge.get_cell(num_replicas).set(
            self.num_replicas_in_sync)
      except:  
        # Default for the case where extended.worker_devices can't return
        # a sensible value.
        extended._retrace_functions_for_each_device = True

    # Below are the dicts of axis(int) -> tf.function.
    self._mean_reduce_helper_fns = {}
    self._reduce_sum_fns = {}

    # Whether this strategy is designed to work with ClusterCoordinator.
    self._should_use_with_coordinator = False

  @property
  def extended(self):
    ```tf.distribute.StrategyExtended with additional methods.```
    return self._extended

1.2 使用

如果想使用 Keras compile/fit,請參照 https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras。

也可以將 tf.distribution.Strategy 的派生類傳遞給 tf.estimator.RunConfig 來指定 tf.estimator.Estimator 應該如何分配計算,具體可以參照 https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support。

在建立和執行模型時,應該首先使用 tf.distribution.Strategy.scope 來指定一個策略。 指定策略意味着這將使代碼處於這個策略的 cross-replica context 中,因此這個策略將負責控制比如 variable placement 這樣的功能。

1.3 CTL

如果您正在編寫一個自定義的訓練循環(custom training loop),您將需要多調用一些方法,

  • 使用 tf.distribut.Strategy.experimental_distribute_dataset 將 tf.data.Dataset 轉換,使之能產生 per-replica 值。如果您想手動指定數據集如何在各個副本之間進行劃分,請使用tf.distribut.Strategy.distribut_datasets_from_function。
  • 使用 tf.distribution.Strategy.run 爲每個副本運行函數,該函數使用 per-replica 的值(例如來自tf.distribution.DistributedDataset對象)並返回一個 per-replica。這個函數是在 副本上下文 中執行的,這意味着每個操作都在每個副本上單獨執行。
  • 最後使用一個方法(如tf.distributed.Strategy.reduce)將得到的 per-replica 的值轉換成普通的張量。

下面代碼是 CTL 一個典型用例,其使用一個普通的 dataset 和 replica_fn 在名爲 my_strategy 的特定 tf.distribution.Strategy 下分佈式運行。在 replica_fn 中創建的任何變量都是使用 my_strategy 的策略創建的。

用戶可以使用 reduce API 來聚合各副本的結果,並將其作爲對 tf.distributedDataset 進行一次迭代的返回值。用戶也可以使用 tf.keras.metrics(如損失、準確度等)來累積各步驟的度量。

  with my_strategy.scope():
    @tf.function
    def distribute_train_epoch(dataset):
      def replica_fn(input):
        # process input and return result
        return result

      total_result = 0
      for x in dataset:
        per_replica_result = my_strategy.run(replica_fn, args=(x,))
        total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
                                           per_replica_result, axis=None)
      return total_result

    dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
    for _ in range(EPOCHS):
      train_result = distribute_train_epoch(dist_dataset)

1.4 Scope

分發策略的範圍(作用域)決定了如何創建變量以及在何處創建變量,比如對於 MultiWorkerMirroredStrategy 而言,創建的變量類型是 MirroredVariable ,策略將它們複製到每個工作者之上。Scope 的方法主要是通過調用 _extended._scope 來完成。該方法返回了一個 Context manager,這可以設置本策略爲當前策略,並且分發變量。

def scope(self):
    
    """Context manager to make the strategy current and distribute variables.
   
    Returns:
    A context manager.
   """
  return self._extended._scope(self)  

1.4.1 使用

具體使用方法如下:

>>> strategy = tf.distribute.MirroredStrategy([GPU:0, GPU:1])
>>> # Variable created inside scope:
>>> with strategy.scope():
...   mirrored_variable = tf.Variable(1.)
>>> mirrored_variable
MirroredVariable:{
  0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
  1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
}
>>> # Variable created outside scope:
>>> regular_variable = tf.Variable(1.)
>>> regular_variable
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>

1.4.2 功能

當進入了 Strategy.scope 之後,會執行如下操作:

  • strategy 被安裝在全局上下文內,作爲當前策略。 在這個範圍內,調用 tf.distribution.get_strategy() 將返回這個策略。在這個範圍之外,它將返回默認的無操作(no-op)策略。
  • 進入這個 scope 也就進入了 cross-replica context。
  • scope 內的變量創建將被策略攔截。每個策略都定義了它要如何影響變量的創建。像 MirroredStrategy、TPUStrategy 和 MultiWorkerMiroredStrategy 這樣的同步策略在每個副本上創建變量,而ParameterServerStrategy 在參數服務器上創建變量。這是在策略自定義的 tf.variable_creator_scope 之中完成的。
  • 在某些策略中也可以輸入默認的設備範圍:在 MultiWorkerMiroredStrategy 中,每個 worker 上輸入的默認設備範圍是 /CPU:0。

注意:進入 Scope 不會自動分配計算,除非是像 keras model.fit 這樣的高層訓練框架。如果您沒有使用 model.fit,您需要使用 strategy.run API 來明確分配該計算。

1.4.3 Scope 範圍

什麼在 Scope 之內?什麼在之外?

  • 任何創建分佈式變量的操作都必須在 strategy.scope 中調用。這可以通過在範圍上下文中直接調用變量創建函數來實現,或者由 strategy.run 或 Keras.Model.fit 自動爲您輸入。
  • 任何可能惰性創建變量的函數(例如,Model.call(),追蹤一個tf.function,等等)也應該在作用域內調用。
  • 變量創建的另一個來源可以是檢查點的恢復。
  • 任何在作用域之外創建的變量都不會被分發。

請注意,任何在策略內部創建的變量都會捕獲策略信息。因此,在 strategy.scope 之外對這些變量的讀寫也可以無縫進行,而不需要用戶進入 scope。

一些需要進入策略範圍的策略 API(如strategy.run和strategy.reduce)會自動進入 scope,這意味着在使用這些API 時,您不需要自己明確進入 scope。

模型、優化器、Metrics 可以在 TF 之中創建變量,這樣的對象應該總是在作用域內初始化。當 tf.keras.Model 在strategy.scope 內被創建,Model 對象會捕獲範圍信息。當高層的訓練框架方法,如 model.compile,model.fit 等被調用時,捕獲的範圍將被自動輸入,相關的策略將被用來分配訓練等。

警告:簡單地調用model(..)不會自動進入 Strategy 的範圍 -- 只有高水平的訓練框架 API 支持這種行爲:model.compile、model.fit、model.evaluation、model.predict 和 model.save 都可以在範圍內或範圍外調用。

1.5 StrategyExtendedV2

StrategyExtendedV2 爲需要分佈感知(distribution-aware)的算法提供額外的API。

@tf_export(distribute.StrategyExtended, v1=[])
class StrategyExtendedV2(object):
    # Additional APIs for algorithms that need to be distribution-aware.

1.5.1 locality

tf.distributed.DistributedValues 可以具有與分佈式變量相同的 locality,這導致 mirrored value 會駐留在與變量相同的設備上(而不是計算設備上)。針對 locality,用戶可以做如下操作:

  • 可以使用 tf.distribution.StrategyExtended.update 來更新變量的值。
  • 可以使用 tf.distribution.StrategyExtended.colocate_vars_with 來讓一個變量與另一個變量有相同的 locality。
  • 可以使用 tf.distribution.StrategyExtended.reduce_to 或 tf.distribution.StrategyExtended.batch_reduce_to 將 PerReplica value 轉換到另一個變量的 locality。

1.5.2 如何更新

接下來我們看看如何更新一個分佈式變量(distributed variable)。分佈式變量(distributed variable)是在多個設備上創建的變量,比如鏡像變量和同步讀取(SyncOnRead)變量。更新分佈式變量的標準模式是:

  1. 在傳遞給 tf.distribution.Strategy.run 的函數中來計算得到一個(update, variable)對列表。例如,更新可能是一個變量的損失梯度。
  2. 通過調用 tf.distribution.get_replica_context().merge_call() 來切換到 cross-replica 模式,調用時將更新和變量作爲參數。
  3. 通過調用 tf.distribution.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)(針對一個變量)或tf.distribution.StrategyExtended.batch_reduce_to(針對一個變量列表)來對更新進行求和。
  4. 可以爲每個變量調用 tf.distribution.StrategyExtended.update(v) 來更新它的值。

如果您在副本上下文中調用 tf.keras.optimizer.Optimizer.apply_gradients方法,則步驟 2 到 4 會由類tf.keras.optimizer.Optimizer 自動完成。

事實上,更新分佈式變量的更高層次的解決方案是對該變量調用 assign,就像您對普通的 tf.Variable 一樣操作。您可以在 replica contextcross-replica context 中調用該方法。

對於一個 mirrored 變量,在 replica context 中調用 assign 需要在變量構造函數中指定aggregation類型。在這種情況下,您需要自行處理在步驟2到4中描述的上下文切換和同步。如果您在 cross-replica context 中對 mirrored variable 調用 assign,您只能 assign 一個值,或者從一個鏡像的 tf.distribution.DistributedValues 中來 assign 值。對於一個 _SyncOnRead 變量,在 replica 上下文中,您可以簡單地調用 assign,而不發生任何聚合。在 cross-replica context 中,您只能給一個 SyncOnRead 變量分配一個值。

1.6 繼承關係

Strategy 繼承關係如下,其中 V1 版本是一條路線,V2 版本又是一條路線。

圖 1 Strategy 繼承關係

Extended 繼承關係如下:

圖 2 Extended 繼承關係

至此,我們分析了Strategy的類體系,但是還沒有領略Strategy的精妙之處,我們需要繼續分析下去,本文會先看看如何處理數據,下一篇看看如何處理變量。

2. 讀取數據

我們接下來看看如何讀取數據。對於輸入數據集,主要有兩種實現:

  • experimental_distribute_dataset :從 tf.data.Dataset 生成 tf.distribute.DistributedDataset,得到的數據集可以像常規數據集一樣迭代讀取。
  • _distribute_datasets_from_function :通過調用 dataset_fn 來分發 tf.data.Dataset。

我們接下來用 MirroredStrategy 來分析如何讀取數據。總體的邏輯大致如下:在每個工作者上對數據集進行復制,重新分批和分片。首先會按文件分片,這樣每個工作者將看到不同的文件子集。如果無法做到,工作者則將嘗試對最終輸入進行分片,這樣每個工作者會運行整個預處理流水線,但是隻收到自己的數據集分片,從而達到數據並行的目的。

2.1 直接讀取數據集

2.1.1 用例

以下是如何使用 experimental_distribute_dataset 來直接得到數據集。

>>> global_batch_size = 2
>>> # Passing the devices is optional.
... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
>>> # Create a dataset
... dataset = tf.data.Dataset.range(4).batch(global_batch_size)
>>> # Distribute that dataset
... dist_dataset = strategy.experimental_distribute_dataset(dataset)
>>> @tf.function
... def replica_fn(input):
...   return input*2
>>> result = []
>>> # Iterate over the tf.distribute.DistributedDataset
... for x in dist_dataset:
...   # process dataset elements
...   result.append(strategy.run(replica_fn, args=(x,)))
>>> print(result)
[PerReplica:{
  0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>,
  1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])>
}, PerReplica:{
  0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>,
  1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>
}]

2.1.2 基類實現

StrategyBase 方法之中,主要三種數據相關操作是:分批,分片,預取(大家可以回到PyTorch數據讀取部分看看異同)。

在上面的代碼片段中,分批操作具體是:

  • dataset 首先按照 global_batch_size 進行分批。
  • 其次調用 experimental_distribute_dataset 把 dataset 按照一個新分批大小(batch size)進行重新分批,新分批大小等於"全局分批大小除以同步副本數量"。用戶可以用 Pythonic for loop 來遍歷它。
  • x 是一個 tf.distribution.DistributedValues,其包含所有副本的數據,而每個副本會得到新批次大小的數據。
  • tf.distribution.Strategy.run 將負責把 x 中每個副本對應的數據(per-replica)分發給每個副本執行工作函數 replica_fn。

分片(Sharding)包含跨多個工作者的自動分片(autosharding)。

  • 首先,在多工作者(multi-worker)分佈式訓練中(使用tf.distribution.experimental.MultiWorkerMirroredStrategy 或 tf.distribution.TPUStrategy 時),在一組工作者上自動分片(autosharding)數據集意味着每個工作者被分配了整個數據集的一個子集(如果設置了正確的tf.data.experimental.AutoShardPolicy)。這是爲了確保在每個 step 中,每個工作者都會處理一個全局的,包含不重疊的數據集元素的批次。自動分片有幾個不同的選項,可以使用 tf.data.experimental.DistributeOptions 來指定。
  • 然後,每個工作者內的分片意味着該方法將在所有工作者設備之間分割數據(如果存在多個)。無論多工作者(multi-worker)是否設定自動分片,這都會發生。
  • 對於跨多個工作者的自動分片,默認模式是 tf.data.experimental.AutoShardPolicy.AUTO。如果數據集是從讀者數據集(例如tf.data.TFRecordDataset、tf.data.TextLineDataset等)中創建的,該模式將嘗試按文件分片,否則按數據分片,其中每個工作者將讀取整個數據集,但是隻處理分配給它的分片。然而,如果每個工作者的輸入文件少於一個,我們建議您通過設置 tf.data.experimental.DistributeOptions.auto_shard_policy 爲 tf.data.experimental.AutoShardPolicy.OFF 來禁止跨工作者的數據集自動分片。

對於預取(prefetch),默認情況下,該方法在用戶提供的 tf.data.Dataset 實例的末尾添加一個預取轉換。預取轉換的參數是 buffer_size,就是同步的副本(replicas in sync)的數量。

experimental_distribute_dataset 的定義如下,其實就是調用 extended 來完成操作。

  def experimental_distribute_dataset(self, dataset, options=None):
    """Creates tf.distribute.DistributedDataset from tf.data.Dataset.

    Args:
      dataset: tf.data.Dataset that will be sharded across all replicas using
        the rules stated above.
      options: tf.distribute.InputOptions used to control options on how this
        dataset is distributed.

    Returns:
      A tf.distribute.DistributedDataset.
    """
    distribution_strategy_input_api_counter.get_cell(
        self.__class__.__name__, "distribute_dataset").increase_by(1)

    return self._extended._experimental_distribute_dataset(dataset, options)  

2.1.3 MirroredExtended 實現

我們用 MirroredExtended 來看看具體實現,其實就是調用 input_lib.get_distributed_dataset 來進行處理,因此我們深入到 input_lib 之中。

def _experimental_distribute_dataset(self, dataset, options):
  if (options and options.experimental_replication_mode ==
      distribute_lib.InputReplicationMode.PER_REPLICA):
    raise NotImplementedError(
        "InputReplicationMode.PER_REPLICA "
        "is only supported in "
        "distribute_datasets_from_function."
    )
  return input_lib.get_distributed_dataset(
      dataset,
      self._input_workers_with_options(options),
      self._container_strategy(),
      num_replicas_in_sync=self._num_replicas_in_sync,
      options=options)

2.1.4 input_lib 功能

input_lib 提供了關於處理輸入數據的一些基礎功能。get_distributed_dataset 是一個通用函數,其可以被所有策略用來返回分佈式數據集。返回的分佈式數據集實例是不同的,這取決於我們是在 TF1 還是 TF2 的背景下。返回的分佈式數據集實例的 API 也有所不同。這裏用到了 DistributedDataset 和 input_workers,所以我們有必要一一進行分析。

def get_distributed_dataset(dataset,
                            input_workers,
                            strategy,
                            num_replicas_in_sync=None,
                            input_context=None,
                            options=None,
                            build=True):
  """Returns a distributed dataset from the given tf.data.Dataset instance.

  Args:
    dataset: a tf.data.Dataset instance.
    input_workers: an InputWorkers object which specifies devices on which
        iterators should be created.
    strategy: a tf.distribute.Strategy object, used to run all-reduce to
        handle last partial batch.
    num_replicas_in_sync: Optional integer. If this is not None, the value is
        used to decide how to rebatch datasets into smaller batches so that
        the total batch size for each step (across all workers and replicas)
        adds up to dataset's batch size.
    input_context: InputContext for sharding. Only pass this in for between
        graph multi-worker cases where there is only one input_worker. In
        these cases, we will shard based on the input_pipeline_id and
        num_input_pipelines in the InputContext.
    options: Default is None. tf.distribute.InputOptions used to control
        options on how this dataset is distributed.
    build: whether to build underlying datasets when a DistributedDataset is
        created. This is only useful for ParameterServerStrategy now.

  Returns:
    A distributed dataset instance.
  """
  if tf2.enabled():
    return DistributedDataset( # 接下來會分析 DistributedDataset
        input_workers,
        strategy,
        dataset,
        num_replicas_in_sync=num_replicas_in_sync,
        input_context=input_context,
        build=build,
        options=options)
  else:
    return DistributedDatasetV1(
        dataset,
        input_workers, # 接下來會分析 InputWorkers
        strategy,
        num_replicas_in_sync=num_replicas_in_sync,
        input_context=input_context,
        options=options)

2.1.5 InputWorkers

定義

InputWorkers 的作用是從輸入 worker 設備到計算設備的 1-to-many mapping。worker_device_pairs 就是映射關係列表,每個 item 是 (input device, a tuple of compute devices fed by that input device)。

class InputWorkers(object):
  """A 1-to-many mapping from input worker devices to compute devices."""

  # TODO(ishark): Remove option canonicalize_devices and make all the callers
  # pass canonicalized or raw device strings as relevant from strategy.
  def __init__(self,
               worker_device_pairs,
               canonicalize_devices=True):
    """Initialize an InputWorkers object.

    Args:
      worker_device_pairs: A sequence of pairs: (input device, a tuple of
        compute devices fed by that input device).
      canonicalize_devices: Whether to canonicalize devices for workers fully or
        partially. If False, it will partially canonicalize devices by removing
        job and task.
    """
    self._worker_device_pairs = worker_device_pairs
    self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs)
    self._canonicalize_devices = canonicalize_devices
    if canonicalize_devices:
      self._fed_devices = tuple(
          tuple(device_util.canonicalize(d)
                for d in f)
          for _, f in self._worker_device_pairs)
    else:
      self._fed_devices = tuple(
          tuple(device_util.canonicalize_without_job_and_task(d)
                for d in f)
          for _, f in self._worker_device_pairs)

  @property
  def num_workers(self):
    return len(self._input_worker_devices)

  @property
  def worker_devices(self):
    return self._input_worker_devices # 返回 device, worker 信息

  def compute_devices_for_worker(self, worker_index):
    return self._fed_devices[worker_index]

  def __repr__(self):
    devices = self.worker_devices
    debug_repr = ",\n".join("  %d %s: %s" %
                            (i, devices[i], self._fed_devices[i])
                            for i in range(len(devices)))
    return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)

  def serialize(self):
    return (self._worker_device_pairs, self._canonicalize_devices)

  def deserialize(self, serialized):
    return InputWorkers(serialized)

構建

在 MirroredStrategy 之中有成員變量 _input_workers,因此,如果調用時候就會生成 InputWorkers。

@property
def _input_workers(self):
  return self._input_workers_with_options()

_input_workers_with_options 會根據 self._devices 來進行配置,就是生成各種映射關係,然後配置進去。

def _input_workers_with_options(self, options=None):
  if not options: 
    # 沒有配置就直接建立
    return input_lib.InputWorkers(self._input_workers_devices)

  # 有配置就依據配置生成
  if (options.experimental_replication_mode ==
      distribute_lib.InputReplicationMode.PER_REPLICA):
    # PER_REPLICA 處理
    if options.experimental_place_dataset_on_device:
      self._input_workers_devices = (
          tuple(
              (device_util.canonicalize(d, d), (d,)) for d in self._devices))
    else:
      self._input_workers_devices = (
          tuple((device_util.canonicalize("/device:CPU:0", d), (d,))
                for d in self._devices))
    return input_lib.InputWorkers(self._input_workers_devices)
  else:
    if not options.experimental_fetch_to_device:
      return input_lib.InputWorkers([
          (host_device, (host_device,) * len(compute_devices))
          for host_device, compute_devices in self._input_workers_devices
      ])
    else:
      return input_lib.InputWorkers(self._input_workers_devices)

這裏使用了 device_util.canonicalize 方法,其作用是把設備分類。

def canonicalize(d, default=None):
  """Canonicalize device string.

  If d has missing components, the rest would be deduced from the default
  argument or from '/replica:0/task:0/device:CPU:0'. For example:
    If d = '/cpu:0', default='/job:worker/task:1', it returns
      '/job:worker/replica:0/task:1/device:CPU:0'.
    If d = '/cpu:0', default='/job:worker', it returns
      '/job:worker/replica:0/task:0/device:CPU:0'.
    If d = '/gpu:0', default=None, it returns
      '/replica:0/task:0/device:GPU:0'.

  Note: This uses "job:localhost" as the default if executing eagerly.

  Args:
    d: a device string or tf.config.LogicalDevice
    default: a string for default device if d doesn't have all components.

  Returns:
    a canonicalized device string.
  """
  if isinstance(d, context.LogicalDevice):
    d = tf_device.DeviceSpec.from_string(d.name)
  else:
    d = tf_device.DeviceSpec.from_string(d)

  # Fill in missing device fields using defaults.
  result = tf_device.DeviceSpec(
      replica=0, task=0, device_type="CPU", device_index=0)
  if ops.executing_eagerly_outside_functions():
    # Try to deduce job, replica and task in case it's in a multi worker setup.
    host_cpu = tf_device.DeviceSpec.from_string(
        config.list_logical_devices("CPU")[0].name)
    if host_cpu.job:
      result = result.make_merged_spec(host_cpu)
    else:
      # The default job is localhost if eager execution is enabled
      result = result.replace(job="localhost")
  if default:
    # Overrides any defaults with values from the default device if given.
    result = result.make_merged_spec(
        tf_device.DeviceSpec.from_string(default))

  # Apply d last, so that it's values take precedence over the defaults.
  result = result.make_merged_spec(d)
  return result.to_string()

2.1.6 DistributedDataset

DistributedDataset 支持預先分發數據到多個設備。

初始化

下面代碼中省略了大量檢查代碼,關鍵點是調用了 build 方法。

class DistributedDataset(_IterableInput, composite_tensor.CompositeTensor):
  """Distributed dataset that supports prefetching to multiple devices."""

  def __init__(self,
               input_workers,
               strategy,
               dataset=None,
               num_replicas_in_sync=None,
               input_context=None,
               components=None,
               element_spec=None,
               enable_get_next_as_optional=None,
               build=True,
               options=None):
    """Distribute the dataset on all workers.

    If num_replicas_in_sync is not None, we split each batch of the dataset
    into num_replicas_in_sync smaller batches, to be distributed among that
    worker's replicas, so that the batch size for a global step (across all
    workers and replicas) is as expected.

    Args:
      input_workers: an InputWorkers object.
      strategy: a tf.distribute.Strategy object, used to run all-reduce to
        handle last partial batch.
      dataset: tf.data.Dataset that will be used as the input source. Either
        dataset or components field should be passed when constructing
        DistributedDataset. Use this when contructing DistributedDataset from a
        new tf.data.Dataset. Use components when constructing using
        DistributedDatasetSpec.
      num_replicas_in_sync: Optional integer. If this is not None, the value
        is used to decide how to rebatch datasets into smaller batches so that
        the total batch size for each step (across all workers and replicas)
        adds up to dataset's batch size.
      input_context: InputContext for sharding. Only pass this in for between
        graph multi-worker cases where there is only one input_worker. In
        these cases, we will shard based on the input_pipeline_id and
        num_input_pipelines in the InputContext.
      components: datasets when DistributedDataset is constructed from
        DistributedDatasetSpec. Either field dataset or components should be
        passed.
      element_spec: element spec for DistributedDataset when constructing from
        DistributedDatasetSpec. This will be used to set the element_spec for
        DistributedDataset and verified against element_spec from components.
      enable_get_next_as_optional: this is required when components is passed
        instead of dataset.
      build: whether to build underlying datasets when this object is created.
        This is only useful for ParameterServerStrategy now.
      options: tf.distribute.InputOptions used to control options on how this
        dataset is distributed.
    """
    super(DistributedDataset, self).__init__(input_workers=input_workers)

    self._input_workers = input_workers
    self._strategy = strategy
    self._options = options
    self._input_context = input_context
    self._num_replicas_in_sync = num_replicas_in_sync

    if dataset is not None:
      self._original_dataset = dataset
      self._built = False
      if build:
        self.build() # 這裏是關鍵
    else:
      self._cloned_datasets = components
      self._cardinality = _cardinality(self._cloned_datasets[0])
      self._enable_get_next_as_optional = enable_get_next_as_optional

      if element_spec != _create_distributed_tensor_spec(
          self._strategy, self._cloned_datasets[0].element_spec):
        raise ValueError("Mismatched element_spec from the passed components")
      self._element_spec = element_spec

      self._built = True

建立數據

build 主要作用是調用 _create_cloned_datasets_from_dataset。

def build(self, dataset_to_replace=None):
  dataset = dataset_to_replace or self._original_dataset
  self._cardinality = _cardinality(dataset)
  self._enable_get_next_as_optional = _enable_get_next_as_optional(
      self._strategy, dataset, self._cardinality)
  self._create_cloned_datasets_from_dataset(dataset, self._input_context,
                                            self._input_workers,
                                            self._strategy,
                                            self._num_replicas_in_sync)
  self._element_spec = _create_distributed_tensor_spec(
      self._strategy, self._cloned_datasets[0].element_spec)
  self._built = True

_create_cloned_datasets_from_dataset 在每個工作者上對數據集進行克隆和分片(這裏就使用到了InputWorkers以獲取設備信息)。首先會嘗試按文件分片,以便每個工作者看到不同的文件子集。如果無法做到,則將嘗試對最終輸入進行分片,這樣每個工作者將運行整個預處理管道,並且只收到自己的數據集分片。

此外,_create_cloned_datasets_from_dataset 將每個工作者上的數據集重新匹配成 num_replicas_in_sync 個更小的批次。這些更小的批次分佈在該工作者的副本中,這樣全局步驟(global step)的批次大小(跨越所有工作者和副本)加起來就等於原始數據集的批次大小。

def _create_cloned_datasets_from_dataset(self, dataset, input_context,
                                         input_workers, strategy,
                                         num_replicas_in_sync):
  if num_replicas_in_sync is not None:
    num_workers = input_context.num_input_pipelines if input_context else len(
        input_workers.worker_devices)
    # 用 _make_rebatch_fn 來重新 batch 數據
    rebatch_fn = self._make_rebatch_fn(dataset, num_workers,
                                       num_replicas_in_sync)
  else:
    rebatch_fn = None
  self._cloned_datasets = []
  
  if input_context:
    # Between-graph where we rely on the input_context for sharding
    if rebatch_fn is not None:
      dataset = rebatch_fn(dataset, input_context.input_pipeline_id)
    dataset = input_ops.auto_shard_dataset(dataset,
                                           input_context.num_input_pipelines,
                                           input_context.input_pipeline_id,
                                           num_replicas_in_sync)
    self._cloned_datasets.append(dataset)
  else:
    # 複製數據,返回 _RemoteDataset
    replicated_ds = distribute.replicate(dataset,
                                         input_workers.worker_devices)
    
    for i, worker in enumerate(input_workers.worker_devices):
      with ops.device(worker):
        cloned_dataset = replicated_ds[worker] # 找到某 worker 對應的數據集
        if rebatch_fn is not None:
          cloned_dataset = rebatch_fn(cloned_dataset, i) # 重新 batch,返回 _RebatchDataset
        # 自動分區,返回 _AutoShardDataset
        cloned_dataset = input_ops.auto_shard_dataset(
            cloned_dataset, len(input_workers.worker_devices), i,
            num_replicas_in_sync)
        self._cloned_datasets.append(cloned_dataset)

distribute.replicate 是用來複制數據,把數據複製到一系列設備上,這裏返回 _RemoteDataset。

def replicate(dataset, devices):
  """A transformation that replicates dataset onto a list of devices.

  Args:
    dataset: A tf.data.Dataset object.
    devices: A list of devices to replicate the dataset on.

  Returns:
    A dictionary mapping device name to a dataset on that device.
  """

  dataset_device = dataset._variant_tensor.device

  datasets = {}
  if len(devices) == 1 and devices[0] == dataset_device:
    datasets[devices[0]] = dataset
    return datasets

  with ops.colocate_with(dataset._variant_tensor):
    dataset = dataset._apply_debug_options()
    graph_def = dataset._as_serialized_graph(
        strip_device_assignment=True,
        external_state_policy=ExternalStatePolicy.WARN)
    
  for device in devices: # 遍歷設備,複製數據到設備上,每個設備一個 _RemoteDataset
    ds = _RemoteDataset(graph_def, device, dataset.element_spec)
    datasets[device] = ds
    
  return datasets

_make_rebatch_fn 返回一個把輸入數據集 rebatches 的 callable,這裏返回 _RebatchDataset。

def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync):
  """Returns a callable that rebatches the input dataset.

  Args:
    dataset: A tf.data.Dataset representing the dataset to be distributed.
    num_workers: An integer representing the number of workers to distribute
      dataset among.
    num_replicas_in_sync: An integer representing the number of replicas in
      sync across all workers.
  """
  if num_replicas_in_sync % num_workers:
    raise ValueError(
        "tf.distribute expects every worker to have the same number of "
        "replicas. However, encountered num_replicas_in_sync ({}) that "
        "cannot be divided by num_workers ({})".format(
            num_replicas_in_sync, num_workers))

  num_replicas_per_worker = num_replicas_in_sync // num_workers
  with ops.colocate_with(dataset._variant_tensor):  
    batch_size = distribute.compute_batch_size(dataset)

  def rebatch_fn(dataset, worker_index):
    try:
      def apply_rebatch():
        batch_sizes = distribute.batch_sizes_for_worker(
            batch_size, num_workers, num_replicas_per_worker, worker_index)
        return distribute._RebatchDataset(
            dataset, batch_sizes).prefetch(num_replicas_per_worker)

      def apply_legacy_rebatch():
        return distribute._LegacyRebatchDataset(
            dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker)

      with ops.colocate_with(dataset._variant_tensor):
        return control_flow_ops.cond(
            math_ops.not_equal(batch_size, -1),
            true_fn=apply_rebatch,
            false_fn=apply_legacy_rebatch)
    except errors.InvalidArgumentError as e:
      if "without encountering a batch" in str(e):
        six.reraise(
            ValueError,
            ValueError(
                "Call the batch method on the input Dataset in order to be "
                "able to split your input across {} replicas.\n Please see "
                "the tf.distribute.Strategy guide. {}".format(
                    num_replicas_in_sync, e)),
            sys.exc_info()[2])
      else:
        raise

  return rebatch_fn

接下來是自動分片,這裏返回 _AutoShardDataset。

def auto_shard_dataset(dataset, num_shards, index, num_replicas_in_sync=None):
  """Shard the input pipeline by sharding the underlying list of files.

  Args:
    dataset: A tf.data.Dataset instance, typically the result of a bunch of
      dataset transformations.
    num_shards: A tf.int64 scalar tf.Tensor, representing the number of
        shards operating in parallel. Same usage as in tf.data.Dataset.shard.
    index: A tf.int64 scalar tf.Tensor, representing the worker index.
      Same usage as in tf.data.Dataset.shard.
    num_replicas_in_sync: An integer representing the total number of replicas
      across all workers. This is used in the rewrite when sharding by data.

  Returns:
    A modified Dataset obtained by updating the pipeline sharded by the
    files. The input dataset will be returned if we cannot automatically
    determine a good way to shard the input dataset.
  """
  if (dataset.options().experimental_distribute.auto_shard_policy !=
      AutoShardPolicy.OFF):
    if num_replicas_in_sync is None:
      num_replicas_in_sync = 1
    if isinstance(dataset, dataset_ops.DatasetV1):
      return distribute._AutoShardDatasetV1(dataset, num_shards, index,
                                            num_replicas_in_sync)
    else:
      return distribute._AutoShardDataset(dataset, num_shards, index,
                                          num_replicas_in_sync)
  else:
    return dataset

此時流程圖如下,可以看到數據集功能逐漸加強,首先是 _RemoteDataset,然後升級到 _AutoShardDataset。

圖 3 建立數據集

數據集

因爲上面涉及了幾種數據集,所以我們要再仔細梳理一下這其中的關係,其具體可以理解爲在數據集 DatasetV2 基礎之上逐步添加功能,最終返回給用戶。

_RemoteDataset 對應遠端數據集。

_RemoteDataset 繼承了 DatasetSource。dataset_ops.DatasetSource 繼承 DatasetV2(就是data.Dataset)。

class DatasetSource(DatasetV2):
  """Abstract class representing a dataset with no inputs."""
 
@tf_export("data.Dataset", v1=[])
@six.add_metaclass(abc.ABCMeta)
class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
                composite_tensor.CompositeTensor):

具體 _RemoteDataset 如下,其利用with ops.device(device)把數據集設定到遠端設備上。

class _RemoteDataset(dataset_ops.DatasetSource):
  """Creates a dataset on a given device given a graph def."""

  def __init__(self, graph_def, device, element_spec):
    self._elem_spec = element_spec
    with ops.device(device): # 這裏會把數據集設定到遠端設備上
      variant_tensor = ged_ops.dataset_from_graph(graph_def)
    super(_RemoteDataset, self).__init__(variant_tensor)

  @property
  def element_spec(self):
    return self._elem_spec

_RebatchDataset 代表重新分批,具體使用參見如下:

  ds = tf.data.Dataset.range(8)
  ds = ds.batch(4)
  ds = _RebatchDataset(ds, batch_sizes=[2, 1, 1])
  for elem in ds:
    print(elem)
  >> [0, 1], [2], [3], [4, 5], [6], [7]

  ds = tf.data.Dataset.range(16)
  ds = ds.batch(4)
  ds = _RebatchDataset(ds, batch_sizes=[6])
  for elem in ds:
    print(elem)
  >> [0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11], [12, 13, 14, 15]

具體代碼如下:

class _RebatchDataset(dataset_ops.UnaryDataset):
  """A Dataset that rebatches elements from its input into new batch sizes.

  _RebatchDataset(input_dataset, batch_sizes) is functionally equivalent to
  input_dataset.unbatch().batch(N), where the value of N cycles through the
  batch_sizes input list. The elements produced by this dataset have the same
  rank as the elements of the input dataset.

  """

  def __init__(self, input_dataset, batch_sizes, drop_remainder=False):
    """Creates a _RebatchDataset.

    Args:
      input_dataset: Dataset to rebatch.
      batch_sizes: A tf.int64 scalar or vector, representing the size of
        batches to produce. If this argument is a vector, these values are
        cycled through in order.
      drop_remainder: (Optional.) A tf.bool scalar tf.Tensor, representing
        whether the last batch should be dropped in the case it has fewer than
        batch_sizes[cycle_index] elements; the default behavior is not to drop
        the smaller batch.
    """
    self._input_dataset = input_dataset
    self._batch_sizes = ops.convert_to_tensor(
        batch_sizes, dtype=dtypes.int64, name="batch_sizes")
    self._drop_remainder = ops.convert_to_tensor(
        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
    new_batch_dim = self._compute_static_batch_dim()

    self._element_spec = nest.map_structure(
        lambda ts: ts._unbatch()._batch(new_batch_dim),
        dataset_ops.get_structure(input_dataset))

    # auto_shard rewrite assumes that there's normalize_to_dense before
    # rebatch_dataset.
    # LINT.IfChange
    input_dataset = dataset_ops.normalize_to_dense(input_dataset)
    variant_tensor = ged_ops.rebatch_dataset_v2(
        input_dataset._variant_tensor,  
        batch_sizes=batch_sizes,
        drop_remainder=drop_remainder,
        **self._flat_structure)
    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)

  def _compute_static_batch_dim(self):
    """Computes the static batch dimension of a dataset if it can be determined.

    Given the _RebatchDataset parameters, determines the batch dimension of this
    dataset statically. Returns None if this cannot be determined or is
    variable.

    Returns:
      An integer representing the batch dimension of the dataset. If it cannot
      be determined statically, returns None.

    Raises:
      ValueError: The batch_sizes parameter is malformed, input_dataset is
      not batched, or input_dataset batch sizes are incompatible with each
      other.
    """
    new_batch_dim = tensor_util.constant_value(self._batch_sizes)
    if new_batch_dim is None:
      return None

    if isinstance(new_batch_dim, np.ndarray):
      if len(new_batch_dim.shape) == 1:
        if np.all(new_batch_dim == new_batch_dim[0]):
          new_batch_dim = new_batch_dim[0]
        else:
          return None
      elif len(new_batch_dim.shape) > 1:
        raise ValueError(
            f"Invalid batch_sizes. Expected batch_sizes to be a scalar or "
            f"a vector. Received batch_sizes of rank "
            f"{len(new_batch_dim.shape)}."
        )

    if self._may_form_partial_batches(new_batch_dim):
      return None

    return new_batch_dim

  def _may_form_partial_batches(self, desired_batch_size):
    """Returns whether this dataset may form partial batches."""
    if tensor_util.constant_value(self._drop_remainder):
      return False

    def get_batch_dim(type_spec):
      shape = type_spec._to_legacy_output_shapes()  
      if not isinstance(shape, tensor_shape.TensorShape):
        return None
      if shape.rank is None:
        return None
      return shape.dims[0].value

    input_batch_dims = [
        get_batch_dim(ts)
        for ts in nest.flatten(dataset_ops.get_structure(self._input_dataset))
    ]
    known_input_batch_dims = [d for d in input_batch_dims if d is not None]

    if not known_input_batch_dims:
      return True

    known_input_batch_dims = np.asarray(known_input_batch_dims)

    return known_input_batch_dims[0] % desired_batch_size != 0

  @property
  def element_spec(self):
    return self._element_spec

_AutoShardDataset 對數據集自動分片。

這個數據集接收了一個現有的數據集,並嘗試自動找出如何在多工作者的情況下使用圖來對數據集進行分片。

  • 如果 AutoShardPolicy 設置爲 FILE,它就會沿着數據集圖向上走,直到找到一個讀者數據集(reader dataset),然後在該節點之前插入一個 ShardDataset op,這樣每個工作者只能看到一些文件。

  • 如果 AutoShardPolicy 設置爲 DATA,它會在輸入流水線的末端,在 terminal PrefetchDataset(如果有)之前,插入一個 ShardDataset 操作。此外,如果輸入管道中有 RebatchDatasetV2,出於正確性考慮,它將被寫入 legacy RebatchDataset,因爲 RebatchDatasetV2 與數據分片不兼容。

  • 如果 AutoShardPolicy 設置爲 AUTO,它將嘗試進行基於文件的分片。如果找不到讀者數據集,它就會退回到進行基於數據的分片。

  • 如果 AutoShardPolicy 被設置爲 OFF,則不進行處理。

class _AutoShardDataset(dataset_ops.UnaryDataset):
  """A Dataset that shards the Dataset automatically.

  This dataset takes in an existing dataset and tries to automatically figure
  out how to shard the dataset in a multi-worker scenario using graph rewrites.

  If the AutoShardPolicy is set to FILE, it walks up the dataset graph until
  it finds a reader dataset, then inserts a ShardDataset op before that node
  so that each worker only sees some files.

  If the AutoShardPolicy is set to DATA, it inserts a ShardDataset op at the
  end of the input pipeline, before any terminal PrefetchDataset if there is
  one. Additionally, if there is a RebatchDatasetV2 in the input pipeline, it
  is written to legacy RebatchDataset for correctness reasons, since
  RebatchDatasetV2 is incompatible with data sharding.

  If the AutoShardPolicy is set to AUTO, it tries to do file-based sharding.
  If it cannot find a reader dataset, it falls back to doing data-based
  sharding.

  If the AutoShardPolicy is set to OFF, it does nothing.

  Attributes:
    num_workers: Total number of workers to shard this dataset across.
    index: The current worker index (out of the total number of workers) this
      dataset is for.
    num_replicas: The total number of replicas across all workers. This is used
      only when sharding by data (either DATA or AUTO) in order to rewrite
      RebatchDatasetV2 to RebatchDataset.

  Raises:
    NotFoundError: If we cannot find a suitable reader dataset to begin
      automatically sharding the dataset.
  """

  def __init__(self, input_dataset, num_workers, index, num_replicas=None):
    self._input_dataset = input_dataset

    self._element_spec = input_dataset.element_spec
    variant_tensor = ged_ops.auto_shard_dataset(
        self._input_dataset._variant_tensor, 
        num_workers=num_workers,
        index=index,
        auto_shard_policy=int(
            input_dataset.options().experimental_distribute.auto_shard_policy),
        num_replicas=num_replicas,
        **self._flat_structure)
    super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)

  @property
  def element_spec(self):
    return self._element_spec

在 tensorflow\core\grappler\optimizers\data\auto_shard.cc 之中有如下做自動分片的代碼,有興趣的讀者可以自行深入。

Status ApplyAutoShard(const NodeDef& sink_node, int64_t num_workers,
                      int64_t index, AutoShardPolicy policy,
                      int64_t num_replicas, MutableGraphView* graph,
                      AutoShardPolicy* policy_applied) {
  *policy_applied = policy;
  FunctionLibraryDefinition flib(OpRegistry::Global(),
                                 graph->graph()->library());
  switch (policy) {
    case AutoShardPolicy::OFF:
      return Status::OK();
    case AutoShardPolicy::FILE:
      return ShardByFile(sink_node, num_workers, index, &flib, graph);
    case AutoShardPolicy::DATA:
      return ShardByData(sink_node, num_workers, index, num_replicas, graph);
    case AutoShardPolicy::HINT:
      return ShardByHint(sink_node, num_workers, index, num_replicas, graph);
    case AutoShardPolicy::AUTO:
    default:
      Status s = ShardByFile(sink_node, num_workers, index, &flib, graph);
      if (errors::IsNotFound(s)) {
        LOG(WARNING) << "AUTO sharding policy will apply DATA sharding policy "
                        "as it failed to apply FILE sharding policy because of "
                        "the following reason: "
                     << s.error_message();
        *policy_applied = AutoShardPolicy::DATA;
        return ShardByData(sink_node, num_workers, index, num_replicas, graph);
      }
      *policy_applied = AutoShardPolicy::FILE;
      return s;
  }
}

具體關係如下,DistributedDataset 成員變量 _cloned_datasets 列表包括了多個 _AutoShardDataset,每個針對一個 Worker。

圖 4 數據集關係

迭代數據

我們接下來看看 DistributedDataset 如何迭代,iter 方法會針對每個 worker 建立一個 iterator,最後統一返回一個 DistributedIterator。

def __iter__(self):

  canonicalize_devices = getattr(self._strategy, "_canonicalize_devices", True)

  # 會針對每個 worker 建立一個 iterator
  worker_iterators = _create_iterators_per_worker(
      self._cloned_datasets,
      self._input_workers,
      enable_legacy_iterators=False,
      options=self._options,
      canonicalize_devices=canonicalize_devices)
    
  # 統一返回一個 DistributedIterator  
  iterator = DistributedIterator(
      self._input_workers,
      worker_iterators,
      self._strategy,
      cardinality=self._cardinality,
      enable_get_next_as_optional=self._enable_get_next_as_optional,
      options=self._options)
  iterator._element_spec = self._element_spec  # pylint: disable=protected-access

  # When async eager is enabled, sometimes the iterator may not finish
  # initialization before passing to a multi device function, add a sync point
  # here to make sure all underlying iterators are initialized.
  if context.executing_eagerly():
    context.async_wait()

  return iterator

_create_iterators_per_worker 爲每個 worker 建立一個 multidevice iterator。

def _create_iterators_per_worker(worker_datasets,
                                 input_workers,
                                 enable_legacy_iterators,
                                 options=None,
                                 canonicalize_devices=False):
  """Create a multidevice iterator on each of the workers."""
  iterators = []
  for i, worker in enumerate(input_workers.worker_devices):
    with ops.device(worker):
      worker_devices = input_workers.compute_devices_for_worker(i)
      if tf2.enabled() and not enable_legacy_iterators:
        iterator = _SingleWorkerOwnedDatasetIterator(
            dataset=worker_datasets[i],
            worker=worker,
            devices=worker_devices,
            options=options,
            canonicalize_devices=canonicalize_devices)
      else:
        iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
                                                worker_devices, options)
      iterators.append(iterator)
  return iterators

_SingleWorkerDatasetIterator 則會建立 MultiDeviceIterator。

class _SingleWorkerDatasetIterator(_SingleWorkerDatasetIteratorBase):
  """Iterator for a single DistributedDatasetV1 instance."""

  def _make_iterator(self):
    """Make appropriate iterator on the dataset."""
    with ops.device(self._worker):
      if self._options is not None:
        self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
            self._dataset,
            self._devices,
            max_buffer_size=self._options.experimental_per_replica_buffer_size,
            prefetch_buffer_size=self._options
            .experimental_per_replica_buffer_size)
      else:
        self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
            self._dataset,
            self._devices,
        )

  def initialize(self):
    """Initialize underlying iterator.

    In eager execution, this simply recreates the underlying iterator.
    In graph execution, it returns the initializer ops for the underlying
    iterator.

    Returns:
      A list of any initializer ops that should be run.
    """
    if ops.executing_eagerly_outside_functions():
      self._iterator._eager_reset()  # pylint: disable=protected-access
      return []
    else:
      return [self._iterator.initializer]

  @property
  def output_classes(self):
    return dataset_ops.get_legacy_output_classes(self._iterator)

  @property
  def output_shapes(self):
    return dataset_ops.get_legacy_output_shapes(self._iterator)

  @property
  def output_types(self):
    return dataset_ops.get_legacy_output_types(self._iterator)

具體邏輯如下:

圖 5 獲取迭代器

2.1.7 DistributedIterator

我們接下來看 DistributedIterator。

DistributedIterator

DistributedIterator 其實沒有完成多少實際工作,主要功能是在基類 DistributedIteratorBase。

class DistributedIterator(DistributedIteratorBase,
                          composite_tensor.CompositeTensor):
  """Input Iterator for a distributed dataset."""

  def __init__(self,
               input_workers=None,
               iterators=None,
               strategy=None,
               components=None,
               element_spec=None,
               cardinality=cardinality_lib.UNKNOWN,
               enable_get_next_as_optional=False,
               options=None):

    error_message = ("Either input_workers or "
                     "both components and element_spec need to be "
                     "provided.")
    self._options = options

    if iterators is None:
      if (components is None or element_spec is None):
        raise ValueError(error_message)
      self._element_spec = element_spec
      self._input_workers = input_workers
      self._iterators = components
      self._strategy = strategy
      self._cardinality = cardinality
      self._enable_get_next_as_optional = enable_get_next_as_optional
    else:
      if (components is not None and element_spec is not None):
        raise ValueError(error_message)

      super(DistributedIterator,
            self).__init__(input_workers, iterators, strategy, cardinality,
                           enable_get_next_as_optional)

  @property
  def element_spec(self):
    # When partial batch handling is enabled, always set the batch dimension to
    # None, otherwise we just follow element_spec of the underlying dataset
    # (whose batch dimension may also be None). This is because with partial
    # batching handling we could always produce empty batches.
    if (self._enable_get_next_as_optional and
        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
      return nest.map_structure(
          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
    return self._element_spec

  @property
  def _type_spec(self):
    # Note that we use actual element_spec instead of the rebatched-as-dynamic
    # one to create DistributedIteratorSpec, to be consistent with the
    # underlying iterators' specs.
    return DistributedIteratorSpec(self._input_workers, self._element_spec,
                                   self._strategy,
                                   self._options,
                                   self._cardinality,
                                   self._enable_get_next_as_optional)

DistributedIteratorBase

DistributedIteratorBase 的主要方法和普通迭代器相同。

class DistributedIteratorBase(DistributedIteratorInterface):
  """Common implementation for all input iterators."""

  # pylint: disable=super-init-not-called
  def __init__(self, input_workers, iterators, strategy, cardinality,
               enable_get_next_as_optional):

    self._iterators = iterators
    self._input_workers = input_workers
    self._strategy = strategy
    self._cardinality = cardinality
    self._enable_get_next_as_optional = enable_get_next_as_optional

  def next(self):
    return self.__next__()

  def __next__(self):
    try:
      return self.get_next()
    except errors.OutOfRangeError:
      raise StopIteration

  def __iter__(self):
    return self

get_next 完成了獲取數據功能,具體我們關注一下 _create_per_replica,這裏看起來和分佈式最爲相關,具體是:

  • 找到所有 worker 信息。
  • 計算副本數目。
  • 獲取數據,並且重新組合。
def get_next(self, name=None):
  """Returns the next input from the iterator for all replicas."""
  with distribution_strategy_context.enter_or_assert_strategy(
      self._strategy):

  if not self._enable_get_next_as_optional:
    return self._get_next_no_partial_batch_handling(name)

  optional_list = []
  # 找到 worker 信息
  for i, worker in enumerate(self._input_workers.worker_devices):
    with ops.device(worker):
      optional_list.append(self._iterators[i].get_next_as_optional_list())
      
  # 計算副本數目    
  num_replicas_with_values = _calculate_replicas_with_values(
      self._strategy, self._input_workers, optional_list)

  # 獲取數據,並且重新組合
  def _value_or_dummy():
    value_list = _get_value_or_dummy( # 獲取數據
        self._input_workers, optional_list, produce_dummy=True)
    return _create_per_replica(value_list, self._strategy)

  def _eof():
    # Optional.get_value raises InvalidArgumentError when there's no value,
    # so we need to call GetNext to raise EOFError.
    return self._get_next_no_partial_batch_handling()

  return control_flow_ops.cond(
      num_replicas_with_values > 0, _value_or_dummy, _eof, strict=True)

_calculate_replicas_with_values 計算出有數據的副本數目。

def _calculate_replicas_with_values(strategy, input_workers, optional_list):
  """Calcualates the number of replicas that have values.

  Args:
    strategy: the tf.distribute.Strategy.
    input_workers: the InputWorkers.
    optional_list: a list of lists tf.experimental.Optional. The values from
      each compute device grouped by the input device.

  Returns:
    A scalar Tensor.
  """
  worker_has_values = []
  for worker, optionals in zip(input_workers.worker_devices, optional_list):
    with ops.device(worker):
      device_has_values = [
          math_ops.cast(v.has_value(), dtypes.int64) for v in optionals
      ]
      worker_has_values.append(
          math_ops.reduce_sum(device_has_values, keepdims=True))
  client_has_values = math_ops.reduce_sum(worker_has_values, keepdims=True)
  if strategy.extended._in_multi_worker_mode():  
    global_has_values = strategy.reduce(
        reduce_util.ReduceOp.SUM, client_has_values, axis=None)
    return array_ops.reshape(global_has_values, [])
  else:
    return array_ops.reshape(client_has_values, [])

_get_value_or_dummy 獲取具體數據。

def _get_value_or_dummy(input_workers, optional_list, produce_dummy):
  """Returns the value of the optionals or dummy values.

  Args:
    input_workers: the InputWorkers.
    optional_list: a list of lists tf.experimental.Optional. The values from
      each compute device grouped by the input device.
    produce_dummy: a bool. Whether to produce dummy tensors when the optional
      doesn't have a value.

  Returns:
    A flatten list of Tensors.

  """
  value_list = []
  for i, worker in enumerate(input_workers.worker_devices): # 遍歷 worker
    with ops.device(worker):
      devices = input_workers.compute_devices_for_worker(i) # 遍歷 worker 之中的設備
      for j, device in enumerate(devices):
        with ops.device(device):
          if produce_dummy:
            value_list.append( # 累計數據
                control_flow_ops.cond(
                    optional_list[i][j].has_value(),
                    lambda: optional_list[i][j].get_value(),  
                    lambda: _dummy_tensor_fn(optional_list[i][j].element_spec),
                    strict=True,
                ))
          else:
            value_list.append(optional_list[i][j].get_value())
  return value_list

_create_per_replica 完成了具體數據的重新組合。

  • 對於 OneDeviceStrategy 以外的策略,它會創建一個 PerReplica,其類型規格被設置爲數據集的元素規格。這有助於避免對部分批次進行回溯。當多個客戶在不同的時間回溯時,回溯對於多客戶端來說是有問題的,因爲回溯改變了 tf.function 的集合鍵(collective keys),並導致客戶之間的不匹配。
  • 對於單客戶策略,_create_per_replica 只是調用 distribution_utils.regroup()。
def _create_per_replica(value_list, strategy):
  """Creates a PerReplica.

  For strategies other than OneDeviceStrategy, it creates a PerReplica whose
  type spec is set to the element spec of the dataset. This helps avoid
  retracing for partial batches. Retracing is problematic for multi client when
  different client retraces different time, since retracing changes the
  collective keys in the tf.function, and causes mismatches among clients.

  For single client strategies, this simply calls distribute_utils.regroup().

  Args:
    value_list: a list of values, one for each replica.
    strategy: the tf.distribute.Strategy.

  Returns:
    a structure of PerReplica.

  """
  always_wrap = _always_wrap(strategy)
  per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap)
  return per_replicas

具體邏輯如下:

圖 6 迭代器處理數據

至此,對於讀取數據我們其實已經有了一個比較基礎的分析,其中最主要幾個類之間的邏輯如下:

  • InputWorker 的作用是從輸入 worker 設備到計算設備的 1-to-many mapping,可以認爲 InputWorker 把 worker 綁定到設備之上。
  • DistributedDataset 就是數據集了,其內部有一系列複雜處理,首先把數據集複製到一系列設備上,然後對數據集進行一系列加強,首先是 _RemoteDataset,然後升級到 _AutoShardDataset。
  • DistributedDataset 的 iter 方法會針對每個 worker 建立一個 iterator,最後統一返回一個 DistributedIterator。
  • DistributedIterator 的 get_next 方法完成了獲取數據功能,用 _create_per_replica 來舉例, 具體操作是:
    • 找到所有 worker 信息。
    • 計算副本數目。
    • 獲取數據,並且重新組合。

具體如下圖(只是大致邏輯概念,僅僅爲了更好的說明),數字與下圖之中對應。

  • 1.InputWorkers 提供了worker和設備的映射關係。
  • 2.數據集被分配到各個設備或者說worker之上。
  • 3.每個 worker 建立一個 iterator,最後統一返回一個 DistributedIterator。

2.2 通過方法初始化

如果上述的批量分割和數據集分片邏輯(即,直接讀取數據集邏輯)不能滿足需求,用戶可以使用tf.distribution.Strategy.distribution_datasets_from_function,它不會做任何自動的批量分割或分片。

2.2.1 StrategyBase

我們首先來到 StrategyBase。distribute_datasets_from_function 會分發 tf.data.Dataset,這些實例是通過執行dataset_fn 來創建的。

用戶傳入的參數 dataset_fn 是一個輸入函數,它有 tf.distribution.InputContext 參數,並返回一個 tf.data.Dataset 實例。從 dataset_fn 返回的數據集默認已經按每個副本的批處理量(即全局批處理量除以同步的副本數量)進行分批,也進行了分片處理。

tf.distribution.Strategy.distribution_datasets_from_function 不會對輸入函數返回的 tf.dataset 實例進行批處理或分片。dataset_fn 將在每個 worker 的 CPU 設備上被調用,每次調用都會生成一個數據集,在此調用之中,該 worker 的每個副本都會從數據集獲取一批輸入(例如,如果一個 worke r有兩個副本,每一步都會從 Dataset 中提取兩個批次)。在幾種情況下會使用這個方法。

  • 首先,它允許您指定您自己的批處理和分片邏輯,相比之下,tf.distribution.experimental_distribute_dataset 會爲您做批處理和分片。例如,當 experimental_distribute_dataset 無法對輸入文件進行分片時,可以用這個方法來手動分片(避免 experimental_distribute_dataset 的緩慢回退行爲)。
  • 在無限數據集的情況下,可以通過創建數據集副本來完成分片,這些副本只在隨機種子上有所不同。

dataset_fn 應該接受一個 tf.distribution.InputContext 實例,此實例包括了關於批處理和輸入副本的信息。由dataset_fn 返回的 tf.data.Dataset 應該有一個每副本(per-replica)的批次大小,這與 experimental_distribute_dataset 不同,後者使用全局批次大小。全局批次大小可以通過 input_context.get_per_replica_batch_size 來計算得到。

def distribute_datasets_from_function(self, dataset_fn, options=None):
  """
  Args:
    dataset_fn: A function taking a tf.distribute.InputContext instance and
      returning a tf.data.Dataset.
    options: tf.distribute.InputOptions used to control options on how this
      dataset is distributed.

  Returns:
    A tf.distribute.DistributedDataset.
  """
  distribution_strategy_input_api_counter.get_cell(
      self.__class__.__name__,
      "distribute_datasets_from_function").increase_by(1)

  return self._extended._distribute_datasets_from_function(  
      dataset_fn, options)

2.2.2 MirroredStrategy

我們依然用MirroredStrategy作爲例子來看。_distribute_datasets_from_function 這裏會初始化 Input worker,然後配置上下文,讀取數據。

def _distribute_datasets_from_function(self, dataset_fn, options):
  input_workers = self._input_workers_with_options(options) # 構建 InputWorkers
  input_contexts = []
  num_workers = input_workers.num_workers
  for i in range(num_workers):
    input_contexts.append(distribute_lib.InputContext(
        num_input_pipelines=num_workers,
        input_pipeline_id=i,
        num_replicas_in_sync=self._num_replicas_in_sync))

  return input_lib.get_distributed_datasets_from_function(
      dataset_fn, input_workers, input_contexts, self._container_strategy(),
      options)

2.2.3 建立 InputWorkers

_input_workers_with_options 建立了 InputWorkers。

def _input_workers_with_options(self, options=None):
  if not options:
    return input_lib.InputWorkers(self._input_workers_devices)
  if (options.experimental_replication_mode ==
      distribute_lib.InputReplicationMode.PER_REPLICA):
    if options.experimental_place_dataset_on_device:
      self._input_workers_devices = (
          tuple(
              (device_util.canonicalize(d, d), (d,)) for d in self._devices))
    else:
      self._input_workers_devices = (
          tuple((device_util.canonicalize("/device:CPU:0", d), (d,))
                for d in self._devices))
    return input_lib.InputWorkers(self._input_workers_devices)
  else:
    if not options.experimental_fetch_to_device:
      return input_lib.InputWorkers([
          (host_device, (host_device,) * len(compute_devices))
          for host_device, compute_devices in self._input_workers_devices
      ])
    else:
      return input_lib.InputWorkers(self._input_workers_devices)

2.2.4 input_contexts

input_contexts 是一個包裝輸入函數所需信息的類,是一個傳遞給用戶輸入函數的上下文類,包含了關於計算副本和輸入流水線的信息。

  • 利用計算副本的數量(同步訓練中)可以讓我們從每個副本所需的全局批次大小中計算出本地批次大小。

  • 利用輸入流水線的信息則可以用來在每個副本中返回不同的輸入子集(例如,分片輸入流水線,使用不同的 input 源等)。

@tf_export("distribute.InputContext")
class InputContext(object):
  """A class wrapping information needed by an input function.

  This is a context class that is passed to the user's input function and
  contains information about the compute replicas and input pipelines. The
  number of compute replicas (in sync training) helps compute the local batch
  size from the desired global batch size for each replica. The input pipeline
  information can be used to return a different subset of the input in each
  replica (for e.g. shard the input pipeline, use a different input
  source etc).
  """

  __slots__ = [
      "_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync"
  ]

  def __init__(self,
               num_input_pipelines=1,
               input_pipeline_id=0,
               num_replicas_in_sync=1):
    """Initializes an InputContext object.

    Args:
      num_input_pipelines: the number of input pipelines in a cluster.
      input_pipeline_id: the current input pipeline id, should be an int in
        [0,num_input_pipelines).
      num_replicas_in_sync: the number of replicas that are in sync.
    """
    self._num_input_pipelines = num_input_pipelines
    self._input_pipeline_id = input_pipeline_id
    self._num_replicas_in_sync = num_replicas_in_sync

  @property
  def num_replicas_in_sync(self):
    """Returns the number of compute replicas in sync."""
    return self._num_replicas_in_sync

  @property
  def input_pipeline_id(self):
    """Returns the input pipeline ID."""
    return self._input_pipeline_id

  @property
  def num_input_pipelines(self):
    """Returns the number of input pipelines."""
    return self._num_input_pipelines

  def get_per_replica_batch_size(self, global_batch_size):
    """Returns the per-replica batch size.

    Args:
      global_batch_size: the global batch size which should be divisible by
        num_replicas_in_sync.

    Returns:
      the per-replica batch size.

    Raises:
      ValueError: if global_batch_size not divisible by
        num_replicas_in_sync.
    """
    if global_batch_size % self._num_replicas_in_sync != 0:
      raise ValueError("The global_batch_size %r is not divisible by "
                       "num_replicas_in_sync %r " %
                       (global_batch_size, self._num_replicas_in_sync))
    return global_batch_size // self._num_replicas_in_sync

  def __str__(self):
    return "tf.distribute.InputContext(input pipeline id {}, total: {})".format(
        self.input_pipeline_id, self.num_input_pipelines)

2.2.5 返回數據集

get_distributed_datasets_from_function 從給定的輸入函數返回一個分佈式數據集。這是一個通用函數,所有策略都使用它來返回分佈式數據集。取決於在 TF 1 還是 TF 2 的背景下而返回不同的分佈式數據集實例,從而分佈式數據集實例的 API 也有所不同。

def get_distributed_datasets_from_function(dataset_fn,
                                           input_workers,
                                           input_contexts,
                                           strategy,
                                           options=None):
  """Returns a distributed dataset from the given input function.

  This is a common function that is used by all strategies to return a
  distributed dataset. The distributed dataset instance returned is different
  depending on if we are in a TF 1 or TF 2 context. The distributed dataset
  instances returned differ from each other in the APIs supported by each of
  them.

  Args:
    dataset_fn: a function that returns a tf.data.Dataset instance.
    input_workers: an InputWorkers object which specifies devices on which
        iterators should be created.
    input_contexts: A list of InputContext instances to be passed to call(s)
        to dataset_fn. Length and order should match worker order in
        worker_device_pairs.
    strategy: a tf.distribute.Strategy object, used to run all-reduce to
        handle last partial batch.
    options: Default is None. tf.distribute.InputOptions used to control
        options on how this dataset is distributed.

  Returns:
    A distributed dataset instance.

  Raises:
    ValueError: if options.experimental_replication_mode and
    options.experimental_place_dataset_on_device are not consistent
  """

  if tf2.enabled():
    return DistributedDatasetsFromFunction(input_workers, strategy,
                                           input_contexts, dataset_fn, options)
  else:
    return DistributedDatasetsFromFunctionV1(input_workers, strategy,
                                             input_contexts, dataset_fn,
                                             options)

2.2.6 構建數據集

DistributedDatasetsFromFunction 會調用 _create_datasets_from_function_with_input_context。

class DistributedDatasetsFromFunction(_IterableInput,
                                      composite_tensor.CompositeTensor):
  """Inputs created from dataset function."""

  def __init__(self,
               input_workers,
               strategy,
               input_contexts=None,
               dataset_fn=None,
               options=None,
               components=None,
               element_spec=None):
    """Makes an iterable from datasets created by the given function.

    Args:
      input_workers: an InputWorkers object.
      strategy: a tf.distribute.Strategy object, used to run all-reduce to
        handle last partial batch.
      input_contexts: A list of InputContext instances to be passed to call(s)
        to dataset_fn. Length and order should match worker order in
        worker_device_pairs.
      dataset_fn: A function that returns a Dataset given an InputContext.
        Either dataset_fn or components should be passed to construct
        DistributedDatasetsFromFunction. Use this when constructing
        DistributedDataset using a function. Use components when constructing
        using DistributedDatasetsFromFunctionSpec.
      options: tf.distribute.InputOptions used to control options on how this
        dataset is distributed.
      components: datasets when DistributedDatasetsFromFunction is constructed
        from DistributedDatasetsFromFunctionSpec. Only one of dataset or
        components should be passed.
      element_spec: element spec for DistributedDataset when constructing from
        DistributedDatasetSpec. This will be used to set the element_spec for
        DistributedDatasetsFromFunctionSpec and verified against element_spec
        from components.
    """
    super(DistributedDatasetsFromFunction, self).__init__(
        input_workers=input_workers)
    self._input_workers = input_workers
    self._strategy = strategy
    self._options = options

    if dataset_fn is not None:
      self._datasets, element_spec = (
          _create_datasets_from_function_with_input_context(
              input_contexts, self._input_workers, dataset_fn))
      self._element_spec = _create_distributed_tensor_spec(
          self._strategy, element_spec)
    else:
      self._element_spec = element_spec
      self._datasets = components

    self._enable_get_next_as_optional = _enable_get_next_as_optional(
        self._strategy, self._datasets[0])

_create_datasets_from_function_with_input_context 函數會正式構建數據集。

def _create_datasets_from_function_with_input_context(input_contexts,
                                                      input_workers,
                                                      dataset_fn):
  """Create device datasets per worker given a dataset function."""
  datasets = []
  for i, ctx in enumerate(input_contexts): #遍歷上下文
    worker = input_workers.worker_devices[i] # 遍歷 worker
    with ops.device(worker):
      dataset = dataset_fn(ctx) # 獲取數據
      datasets.append(dataset)
  return datasets, dataset.element_spec

具體邏輯如下:

圖 7 通過方法構建數據

2.3 高層使用

2.3.1 Keras

我們首先看看 Keras 之中的使用。在 tensorflow/python/keras/distribute/distributed_training_utils_v1.py 之中有如下方法,這裏會生成策略的數據迭代器。

def get_iterator(dataset, distribution_strategy):
  with distribution_strategy.scope():
    iterator = distribution_strategy.make_dataset_iterator(dataset)
  initialize_iterator(iterator, distribution_strategy)
  return iterator

tensorflow/python/distribute/distribute_lib.py 則會使用 _extended,比如 StrategyBase 有:

def make_dataset_iterator(self, dataset):
  """DEPRECATED TF 1.x ONLY."""
  return self._extended._make_dataset_iterator(dataset)  

對於 StrategyV1 有:

def make_dataset_iterator(self, dataset):
  """Makes an iterator for input provided via dataset.

  DEPRECATED: This method is not available in TF 2.x.

  Data from the given dataset will be distributed evenly across all the
  compute replicas. We will assume that the input dataset is batched by the
  global batch size. With this assumption, we will make a best effort to
  divide each batch across all the replicas (one or more workers).
  If this effort fails, an error will be thrown, and the user should instead
  use make_input_fn_iterator which provides more control to the user, and
  does not try to divide a batch across replicas.

  The user could also use make_input_fn_iterator if they want to
  customize which input is fed to which replica/worker etc.

  Args:
    dataset: tf.data.Dataset that will be distributed evenly across all
      replicas.

  Returns:
    An tf.distribute.InputIterator which returns inputs for each step of the
    computation.  User should call initialize on the returned iterator.
  """
  return self._extended._make_dataset_iterator(dataset) 

來到 tensorflow/python/distribute/mirrored_strategy.py,則有如下代碼生成 DatasetIterator:

def _make_dataset_iterator(self, dataset):
  return input_lib.DatasetIterator(
      dataset,
      self._input_workers,
      self._container_strategy(),
      num_replicas_in_sync=self._num_replicas_in_sync)

具體邏輯如下:

圖 9 Keras 使用

2.2.2 其他路徑

另一條執行路徑如下:

def single_loss_example(optimizer_fn, distribution, use_bias=False,
                        iterations_per_step=1):
  """Build a very simple network to use in tests and examples."""

  def dataset_fn():
    return dataset_ops.Dataset.from_tensors([[1.]]).repeat()

  optimizer = optimizer_fn()
  layer = core.Dense(1, use_bias=use_bias)

  def loss_fn(ctx, x):
    del ctx
    y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
    return y * y

  single_loss_step = step_fn.StandardSingleLossStep(
      dataset_fn, loss_fn, optimizer, distribution, iterations_per_step)

StandardSingleLossStep 調用如下:

class StandardSingleLossStep(StandardInputStep):
  """A step function that implements a training step for a feed forward network.

  An instance of this class is intended to be used as a callable:

  ```python
  ...
  step = step_fn.StandardSingleLossStep(
      dataset, loss_fn, optimizer, distribution)

  # Run a single training step on a given DistributionStrategy:
  step(distribution)
  ...
  ```

  Args:
    dataset_fn: a function that returns a tf.data Dataset that produces the
      input for the model.
    loss_fn: a function that takes a context and inputs as arguments. It returns
      the loss for those inputs. context is an instance of
      values.MultiStepContext that will be passed when loss_fn is run.
      context can be used to specify the outputs to be returned from
      loss_fn, among other things.
    optimizer: an optimizer that implements an update rule.
    distribution: a DistributionStrategy object.
  """

  def __init__(self, dataset_fn, loss_fn, optimizer, distribution,
               iterations_per_step=1):
    super(StandardSingleLossStep, self).__init__(dataset_fn, distribution)
    self._loss_fn = loss_fn
    self._optimizer = optimizer
    self._iterations_per_step = iterations_per_step

  def __call__(self):
    with self._distribution.scope():
      def step_fn(ctx, inputs):
        """Function to run one iteration with one input."""
        gradients_fn = backprop.implicit_grad(self._loss_fn)
        gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)

        grads_and_vars = self.distribution.extended.call_for_each_replica(
            gradients_fn, args=(ctx, inputs))
        # If threads use layers, then we need to run the first step
        # sequentially, so that layers.build() is not executed in parallel.
        # Otherwise, multiple sets of mirrored variables are going to be
        # created.
        return self._optimizer._distributed_apply(  # pylint: disable=protected-access
            self.distribution, grads_and_vars)

      ctx = self.distribution.extended.experimental_run_steps_on_iterator(
          step_fn, self._iterator, self._iterations_per_step)
      return ctx.run_op

StandardInputStep 這裏生成了 _iterator。

class StandardInputStep(Step):
  """Step with a standard implementation of input handling.

  Args:
    dataset_fn: a function that returns a tf.data Dataset that produces the
      input for the model.
  """

  def __init__(self, dataset_fn, distribution):
    super(StandardInputStep, self).__init__(distribution)
    self._iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())

  def initialize(self):
    return self._iterator.initializer

StrategyV1之中有:

def make_input_fn_iterator(self,  # pylint: disable=useless-super-delegation
                           input_fn,
                           replication_mode=InputReplicationMode.PER_WORKER):
  """Returns an iterator split across replicas created from an input function.

  DEPRECATED: This method is not available in TF 2.x.

  The input_fn should take an tf.distribute.InputContext object where
  information about batching and input sharding can be accessed:

  ```
  def input_fn(input_context):
    batch_size = input_context.get_per_replica_batch_size(global_batch_size)
    d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
    return d.shard(input_context.num_input_pipelines,
                   input_context.input_pipeline_id)
  with strategy.scope():
    iterator = strategy.make_input_fn_iterator(input_fn)
    replica_results = strategy.experimental_run(replica_fn, iterator)
  ```

  The tf.data.Dataset returned by input_fn should have a per-replica
  batch size, which may be computed using
  input_context.get_per_replica_batch_size.

  Args:
    input_fn: A function taking a tf.distribute.InputContext object and
      returning a tf.data.Dataset.
    replication_mode: an enum value of tf.distribute.InputReplicationMode.
      Only PER_WORKER is supported currently, which means there will be
      a single call to input_fn per worker. Replicas will dequeue from the
      local tf.data.Dataset on their worker.

  Returns:
    An iterator object that should first be .initialize()-ed. It may then
    either be passed to strategy.experimental_run() or you can
    iterator.get_next() to get the next value to pass to
    strategy.extended.call_for_each_replica().
  """
  return super(StrategyV1, self).make_input_fn_iterator(
      input_fn, replication_mode)

StrategyBase 之中有:

@doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
def make_input_fn_iterator(self,
                           input_fn,
                           replication_mode=InputReplicationMode.PER_WORKER):
  """DEPRECATED TF 1.x ONLY."""
  if replication_mode != InputReplicationMode.PER_WORKER:
    raise ValueError(
        "Input replication mode not supported: %r" % replication_mode)
  with self.scope():
    return self.extended._make_input_fn_iterator(  # pylint: disable=protected-access
        input_fn, replication_mode=replication_mode)

最終來到 MirroredStrategy,生成了 InputFunctionIterator,其調用關係如下:

class InputFunctionIterator(DistributedIteratorV1)
class DistributedIteratorV1(DistributedIteratorBase)

具體代碼如下:

def _make_input_fn_iterator(
    self,
    input_fn,
    replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
  input_contexts = []
  num_workers = self._input_workers.num_workers
  for i in range(num_workers):
    input_contexts.append(distribute_lib.InputContext(
        num_input_pipelines=num_workers,
        input_pipeline_id=i,
        num_replicas_in_sync=self._num_replicas_in_sync))
  return input_lib.InputFunctionIterator(input_fn, self._input_workers,
                                         input_contexts,
                                         self._container_strategy())

邏輯如下:

圖 10 使用示例

0xFF 參考

tensorflow源碼解析之distributed_runtime

TensorFlow分佈式訓練

TensorFlow內核剖析

源代碼

Tensorflow分佈式原理理解

TensorFlow架構與設計:概述

Tensorflow 跨設備通信

TensorFlow 篇 | TensorFlow 2.x 分佈式訓練概覽

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章