[源碼解析] TensorFlow 分佈式之 ParameterServerStrategy V1

[源碼解析] TensorFlow 分佈式之 ParameterServerStrategy V1

本章我們看看 ParameterServerStrategy,就是第一版代碼。研究這個是因爲目前工業界還有很多公司在使用,而且其內部機制也比較清晰易懂,值得我們分析。

安利兩個github,都是非常好的學習資料,推薦。

https://github.com/yuhuiaws/ML-study

https://github.com/Jack47/hack-SysML

另外推薦西門宇少的最新大作讓Pipeline在Transformer LM上沿着Token level並行起來——TeraPipe

本系列其他文章是:

[翻譯] TensorFlow 分佈式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分佈式之論文篇 "Implementation of Control Flow in TensorFlow"

[源碼解析] TensorFlow 分佈式環境(1) --- 總體架構

[源碼解析] TensorFlow 分佈式環境(2)---Master 靜態邏輯

[源碼解析] TensorFlow 分佈式環境(3)--- Worker 靜態邏輯

[源碼解析] TensorFlow 分佈式環境(4) --- WorkerCache

[源碼解析] TensorFlow 分佈式環境(5) --- Session

[源碼解析] TensorFlow 分佈式環境(7) --- Worker 動態邏輯

[源碼解析] TensorFlow 分佈式環境(8) --- 通信機制

[翻譯] 使用 TensorFlow 進行分佈式訓練

[源碼解析] TensorFlow 分佈式 DistributedStrategy 之基礎篇

[源碼解析] TensorFlow 之 分佈式變量

[源碼解析] TensorFlow 分佈式之 MirroredStrategy

[源碼解析] TensorFlow 分佈式之 MirroredStrategy 分發計算

1. 思路

參數服務器訓練是一種常見的數據並行方法,用於在多臺機器上擴展機器學習模型。一個參數服務器訓練集羣由工作者和參數服務器組成。變量是在參數服務器上創建的,它們在每個步驟中被工作者讀取和更新。默認情況下,工作者獨立地讀取和更新這些變量,而不互相同步。在這種配置下,它被稱爲異步訓練。

Tensorflow 支持兩種方式實現 parameter server:低階 API 創建 parameter server 集羣方式和 tf.distribute.Strategy 中的 ParameterServerStrategy。ParameterServerStrategyV1 的主要作用就是把變量分佈在 ps 之上,計算分佈在 worker 之上。我們將從幾個方面來研究:

  • 如何與集羣建立連接。
  • 如何獲取數據。
  • 如何生成變量。
  • 如何運行。

1.1 總體邏輯

ParameterServerStrategyV1 是一個異步的多工作者參數服務器 tf.distribution 策略。這個策略需要兩個角色:工作者(worker)和參數服務器。變量和對這些變量的更新將被分配給參數服務器,其他操作則被分配給 工作者。

當每個工作者有一個以上的 GPU 時,操作將被複制到所有 GPU 上,但變量不會被複制,每個工作者共享一個共同的視圖,以確定一個變量被分配到哪個參數服務器。缺省狀態下,ParameterServerStrategyV1 使用 TFConfigClusterResolver 來查找多工作者的配置,這需要一個 'TF_CONFIG' 環境變量,並且 'TF_CONFIG' 必須有一個集羣規格。

該類假設每個工作者獨立運行相同的代碼,而但參數服務器則運行一個標準服務器。這意味着,雖然每個工作者將在所有 GPU 上同步計算一個梯度更新,但工作器之間的更新是異步進行的。即使只有 CPU 或一個 GPU,也應該調用"call_for_each_replica(fn, ...)" 來進行任何可能跨副本複製的操作(即多個 GPU)。當定義"fn" 時,需要注意以下幾點:

  1. 一般不建議在策略的作用域(scope)內再打開一個設備作用域。設備作用域(即調用 tf.device)將合併或者覆蓋操作的設備,但不會改變變量的設備。
  2. 也不建議在策略的作用域(scope)內再打開一個 colocation 作用域(strategy.extended.colocate_vars_with),對於 colocating variables,則使用strategy.extended.colocate_vars_with 。協同操作可能會產生設備分配衝突。

注意:該策略僅適用於 Estimator API。當你創建"RunConfig"時,把這個策略的一個實例傳遞給"experimental_distribute"參數。而這個"RunConfig"的實例應該被傳遞給"Estimator"實例,然後在這個"Estimator" 實例上調用"train_and_evaluate"。

1.2 使用

ParameterServerStrategy 的使用樣例如下:

  strategy = tf.distribute.experimental.ParameterServerStrategy()
  run_config = tf.estimator.RunConfig(
      experimental_distribute.train_distribute=strategy)
  estimator = tf.estimator.Estimator(config=run_config)
  tf.estimator.train_and_evaluate(estimator,...)

1.3 定義

ParameterServerStrategyV1 的定義和初始化沒有什麼可以研究的,主要是使用 ParameterServerStrategyExtended 完成初始化,摘錄如下:

@tf_export(v1=["distribute.experimental.ParameterServerStrategy"])  
class ParameterServerStrategyV1(distribute_lib.StrategyV1):
  def __init__(self, cluster_resolver=None):
  """Initializes this strategy with an optional cluster_resolver.

    Args:
      cluster_resolver: Optional
        tf.distribute.cluster_resolver.ClusterResolver object. Defaults to a
        tf.distribute.cluster_resolver.TFConfigClusterResolver.
  """
    if cluster_resolver is None:
      cluster_resolver = TFConfigClusterResolver()
    super(ParameterServerStrategyV1, self).__init__(
        ParameterServerStrategyExtended(
            self, cluster_resolver=cluster_resolver))
    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
      "ParameterServerStrategy")    

2. ParameterServerStrategyExtended

ParameterServerStrategyExtended 派生自 distribute_lib.StrategyExtendedV1,提供了可以分佈式感知的算法附加 API。

class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
"""Implementation of ParameterServerStrategy and CentralStorageStrategy."""

  def __init__(self,
               container_strategy,
               cluster_resolver=None,
               compute_devices=None,
               parameter_device=None):
    super(ParameterServerStrategyExtended, self).__init__(container_strategy)
    self._initialize_strategy(
        cluster_resolver=cluster_resolver,
        compute_devices=compute_devices,
        parameter_device=parameter_device)

    # We typically don't need to do all-reduce in this strategy.
    self._cross_device_ops = (
        cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU))

2.1 初始化

這部分完成了獲取集羣信息的工作。_initialize_strategy 依據 spec 不同選擇啓動本地還是多工作者,我們只研究多工作者的情況。

  def _initialize_strategy(self,
                           cluster_resolver=None,
                           compute_devices=None,
                           parameter_device=None):
    if cluster_resolver and cluster_resolver.cluster_spec():
      self._initialize_multi_worker(cluster_resolver)
    else:
      self._initialize_local(
          compute_devices, parameter_device, cluster_resolver=cluster_resolver)

_initialize_multi_worker 這裏會做一系列配置,比如:

  • 獲取 gpu 數量。

  • 從集羣配置之中獲取信息。

  • 設定工作設備和輸入設備名稱。

  • 設定計算設備列表。

  • 分配設備策略。

  • 得到參數服務器設備列表。

  def _initialize_multi_worker(self, cluster_resolver):
  """Initialize devices for multiple workers.

    It creates variable devices and compute devices. Variables and operations
    will be assigned to them respectively. We have one compute device per
    replica. The variable device is a device function or device string. The
    default variable device assigns variables to parameter servers in a
    round-robin fashion.

    Args:
      cluster_resolver: a descendant of ClusterResolver object.

    Raises:
      ValueError: if the cluster doesn't have ps jobs.
  """
    # 獲取gpu數量
    if isinstance(cluster_resolver, TFConfigClusterResolver):
      num_gpus = context.num_gpus()
    else:
      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)

    # Save the num_gpus_per_worker for configure method.
    self._num_gpus_per_worker = num_gpus

    # 從集羣配置之中獲取信息
    cluster_spec = cluster_resolver.cluster_spec()
    task_type = cluster_resolver.task_type
    task_id = cluster_resolver.task_id
    cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
    assert cluster_spec.as_dict()

    # 設定工作設備和輸入設備名稱
    self._worker_device ="/job:%s/task:%d" % (task_type, task_id)
    self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)

    # Define compute devices which is a list of device strings and one for each
    # replica. When there are GPUs, replicate operations on these GPUs.
    # Otherwise, place operations on CPU.
    
    # 設定計算設備列表
    if num_gpus > 0:
      compute_devices = tuple(
        "%s/device:GPU:%d" % (self._worker_device, i)
          for i in range(num_gpus))
    else:
      compute_devices = (self._worker_device,)

    self._compute_devices = [
        device_util.canonicalize(d) for d in compute_devices]

    # In distributed mode, place variables on ps jobs in a round-robin fashion.
    # Note that devices returned from replica_device_setter are not
    # canonical and therefore we don't canonicalize all variable devices to
    # make them consistent.
    # TODO(yuefengz): support passing a strategy object to control variable
    # assignment.

    # 分配設備策略,變量放到哪個設備上
    num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
    self._variable_device = device_setter.replica_device_setter(
        ps_tasks=num_ps_replicas, # 參數服務器
        worker_device=self._worker_device, # 工作設備
        merge_devices=True,
        cluster=cluster_spec)

    # The _parameter_devices is needed for the parameter_devices property
    # and is a list of all variable devices. Here parameter devices are all
    # tasks of the"ps" job.
    
    # 得到參數服務器設備列表
    self._parameter_devices = tuple(map("/job:ps/task:{}".format,
                                        range(num_ps_replicas)))

    # Add a default device so that ops without specified devices will not end up
    # on other workers.
    self._default_device = self._worker_device
    self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
                                                task_id)
    self._cluster_spec = cluster_spec
    self._task_type = task_type
    self._task_id = task_id

2.2 分配設備

我們接下來看看如何分配設備。在目前狀態下,分配設備就是給每個計算圖指定一個設備名字,後續真正運行時候,系統會根據這個設備名字再具體進行分配。

2.2.1 replica_device_setter

replica_device_setter 返回一個設備函數 device function(或者說是策略),當爲副本建立計算圖時候,此策略將提供信息,該信息用來指導計算圖應該分配到哪個設備上。設備函數與 with tf.device(device_function) 一起使用。當構建時候,Operation 會自動被映射到設備函數提供的設備之上。設備約束首先從最內部的上下文添加,然後向外工作。如果 'cluster' 爲 'None' 且 'ps_tasks' 爲 0,則返回的函數爲 no-op。否則,'ps_tasks' 的值派生自 'cluster'。如果'ps_tasks' 數值不爲0,則後續變量就放到ps_device之上,否則放到 worker_device 之上。

@tf_export(v1=["train.replica_device_setter"])
def replica_device_setter(ps_tasks=0,
                          ps_device="/job:ps",
                          worker_device="/job:worker",
                          merge_devices=True,
                          cluster=None,
                          ps_ops=None,
                          ps_strategy=None):
"""Return a device function to use when building a Graph for replicas.

  Device Functions are used in with tf.device(device_function): statement to
  automatically assign devices to Operation objects as they are constructed,
  Device constraints are added from the inner-most context first, working
  outwards. The merging behavior adds constraints to fields that are yet unset
  by a more inner context. Currently the fields are (job, task, cpu/gpu).

  If cluster is None, and ps_tasks is 0, the returned function is a no-op.
  Otherwise, the value of ps_tasks is derived from cluster.

  Args:
    ps_tasks: Number of tasks in the ps job.  Ignored if cluster is
      provided.
    ps_device: String.  Device of the ps job.  If empty no ps job is used.
      Defaults to ps.
    worker_device: String.  Device of the worker job.  If empty no worker
      job is used.
    merge_devices: Boolean. If True, merges or only sets a device if the
      device constraint is completely unset. merges device specification rather
      than overriding them.
    cluster: ClusterDef proto or ClusterSpec.
    ps_ops: List of strings representing Operation types that need to be
      placed on ps devices.  If None, defaults to STANDARD_PS_OPS.
    ps_strategy: A callable invoked for every ps Operation (i.e. matched by
      ps_ops), that takes the Operation and returns the ps task index to
      use.  If None, defaults to a round-robin strategy across all ps
      devices.

  Returns:
    A function to pass to tf.device().

  Raises:
    TypeError if cluster is not a dictionary or ClusterDef protocol buffer,
    or if ps_strategy is provided but not a callable.
"""
  if cluster is not None:
    if isinstance(cluster, server_lib.ClusterSpec):
      cluster_spec = cluster.as_dict()
    else:
      cluster_spec = server_lib.ClusterSpec(cluster).as_dict()
    # Get ps_job_name from ps_device by stripping"/job:".
    ps_job_name = pydev.DeviceSpec.from_string(ps_device).job
    if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None:
      return None
    ps_tasks = len(cluster_spec[ps_job_name])

  if ps_tasks == 0:
    return None

  if ps_ops is None:
    ps_ops = list(STANDARD_PS_OPS)

  if ps_strategy is None:
    ps_strategy = _RoundRobinStrategy(ps_tasks)

  chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device,
                                  merge_devices, ps_ops, ps_strategy)
  return chooser.device_function

2.2.2 _RoundRobinStrategy

默認情況下,ps 任務上只放置變量 op,並且 placement strategy 是以 round-robin 機制在 ps tasks 之間進行分配。也可以採用比如 tf.contrib.training.GreedyLoadBalancingStrategy。

# To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker
# jobs on hosts worker0, worker1 and worker2.
cluster_spec = {
  "ps": ["ps0:2222","ps1:2222"],
  "worker": ["worker0:2222","worker1:2222","worker2:2222"]}
with
tf.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)):
  # Build your graph
  v1 = tf.Variable(...)  # assigned to /job:ps/task:0
  v2 = tf.Variable(...)  # assigned to /job:ps/task:1
  v3 = tf.Variable(...)  # assigned to /job:ps/task:0
# Run compute

_RoundRobinStrategy 具體如下:

class _RoundRobinStrategy(object):
"""Returns the next ps task index for placement in round-robin order.

  This class is not to be used directly by users.  See instead
  replica_device_setter() below.
"""

  def __init__(self, num_tasks):
  """Create a new _RoundRobinStrategy.

    Args:
      num_tasks: Number of ps tasks to cycle among.
  """
    self._num_tasks = num_tasks
    self._next_task = 0

  def __call__(self, unused_op):
  """Choose a ps task index for the given Operation.

    Args:
      unused_op: An Operation to be placed on ps.

    Returns:
      The next ps task index to use for the Operation. Returns the next
      index, in the range [offset, offset + num_tasks).
  """
    task = self._next_task
    self._next_task = (self._next_task + 1) % self._num_tasks
    return task

2.2.3 _ReplicaDeviceChooser

replica_device_setter 返回的是 _ReplicaDeviceChooser.device_function。就是使用 _ps_strategy 來返回設備名字。這裏會依據_ps_tasks的信息來決定變量放在 ps_device 之上還是worker_device之上。

class _ReplicaDeviceChooser(object):
"""Class to choose devices for Ops in a replicated training setup.

  This class is not to be used directly by users.  See instead
  replica_device_setter() below.
"""

  def __init__(self, ps_tasks, ps_device, worker_device, merge_devices, ps_ops,
               ps_strategy):
  """Create a new _ReplicaDeviceChooser.

    Args:
      ps_tasks: Number of tasks in the ps job.
      ps_device: String.  Name of the ps job.
      worker_device: String.  Name of the worker job.
      merge_devices: Boolean. Set to True to allow merging of device specs.
      ps_ops: List of strings representing Operation types that need to be
        placed on ps devices.
      ps_strategy: A callable invoked for every ps Operation (i.e. matched by
        ps_ops), that takes the Operation and returns the ps task index to
        use.
  """
    self._ps_tasks = ps_tasks
    self._ps_device = ps_device
    self._worker_device = worker_device
    self._merge_devices = merge_devices
    self._ps_ops = ps_ops
    self._ps_strategy = ps_strategy

  def device_function(self, op):
  """Choose a device for op.

    Args:
      op: an Operation.

    Returns:
      The device to use for the Operation.
  """
    # If we don't return early here, either merge_devices is True, or op.device
    # is empty (in which case merging is a no-op). So we can always merge below.
    if not self._merge_devices and op.device:
      return op.device

    current_device = pydev.DeviceSpec.from_string(op.device or"")

    # The ps_device will be used for specified ops (ps_ops) whenever it is
    # present and ps_tasks is non-zero. However, its task number will only be
    # set (using ps_strategy) if there is a job field in ps_device that won't be
    # changed by the job field (if present) in current_device.
    node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
    if self._ps_tasks and self._ps_device and node_def.op in self._ps_ops:
      ps_device = pydev.DeviceSpec.from_string(self._ps_device)

      current_job, ps_job = current_device.job, ps_device.job
      if ps_job and (not current_job or current_job == ps_job):
        # 這裏使用了策略
        ps_device = ps_device.replace(task=self._ps_strategy(op))

      ps_device = ps_device.make_merged_spec(current_device)
      return ps_device.to_string()

    worker_device = pydev.DeviceSpec.from_string(self._worker_device or"")
    worker_device = worker_device.make_merged_spec(current_device)
    return worker_device.to_string()

設備相關的邏輯總結如下:

圖 1 分配設備

初始化之後,ParameterServerStrategyExtended如下:

3. 數據

我們接下來看看如何獲取訓練數據。distribute_datasets_from_function 是調用基類 的 distribute_datasets_from_function,所以我們要看看 StrategyBase。

  def distribute_datasets_from_function(self, dataset_fn, options=None):
    if (options and options.experimental_replication_mode ==
        distribute_lib.InputReplicationMode.PER_REPLICA):
      raise NotImplementedError(
        "InputReplicationMode.PER_REPLICA"
        "is only supported in"
        "experimental_distribute_datasets_from_function"
        "of tf.distribute.MirroredStrategy")
    self._raise_pss_error_if_eager()
    super(ParameterServerStrategyV1, self).distribute_datasets_from_function(
        dataset_fn=dataset_fn, options=options)

3.1 StrategyBase

distribute_datasets_from_function 作用是依靠調用 'dataset_fn' 來分發 tf.data.Dataset。用戶傳入的參數 dataset_fn 是一個輸入函數。這個輸入參數帶有 InputContext 參數,並返回一個 tf.data.Dataset 實例。dataset_fn 得到的數據集應該是已按每個副本的批大小(即全局批大小除以同步副本的數量)進行分批次和分片的。Tf.distribute.Strategy.distribute_datasets_from_function 本身不會做分批次和分片操作。

dataset_fn 將在每個工作者的 CPU device 上調用並且會生成一個數據集,其中該工作者上的每個 replica 都會將一個輸入 batch 移出隊列(即,如果一個工作者有兩個副本,則每個 step 之中,兩個 batches 將會被從 Dataset 之中移出隊列)。這種方法有多種用途。首先,它允許您指定自己的分批切分邏輯。(相比之下,tf.distribute.experimental_distribute_dataset 爲您進行分批和分片。)例如,experimental_distribute_dataset 無法切分輸入文件,則可以使用此方法來自定義手動切分數據集(避免experimental_distribute_dataset 中的慢回調行爲)。在數據集無限大的情況下,分片可以通過依據隨機種子的不同來創建數據集副本。另外,dataset_fn 應該使用 tf.distribute.InputContext 的實例來得到分批和輸入分片的信息。

具體調用方式如下:

def per_worker_dataset_fn():
	return strategy.distribute_datasets_from_function(
		lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))

這裏我們發現,distribute_datasets_from_function 則又回到了派生類 _distribute_datasets_from_function 方法。

def distribute_datasets_from_function(self, dataset_fn, options=None):
    return self._extended._distribute_datasets_from_function(dataset_fn, options)    

3.2 _distribute_datasets_from_function

_distribute_datasets_from_function 則調用了 InputContext 來獲取數據。

  def _distribute_datasets_from_function(self, dataset_fn, options):
    if self._cluster_spec:
      input_pipeline_id = multi_worker_util.id_in_cluster(
          self._cluster_spec, self._task_type, self._task_id)
      num_input_pipelines = multi_worker_util.worker_count(
          self._cluster_spec, self._task_type)
    else:
      input_pipeline_id = 0
      num_input_pipelines = 1

    input_context = distribute_lib.InputContext(
        num_input_pipelines=num_input_pipelines,
        input_pipeline_id=input_pipeline_id,
        num_replicas_in_sync=self._num_replicas_in_sync)

    return input_lib.get_distributed_datasets_from_function(
        dataset_fn,
        self._input_workers_with_options(options), [input_context],
        self._container_strategy(),
        options=options)

3.3 InputLib

這部分代碼在 tensorflow/python/distribute/input_lib.py,主要就是獲取 iterator。

def get_distributed_datasets_from_function(dataset_fn,
                                           input_workers,
                                           input_contexts,
                                           strategy,
                                           options=None):
"""Returns a distributed dataset from the given input function.

  This is a common function that is used by all strategies to return a
  distributed dataset. The distributed dataset instance returned is different
  depending on if we are in a TF 1 or TF 2 context. The distributed dataset
  instances returned differ from each other in the APIs supported by each of
  them.

  Args:
    dataset_fn: a function that returns a tf.data.Dataset instance.
    input_workers: an InputWorkers object which specifies devices on which
        iterators should be created.
    input_contexts: A list of InputContext instances to be passed to call(s)
        to dataset_fn. Length and order should match worker order in
        worker_device_pairs.
    strategy: a tf.distribute.Strategy object, used to run all-reduce to
        handle last partial batch.
    options: Default is None. tf.distribute.InputOptions used to control
        options on how this dataset is distributed.

  Returns:
    A distributed dataset instance.

  Raises:
    ValueError: if options.experimental_replication_mode and
    options.experimental_place_dataset_on_device are not consistent
"""
  if tf2.enabled():
    return DistributedDatasetsFromFunction(input_workers, strategy,
                                           input_contexts, dataset_fn, options)
  else:
    return DistributedDatasetsFromFunctionV1(input_workers, strategy,
                                             input_contexts, dataset_fn,
                                             options)

DistributedDatasetsFromFunctionV1 則會返回 DistributedIteratorV1,既然得到了 iterator,就可以從數據集之中獲得數據了。

class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
"""Inputs created from dataset function."""

  def _make_initializable_iterator(self, shared_name=None):
  """Get an initializable iterator for DistributedDatasetsFromFunctionV1."""
    del shared_name  # Unused
    # Eager mode generates already initialized iterators. Hence we cannot create
    # an initializable iterator.
    if context.executing_eagerly():
      raise ValueError("Cannot create initializable iterator in Eager mode."
                     "Please use iter() instead.")
    return self._get_iterator()

  def _make_one_shot_iterator(self):
  """Get an iterator for iterating over DistributedDatasetsFromFunctionV1."""
    # Graph mode with one shot iterator is disabled because we have to call
    # initialize on the iterator which is only required if we are using a
    # tf.distribute strategy.
    if not context.executing_eagerly():
      raise ValueError("Cannot create a one shot iterator. Please use"
                     "make_initializable_iterator() instead.")
    return self._get_iterator()

  def _get_iterator(self):
    iterators = _create_iterators_per_worker(self._datasets,
                                             self._input_workers, True,
                                             self._options)
    iterator = DistributedIteratorV1(self._input_workers, iterators,
                                     self._strategy,
                                     self._enable_get_next_as_optional)
    iterator._element_spec = self._element_spec  # pylint: disable=protected-access

    # When async eager is enabled, sometimes the iterator may not finish
    # initialization before passing to a multi device function, add a sync point
    # here to make sure all underlying iterators are initialized.
    if context.executing_eagerly():
      context.async_wait()

    return iterator

  def __iter__(self):
    if (ops.executing_eagerly_outside_functions() or
        ops.get_default_graph().building_function):
      return self._get_iterator()

    raise RuntimeError("__iter__() is only supported inside of tf.function"
                     "or when eager execution is enabled.")

4. 作用域和變量

4.1 StrategyBase

scope 就是調用基類的方法。

  def scope(self):
    self._raise_pss_error_if_eager()
    return super(ParameterServerStrategyV1, self).scope()

StrategyBase 的 scope 方法返回一個 Context manager,其使用當前策略來建立分佈式變量,當進入 Strategy.scope 時會發生:

  • "strategy" 成爲全局上下文內的 "當前" strategy 。在這個作用域內,tf.distribute.get_strategy() 將返回此策略。在此範圍之外,它返回默認的無操作策略。
  • 進入此作用域也會進入"cross-replica context"。
  • "scope"內的變量創建被策略攔截。每個策略都定義了它想要如何影響變量的創建。像 'MirroredStrategy'、'TPUStrategy' 和 'MultiWorkerMirroredStrategy' 這樣的同步策略會在每個副本上創建複製的變量,而 'ParameterServerStrategy' 在參數服務器上創建變量。這是使用自定義的 tf.variable_creator_scope 完成的。
  • 在某些策略中,還可以輸入默認的設備作用域:比如在"MultiWorkerMirroredStrategy"中,爲每個工作者輸入默認的設備作用域 "/CPU:0"。
  def scope(self):
  """Context manager to make the strategy current and distribute variables.

    This method returns a context manager, and is used as follows:

    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0","GPU:1"])
    >>> # Variable created inside scope:
    >>> with strategy.scope():
    ...   mirrored_variable = tf.Variable(1.)
    >>> mirrored_variable
    MirroredVariable:{
      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
    }
    >>> # Variable created outside scope:
    >>> regular_variable = tf.Variable(1.)
    >>> regular_variable
    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>

    Returns:
      A context manager.
  """
    return self._extended._scope(self)  

既然是調用了 extended,我們就接着分析。

4.2 StrategyExtendedV2

_scope 則配置瞭如何創建變量,如何獲取變量,如何獲取變量作用域。具體返回給用戶一個 _CurrentDistributionContext,用戶使用比如 creator_with_resource_vars 會調用到 派生策略的 _create_variable 來創建變量。

  def _scope(self, strategy):
  """Implementation of tf.distribute.Strategy.scope()."""

    def creator_with_resource_vars(next_creator, **kwargs):
    """Variable creator to use in _CurrentDistributionContext."""
      _require_strategy_scope_extended(self)
      kwargs["use_resource"] = True
      kwargs["distribute_strategy"] = strategy

      # Unwrap initial_value if it is a CheckpointInitialValue to avoid
      # dereferencing a Tensor that is without a name. We still need to
      # propagate the metadata it's holding.
      if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
        checkpoint_restore_uid = kwargs[
          "initial_value"].checkpoint_position.restore_uid
        kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
      elif isinstance(kwargs["initial_value"],
                      trackable.CheckpointInitialValueCallable):
        checkpoint_restore_uid = kwargs[
          "initial_value"].checkpoint_position.restore_uid
      elif (isinstance(kwargs["initial_value"], functools.partial) and
            isinstance(kwargs["initial_value"].func,
                       trackable.CheckpointInitialValueCallable)):
        # Some libraries (e.g, Keras) create partial function out of initializer
        # to bind shape/dtype, for example:
        #  initial_val = functools.partial(initializer, shape, dtype=dtype)
        # Therefore to get the restore_uid we need to examine the"func" of
        # the partial function.
        checkpoint_restore_uid = kwargs[
          "initial_value"].func.checkpoint_position.restore_uid
      else:
        checkpoint_restore_uid = None

      # 這裏調用派生策略的 _create_variable
      created = self._create_variable(next_creator, **kwargs)

      if checkpoint_restore_uid is not None:
        # pylint: disable=protected-access
        # Let the checkpointing infrastructure know that the variable was
        # already restored so it doesn't waste memory loading the value again.
        # In this case of CheckpointInitialValueCallable this may already be
        # done by the final variable creator, but it doesn't hurt to do it
        # again.
        created._maybe_initialize_trackable()
        created._update_uid = checkpoint_restore_uid
       return created

    def distributed_getter(getter, *args, **kwargs):
      return getter(*args, **kwargs)

    return _CurrentDistributionContext(
        strategy,
        variable_scope.variable_creator_scope(creator_with_resource_vars),
        variable_scope.variable_scope(
            variable_scope.get_variable_scope(),
            custom_getter=distributed_getter), self._default_device)

4.2 創建變量

上面講到了 creator_with_resource_vars 會調用到派生策略的 _create_variable 來創建變量這裏我們就看看 PS 如何處理。初始化時候配置了 self._variable_device,這樣就知道了應該如何分配變量到設置之上。在後續代碼之中有 with ops.device(self._variable_device),這就是把後續作用域之中的變量放到self._variable_device之上。

self._variable_device = device_setter.replica_device_setter(
        ps_tasks=num_ps_replicas, # 參數服務器
        worker_device=self._worker_device, # 工作設備
        merge_devices=True,
        cluster=cluster_spec)

創建變量如下:

  def _create_variable(self, next_creator, **kwargs):
    
    # 創建變量
    var_creator = self._create_var_creator(next_creator, **kwargs)

    if"colocate_with" in kwargs:
      colocate_with = kwargs["colocate_with"]
      if isinstance(colocate_with, numpy_dataset.SingleDevice):
        with ops.device(colocate_with.device):
          return var_creator(**kwargs)
      with ops.device(None):
        with ops.colocate_with(colocate_with):
          return var_creator(**kwargs)

    with ops.colocate_with(None, ignore_existing=True):
      # 
      with ops.device(self._variable_device): # 這裏使用到了 replica_device_setter
        return var_creator(**kwargs)

具體建立變量是通過 _create_var_creator。這裏主要的是調用了 ps_values.AggregatingVariable 生成變量。

  def _create_var_creator(self, next_creator, **kwargs):
    if self._num_replicas_in_sync > 1:
      aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
      if aggregation not in (
          vs.VariableAggregation.NONE,
          vs.VariableAggregation.SUM,
          vs.VariableAggregation.MEAN,
          vs.VariableAggregation.ONLY_FIRST_REPLICA
      ):
        raise ValueError("Invalid variable aggregation mode:" + aggregation +
                       " for variable:" + kwargs["name"])

      def var_creator(**kwargs):
      """Create an AggregatingVariable and fix up collections."""
        # Record what collections this variable should be added to.
        collections = kwargs.pop("collections", None)
        if collections is None:
          collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        kwargs["collections"] = []

        # Create and wrap the variable.
        v = next_creator(**kwargs)
        
        # 建立變量
        wrapped = ps_values.AggregatingVariable(self._container_strategy(), v,
                                                aggregation)

        # Add the wrapped variable to the requested collections.
        # The handling of eager mode and the global step matches
        # ResourceVariable._init_from_args().
        if not context.executing_eagerly():
          g = ops.get_default_graph()
          # If"trainable" is True, next_creator() will add the contained
          # variable to the TRAINABLE_VARIABLES collection, so we manually
          # remove it and replace with the wrapper. We can't set"trainable"
          # to False for next_creator() since that causes functions like
          # implicit_gradients to skip those variables.
          if kwargs.get("trainable", True):
            collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
            if v in l:
              l.remove(v)
          g.add_to_collections(collections, wrapped)
        elif ops.GraphKeys.GLOBAL_STEP in collections:
          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)

        return wrapped
    
      return var_creator
    else:
      return next_creator

4.3 PS 變量

AggregatingVariable 就是爲變量加了一個 wrapper,這樣對於變量的操作就落到了 strategy 之上。這裏只給出了部分代碼。

# Variable used in PSStrategy TF 1, TF2 and CentralStorageStrategy.
class AggregatingVariable(variables_lib.Variable, core.Tensor):
"""A wrapper around a variable that aggregates updates across replicas."""

  def __init__(self, strategy, v, aggregation):
    self._distribute_strategy = strategy
    self._v = v
    # NOTE: We don't use"_distributed_container" here because we don't want
    # to trigger that code path in regroup().
    v._aggregating_container = weakref.ref(self)  # pylint: disable=protected-access
    self._aggregation = aggregation

  def get(self):
    return self._v

  @property
  def distribute_strategy(self):
    return self._distribute_strategy

  def __getattr__(self, name):
    return getattr(self._v, name)

  def _assign_func(self, *args, **kwargs):
    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
      f = kwargs.pop("f")
      
      # 這裏使用了跨副本上下文
      if ds_context.in_cross_replica_context():
        if distribute_lib.get_update_replica_id() is not None:
          # We are calling an assign function in an update context.
          return f(self._v, *args, **kwargs)

        # We are calling an assign function in cross replica context, wrap it in
        # an update call.
        # 使用策略來更新        
        return self._distribute_strategy.extended.update(
            self, f, args=args, kwargs=kwargs)
      else:
        replica_context = ds_context.get_replica_context()
        assert replica_context
        # We are calling an assign function in replica context.
        # We reduce the value we want to assign/add/sub. More details about how
        # we handle the different use cases can be found in the _reduce method.
        # We call the function with the reduced value.
        if self._aggregation == vs.VariableAggregation.NONE:
          raise ValueError(
              values_util.aggregation_error_msg.format(
                  variable_type="AggregatingVariable"))

        def merge_fn(strategy,
                     value,
                     use_locking=False,
                     name=None,
                     read_value=True):
          v = values_util.apply_aggregation(strategy, value, self._aggregation,
                                            self)
          if name and isinstance(name, values.PerReplica):
            name = name.values[0]
          return strategy.extended.update(
              self,
              f,
              args=(v,),
              kwargs={
                "use_locking": use_locking,
                "name": name,
                "read_value": read_value
              })
        return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)

  def assign_sub(self, *args, **kwargs):
    assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
    return self._assign_func(f=assign_sub_fn, *args, **kwargs)

  def assign_add(self, *args, **kwargs):
    assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
    return self._assign_func(f=assign_add_fn, *args, **kwargs)

  def assign(self, *args, **kwargs):
    assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
    return self._assign_func(f=assign_fn, *args, **kwargs)

  @property
  def initializer(self):
    return self._v.initializer

  def initialized_value(self):
    return self._v.initialized_value()

  @property
  def initial_value(self):
    return self._v.initial_value

  # 省略大部分代碼

具體邏輯如下,第一個操作序列是建立變量,第二個操作序列是處理變量。

圖 2 創建變量

5. 運行

我們接下來看看 ParameterServerStrategyV1 如何運行。

5.1 基類

ParameterServerStrategyV1 其實調用了基類 StrategyV1 的 run 方法,具體定義在 tensorflow/python/distribute/distribute_lib.py。具體在前文之中我們已經分析過,這裏爲了行文完整,再次列舉出來如下.

這個方法是用 tf.distribution 對象分發計算的主要方法。它在每個副本上調用fn。如果args或kwargs有tf.distribution.DistributedValues,當 fn 在一個特定的副本上執行時,它將與對應於該副本的 tf.distributed.DistributedValues 的組件一起執行。

tf.distribution.DistributedValues 的例子如下:由 tf.distribution.DistributedDataset 產生的tf.distribution.Strategy.experimental_distribute_dataset 或 tf.distribution.Strategy.Dataset 的 tf.distributedDataset,

fn 在副本上下文被調用,fn可以調用tf.distribution.get_replica_context()來訪問諸如all_reduce等成員。args 或kwargs 中的所有參數可以是一個嵌套的張量結構,例如一個張量列表,在這種情況下,args 和 kwargs 將被傳遞給在每個副本上調用的 fn。或者 args 或 kwargs 可以是包含張量或複合張量的tf.compat.v1.TensorInfo.CompositeTensor 的 tf.distributedValues,在這種情況下,每個fn調用將得到與其副本對應的tf.distributedValues的組件。

重要的是:根據 tf.distribution.Strategy 的實現和是否啓用 eager execution,fn可能被調用一次或多次。如果 fn被註解爲 tf.function 或者 tf.distribution.Strategy.run 在 tf.function 中被調用(默認情況下 tf.function 中禁止 eager execution),fn 在每個副本中被調用一次以生成 Tensorflow 圖,然後被重新用於新輸入的執行。

run 方法之中,主要就是調用了 call_for_each_replica。

  def run(self, fn, args=(), kwargs=None, options=None):
  """Invokes fn on each replica, with the given arguments.
  """
    del options

    if not isinstance(args, (list, tuple)):
      raise ValueError(
        "positional args must be a list or tuple, got {}".format(type(args)))

    with self.scope():
      # tf.distribute supports Eager functions, so AutoGraph should not be
      # applied when the caller is also in Eager mode.
      fn = autograph.tf_convert(
          fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
      return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)

Extend

執行來到了 StrategyExtendedV2,其實際上調用的是派生類的 _call_for_each_replica。

  def call_for_each_replica(self, fn, args=(), kwargs=None):
  """Run fn once per replica.

    fn may call tf.get_replica_context() to access methods such as
    replica_id_in_sync_group and merge_call().

    merge_call() is used to communicate between the replicas and
    re-enter the cross-replica context. All replicas pause their execution
    having encountered a merge_call() call. After that the
    merge_fn-function is executed. Its results are then unwrapped and
    given back to each replica call. After that execution resumes until
    fn is complete or encounters another merge_call().  Example:

    ```python
    # Called once in"cross-replica" context.
    def merge_fn(distribution, three_plus_replica_id):
      # sum the values across replicas
      return sum(distribution.experimental_local_results(three_plus_replica_id))

    # Called once per replica in distribution, in a"replica" context.
    def fn(three):
      replica_ctx = tf.get_replica_context()
      v = three + replica_ctx.replica_id_in_sync_group
      # Computes the sum of the v values across all replicas.
      s = replica_ctx.merge_call(merge_fn, args=(v,))
      return s + v

    with distribution.scope():
      # in"cross-replica" context
      ...
      merged_results = distribution.run(fn, args=[3])
      # merged_results has the values from every replica execution of fn.
      # This statement prints a list:
      print(distribution.experimental_local_results(merged_results))
    ```

    Args:
      fn: function to run (will be run once per replica).
      args: Tuple or list with positional arguments for fn.
      kwargs: Dict with keyword arguments for fn.

    Returns:
      Merged return value of fn across all replicas.
  """
    _require_cross_replica_or_default_context_extended(self)
    if kwargs is None:
      kwargs = {}
    with self._container_strategy().scope():
      return self._call_for_each_replica(fn, args, kwargs)

5.2 派生

派生類 ParameterServerStrategyExtended 的 _call_for_each_replica 如下:

  def _call_for_each_replica(self, fn, args, kwargs):
    return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
                                              args, kwargs)

具體 mirrored_run 部分已經在前文分析過,這裏不再贅述,具體邏輯如下:

圖 3 運行

或者從另一個角度如下圖所示:

0xFF 參考

https://www.youtube.com/watch?v=B2Tpv_N7wkg&ab_channel=TensorFlow

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章