學習筆記TF061:分佈式TensorFlow,分佈式原理、最佳實踐

分佈式TensorFlow由高性能gRPC庫底層技術支持。Martin Abadi、Ashish Agarwal、Paul Barham論文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》。

分佈式原理。分佈式集羣 由多個服務器進程、客戶端進程組成。部署方式,單機多卡、分佈式(多機多卡)。多機多卡TensorFlow分佈式。

單機多卡,單臺服務器多塊GPU。訓練過程:在單機單GPU訓練,數據一個批次(batch)一個批次訓練。單機多GPU,一次處理多個批次數據,每個GPU處理一個批次數據計算。變量參數保存在CPU,數據由CPU分發給多個GPU,GPU計算每個批次更新梯度。CPU收集完多個GPU更新梯度,計算平均梯度,更新參數。繼續計算更新梯度。處理速度取決最慢GPU速度。

分佈式,訓練在多個工作節點(worker)。工作節點,實現計算單元。計算服務器單卡,指服務器。計算服務器多卡,多個GPU劃分多個工作節點。數據量大,超過一臺機器處理能力,須用分佈式。

分佈式TensorFlow底層通信,gRPC(google remote procedure call)。gRPC,谷歌開源高性能、跨語言RPC框架。RPC協議,遠程過程調用協議,網絡從遠程計算機程度請求服務。

分佈式部署方式。分佈式運行,多個計算單元(工作節點),後端服務器部署單工作節點、多工作節點。

單工作節點部署。每臺服務器運行一個工作節點,服務器多個GPU,一個工作節點可以訪問多塊GPU卡。代碼tf.device()指定運行操作設備。優勢,單機多GPU間通信,效率高。劣勢,手動代碼指定設備。

多工作節點部署。一臺服務器運行多個工作節點。

設置CUDA_VISIBLE_DEVICES環境變量,限制各個工作節點只可見一個GPU,啓動進程添加環境變量。用tf.device()指定特定GPU。多工作節點部署優勢,代碼簡單,提高GPU使用率。劣勢,工作節點通信,需部署多個工作節點。https://github.com/tobegit3hub/tensorflow_examples/tree/master/distributed_tensorflow

CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=0
CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=1
CUDA_VISIBLE_DEVICES='0' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=0
CUDA_VISIBLE_DEVICES='1' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=1

分佈式架構。https://www.tensorflow.org/extend/architecture 。客戶端(client)、服務端(server),服務端包括主節點(master)、工作節點(worker)組成。

客戶端、主節點、工作節點關係。TensorFlow,客戶端會話聯繫主節點,實際工作由工作節點實現,每個工作節點佔一臺設備(TensorFlow具體計算硬件抽象,CPU或GPU)。單機模式,客戶端、主節點、工作節點在同一臺服務器。分佈模式,可不同服務器。客戶端->主節點->工作節點/job:worker/task:0->/job:ps/task:0。
客戶端。建立TensorFlow計算圖,建立與集羣交互會話層。代碼包含Session()。一個客戶端可同時與多個服務端相連,一具服務端也可與多個客戶端相連。
服務端。運行tf.train.Server實例進程,TensroFlow執行任務集羣(cluster)一部分。有主節點服務(Master service)和工作節點服務(Worker service)。運行中,一個主節點進程和數個工作節點進程,主節點進程和工作接點進程通過接口通信。單機多卡和分佈式結構相同,只需要更改通信接口實現切換。
主節點服務。實現tensorflow::Session接口。通過RPC服務程序連接工作節點,與工作節點服務進程工作任務通信。TensorFlow服務端,task_index爲0作業(job)。
工作節點服務。實現worker_service.proto接口,本地設備計算部分圖。TensorFlow服務端,所有工作節點包含工作節點服務邏輯。每個工作節點負責管理一個或多個設備。工作節點可以是本地不同端口不同進程,或多臺服務多個進程。運行TensorFlow分佈式執行任務集,一個或多個作業(job)。每個作業,一個或多個相同目的任務(task)。每個任務,一個工作進程執行。作業是任務集合,集羣是作業集合。
分佈式機器學習框架,作業分參數作業(parameter job)和工作節點作業(worker job)。參數作業運行服務器爲參數服務器(parameter server,PS),管理參數存儲、更新。工作節點作業,管理無狀態主要從事計算任務。模型越大,參數越多,模型參數更新超過一臺機器性能,需要把參數分開到不同機器存儲更新。參數服務,多臺機器組成集羣,類似分佈式存儲架構,涉及數據同步、一致性,參數存儲爲鍵值對(key-value)。分佈式鍵值內存數據庫,加參數更新操作。李沐《Parameter Server for Distributed Machine Learning》http://www.cs.cmu.edu/~muli/file/ps.pdf
參數存儲更新在參數作業進行,模型計算在工作節點作業進行。TensorFlow分佈式實現作業間數據傳輸,參數作業到工作節點作業前向傳播,工作節點作業到參數作業反向傳播。
任務。特定TensorFlow服務器獨立進程,在作業中擁有對應序號。一個任務對應一個工作節點。集羣->作業->任務->工作節點。

客戶端、主節點、工作節點交互過程。單機多卡交互,客戶端->會話運行->主節點->執行子圖->工作節點->GPU0、GPU1。分佈式交互,客戶端->會話運行->主節點進程->執行子圖1->工作節點進程1->GPU0、GPU1。《TensorFlow:Large-Scale Machine Learning on Heterogeneous distributed Systems》https://arxiv.org/abs/1603.04467v1

分佈式模式。

數據並行。https://www.tensorflow.org/tutorials/deep_cnn 。CPU負責梯度平均、參數更新,不同GPU訓練模型副本(model replica)。基於訓練樣例子集訓練,模型有獨立性。
步驟:不同GPU分別定義模型網絡結構。單個GPU從數據管道讀取不同數據塊,前向傳播,計算損失,計算當前變量梯度。所有GPU輸出梯度數據轉移到CPU,梯度求平均操作,模型變量更新。重複,直到模型變量收斂。
數據並行,提高SGD效率。SGD mini-batch樣本,切成多份,模型複製多份,在多個模型上同時計算。多個模型計算速度不一致,CPU更新變量有同步、異步兩個方案。

同步更新、異步更新。分佈式隨機梯度下降法,模型參數分佈式存儲在不同參數服務上,工作節點並行訓練數據,和參數服務器通信獲取模型參數。
同步隨機梯度下降法(Sync-SGD,同步更新、同步訓練),訓練時,每個節點上工作任務讀入共享參數,執行並行梯度計算,同步需要等待所有工作節點把局部梯度處好,將所有共享參數合併、累加,再一次性更新到模型參數,下一批次,所有工作節點用模型更新後參數訓練。優勢,每個訓練批次考慮所有工作節點訓練情部,損失下降穩定。劣勢,性能瓶頸在最慢工作節點。異楹設備,工作節點性能不同,劣勢明顯。
異步隨機梯度下降法(Async-SGD,異步更新、異步訓練),每個工作節點任務獨立計算局部梯度,異步更新到模型參數,不需執行協調、等待操作。優勢,性能不存在瓶頸。劣勢,每個工作節點計算梯度值發磅回參數服務器有參數更新衝突,影響算法收劍速度,損失下降過程抖動較大。
同步更新、異步更新實現區別於更新參數服務器參數策略。數據量小,各節點計算能力較均衡,用同步模型。數據量大,各機器計算性能參差不齊,用異步模式。
帶備份的Sync-SGD(Sync-SDG with backup)。Jianmin Chen、Xinghao Pan、Rajat Monga、Aamy Bengio、Rafal Jozefowicz論文《Revisiting Distributed Synchronous SGD》https://arxiv.org/abs/1604.00981 。增加工作節點,解決部分工作節點計算慢問題。工作節點總數n+n*5%,n爲集羣工作節點數。異步更新設定接受到n個工作節點參數直接更新參數服務器模型參數,進入下一批次模型訓練。計算較慢節點訓練參數直接丟棄。
同步更新、異步更新有圖內模式(in-graph pattern)和圖間模式(between-graph pattern),獨立於圖內(in-graph)、圖間(between-graph)概念。
圖內複製(in-grasph replication),所有操作(operation)在同一個圖中,用一個客戶端來生成圖,把所有操作分配到集羣所有參數服務器和工作節點上。國內複製和單機多卡類似,擴展到多機多卡,數據分發還是在客戶端一個節點上。優勢,計算節點只需要調用join()函數等待任務,客戶端隨時提交數據就可以訓練。劣勢,訓練數據分發在一個節點上,要分發給不同工作節點,嚴重影響併發訓練速度。
圖間複製(between-graph replication),每一個工作節點創建一個圖,訓練參數保存在參數服務器,數據不分發,各個工作節點獨立計算,計算完成把要更新參數告訴參數服務器,參數服務器更新參數。優勢,不需要數據分發,各個工作節點都創建圖和讀取數據訓練。劣勢,工作節點既是圖創建者又是計算任務執行者,某個工作節點宕機影響集羣工作。大數據相關深度學習推薦使用圖間模式。

模型並行。切分模型,模型不同部分執行在不同設備上,一個批次樣本可以在不同設備同時執行。TensorFlow儘量讓相鄰計算在同一臺設備上完成節省網絡開銷。Martin Abadi、Ashish Agarwal、Paul Barham論文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》https://arxiv.org/abs/1603.04467v1

模型並行、數據並行,TensorFlow中,計算可以分離,參數可以分離。可以在每個設備上分配計算節點,讓對應參數也在該設備上,計算參數放一起。

分佈式API。https://www.tensorflow.org/deploy/distributed
創建集羣,每個任務(task)啓動一個服務(工作節點服務或主節點服務)。任務可以分佈不同機器,可以同一臺機器啓動多個任務,用不同GPU運行。每個任務完成工作:創建一個tf.train.ClusterSpec,對集羣所有任務進行描述,描述內容對所有任務相同。創建一個tf.train.Server,創建一個服務,運行相應作業計算任務。
TensorFlow分佈式開發API。tf.train.ClusterSpec({“ps”:ps_hosts,”worker”:worke_hosts})。創建TensorFlow集羣描述信息,ps、worker爲作業名稱,ps_phsts、worker_hosts爲作業任務所在節點地址信息。tf.train.ClusterSpec傳入參數,作業和任務間關係映射,映射關係任務通過IP地址、端口號表示。

結構 tf.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
可用任務 /job:local/task:0、/job:local/task:1。
結構 tf.train.ClusterSpec({"worker":["worker0.example.com:2222","worker1.example.com:2222","worker2.example.com:2222"],"ps":["ps0.example.com:2222","ps1.example.com:2222"]})
可用任務 /job:worker/task:0、 /job:worker/task:1、 /job:worker/task:2、 /job:ps/task:0、 /job:ps/task:1

tf.train.Server(cluster,job_name,task_index)。創建服務(主節點服務或工作節點服務),運行作業計算任務,運行任務在task_index指定機器啓動。

#任務0 
cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
server  = tr.train.Server(cluster,job_name="local",task_index=0) 
#任務1 
cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
server  = tr.train.Server(cluster,job_name="local",task_index=1)。

自動化管理節點、監控節點工具。集羣管理工具Kubernetes。
tf.device(device_name_or_function)。設定指定設備執行張量運算,批定代碼運行CPU、GPU。

#指定在task0所在機器執行Tensor操作運算 
with tf.device("/job:ps/task:0"):
  weights_1 = tf.Variable(…)
  biases_1 = tf.Variable(…)

分佈式訓練代碼框架。創建TensorFlow服務器集羣,在該集羣分佈式計算數據流圖。https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/deploy/distributed.md

import argparse
import sys
import tensorflow as tf
FLAGS = None
def main(_):
  # 第1步:命令行參數解析,獲取集羣信息ps_hosts、worker_hosts
  # 當前節點角色信息job_name、task_index
  ps_hosts = FLAGS.ps_hosts.split(",")
  worker_hosts = FLAGS.worker_hosts.split(",")
  # 第2步:創建當前任務節點服務器
  # Create a cluster from the parameter server and worker hosts.
  cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
  # Create and start a server for the local task.
  server = tf.train.Server(cluster,
                           job_name=FLAGS.job_name,
                           task_index=FLAGS.task_index)
  # 第3步:如果當前節點是參數服務器,調用server.join()無休止等待;如果是工作節點,執行第4步
  if FLAGS.job_name == "ps":
    server.join()
  # 第4步:構建要訓練模型,構建計算圖
  elif FLAGS.job_name == "worker":
    # Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % FLAGS.task_index,
        cluster=cluster)):
      # Build model...
      loss = ...
      global_step = tf.contrib.framework.get_or_create_global_step()
      train_op = tf.train.AdagradOptimizer(0.01).minimize(
          loss, global_step=global_step)
    # The StopAtStepHook handles stopping after running given steps.
    # 第5步管理模型訓練過程
    hooks=[tf.train.StopAtStepHook(last_step=1000000)]
    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    with tf.train.MonitoredTrainingSession(master=server.target,
                                           is_chief=(FLAGS.task_index == 0),
                                           checkpoint_dir="/tmp/train_logs",
                                           hooks=hooks) as mon_sess:
      while not mon_sess.should_stop():
        # Run a training step asynchronously.
        # See `tf.train.SyncReplicasOptimizer` for additional details on how to
        # perform *synchronous* training.
        # mon_sess.run handles AbortedError in case of preempted PS.
        # 訓練模型
        mon_sess.run(train_op)
if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  # Flags for defining the tf.train.ClusterSpec
  parser.add_argument(
      "--ps_hosts",
      type=str,
      default="",
      help="Comma-separated list of hostname:port pairs"
  )
  parser.add_argument(
      "--worker_hosts",
      type=str,
      default="",
      help="Comma-separated list of hostname:port pairs"
  )
  parser.add_argument(
      "--job_name",
      type=str,
      default="",
      help="One of 'ps', 'worker'"
  )
  # Flags for defining the tf.train.Server
  parser.add_argument(
      "--task_index",
      type=int,
      default=0,
      help="Index of task within the job"
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

分佈式最佳實踐。https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py
MNIST數據集分佈式訓練。開設3個端口作分佈式工作節點部署,2222端口參數服務器,2223端口工作節點0,2224端口工作節點1。參數服務器執行參數更新任務,工作節點0、工作節點1執行圖模型訓練計算任務。參數服務器/job:ps/task:0 cocalhost:2222,工作節點/job:worker/task:0 cocalhost:2223,工作節點/job:worker/task:1 cocalhost:2224。
運行代碼。

python mnist_replica.py --job_name="ps" --task_index=0
python mnist_replica.py --job_name="worker" --task_index=0
python mnist_replica.py --job_name="worker" --task_index=1

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import sys
import tempfile
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 定義常量,用於創建數據流圖
flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/mnist-data",
                    "Directory for storing mnist data")
# 只下載數據,不做其他操作
flags.DEFINE_boolean("download_only", False,
                     "Only perform downloading of data; Do not proceed to "
                     "session preparation, model definition or training")
# task_index從0開始。0代表用來初始化變量的第一個任務
flags.DEFINE_integer("task_index", None,
                     "Worker task index, should be >= 0. task_index=0 is "
                     "the master worker task the performs the variable "
                     "initialization ")
# 每臺機器GPU個數,機器沒有GPU爲0
flags.DEFINE_integer("num_gpus", 1,
                     "Total number of gpus for each machine."
                     "If you don't use GPU, please set it to '0'")
# 同步訓練模型下,設置收集工作節點數量。默認工作節點總數
flags.DEFINE_integer("replicas_to_aggregate", None,
                     "Number of replicas to aggregate before parameter update"
                     "is applied (For sync_replicas mode only; default: "
                     "num_workers)")
flags.DEFINE_integer("hidden_units", 100,
                     "Number of units in the hidden layer of the NN")
# 訓練次數
flags.DEFINE_integer("train_steps", 200,
                     "Number of (global) training steps to perform")
flags.DEFINE_integer("batch_size", 100, "Training batch size")
flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
# 使用同步訓練、異步訓練
flags.DEFINE_boolean("sync_replicas", False,
                     "Use the sync_replicas (synchronized replicas) mode, "
                     "wherein the parameter updates from workers are aggregated "
                     "before applied to avoid stale gradients")
# 如果服務器已經存在,採用gRPC協議通信;如果不存在,採用進程間通信
flags.DEFINE_boolean(
    "existing_servers", False, "Whether servers already exists. If True, "
    "will use the worker hosts via their GRPC URLs (one client process "
    "per worker host). Otherwise, will create an in-process TensorFlow "
    "server.")
# 參數服務器主機
flags.DEFINE_string("ps_hosts","localhost:2222",
                    "Comma-separated list of hostname:port pairs")
# 工作節點主機
flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
                    "Comma-separated list of hostname:port pairs")
# 本作業是工作節點還是參數服務器
flags.DEFINE_string("job_name", None,"job name: worker or ps")
FLAGS = flags.FLAGS
IMAGE_PIXELS = 28
def main(unused_argv):
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
  if FLAGS.download_only:
    sys.exit(0)
  if FLAGS.job_name is None or FLAGS.job_name == "":
    raise ValueError("Must specify an explicit `job_name`")
  if FLAGS.task_index is None or FLAGS.task_index =="":
    raise ValueError("Must specify an explicit `task_index`")
  print("job name = %s" % FLAGS.job_name)
  print("task index = %d" % FLAGS.task_index)
  #Construct the cluster and start the server
  # 讀取集羣描述信息
  ps_spec = FLAGS.ps_hosts.split(",")
  worker_spec = FLAGS.worker_hosts.split(",")
  # Get the number of workers.
  num_workers = len(worker_spec)
  # 創建TensorFlow集羣描述對象
  cluster = tf.train.ClusterSpec({
      "ps": ps_spec,
      "worker": worker_spec})
  # 爲本地執行任務創建TensorFlow Server對象。
  if not FLAGS.existing_servers:
    # Not using existing servers. Create an in-process server.
    # 創建本地Sever對象,從tf.train.Server這個定義開始,每個節點開始不同
    # 根據執行的命令的參數(作業名字)不同,決定這個任務是哪個任務
    # 如果作業名字是ps,進程就加入這裏,作爲參數更新的服務,等待其他工作節點給它提交參數更新的數據
    # 如果作業名字是worker,就執行後面的計算任務
    server = tf.train.Server(
        cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
    # 如果是參數服務器,直接啓動即可。這裏,進程就會阻塞在這裏
    # 下面的tf.train.replica_device_setter代碼會將參數批定給ps_server保管
    if FLAGS.job_name == "ps":
      server.join()
  # 處理工作節點
  # 找出worker的主節點,即task_index爲0的點
  is_chief = (FLAGS.task_index == 0)
  # 如果使用gpu
  if FLAGS.num_gpus > 0:
    # Avoid gpu allocation conflict: now allocate task_num -> #gpu
    # for each worker in the corresponding machine
    gpu = (FLAGS.task_index % FLAGS.num_gpus)
    # 分配worker到指定gpu上運行
    worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
  # 如果使用cpu
  elif FLAGS.num_gpus == 0:
    # Just allocate the CPU to worker server
    # 把cpu分配給worker
    cpu = 0
    worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
  # The device setter will automatically place Variables ops on separate
  # parameter servers (ps). The non-Variable ops will be placed on the workers.
  # The ps use CPU and workers use corresponding GPU
  # 用tf.train.replica_device_setter將涉及變量操作分配到參數服務器上,使用CPU。將涉及非變量操作分配到工作節點上,使用上一步worker_device值。
  # 在這個with語句之下定義的參數,會自動分配到參數服務器上去定義。如果有多個參數服務器,就輪流循環分配
  with tf.device(
      tf.train.replica_device_setter(
          worker_device=worker_device,
          ps_device="/job:ps/cpu:0",
          cluster=cluster)):

    # 定義全局步長,默認值爲0
    global_step = tf.Variable(0, name="global_step", trainable=False)
    # Variables of the hidden layer
    # 定義隱藏層參數變量,這裏是全連接神經網絡隱藏層
    hid_w = tf.Variable(
        tf.truncated_normal(
            [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
            stddev=1.0 / IMAGE_PIXELS),
        name="hid_w")
    hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
    # Variables of the softmax layer
    # 定義Softmax 迴歸層參數變量
    sm_w = tf.Variable(
        tf.truncated_normal(
            [FLAGS.hidden_units, 10],
            stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
        name="sm_w")
    sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
    # Ops: located on the worker specified with FLAGS.task_index
    # 定義模型輸入數據變量
    x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
    y_ = tf.placeholder(tf.float32, [None, 10])
    # 構建隱藏層
    hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
    hid = tf.nn.relu(hid_lin)
    # 構建損失函數和優化器
    y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
    cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
    # 異步訓練模式:自己計算完成梯度就去更新參數,不同副本之間不會去協調進度
    opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
    # 同步訓練模式
    if FLAGS.sync_replicas:
      if FLAGS.replicas_to_aggregate is None:
        replicas_to_aggregate = num_workers
      else:
        replicas_to_aggregate = FLAGS.replicas_to_aggregate
      # 使用SyncReplicasOptimizer作優化器,並且是在圖間複製情況下
      # 在圖內複製情況下將所有梯度平均
      opt = tf.train.SyncReplicasOptimizer(
          opt,
          replicas_to_aggregate=replicas_to_aggregate,
          total_num_replicas=num_workers,
          name="mnist_sync_replicas")
    train_step = opt.minimize(cross_entropy, global_step=global_step)
    if FLAGS.sync_replicas:
      local_init_op = opt.local_step_init_op
      if is_chief:
        # 所有進行計算工作節點裏一個主工作節點(chief)
        # 主節點負責初始化參數、模型保存、概要保存
        local_init_op = opt.chief_init_op
      ready_for_local_init_op = opt.ready_for_local_init_op
      # Initial token and chief queue runners required by the sync_replicas mode
      # 同步訓練模式所需初始令牌、主隊列
      chief_queue_runner = opt.get_chief_queue_runner()
      sync_init_op = opt.get_init_tokens_op()
    init_op = tf.global_variables_initializer()
    train_dir = tempfile.mkdtemp()
    if FLAGS.sync_replicas:
      # 創建一個監管程序,用於統計訓練模型過程中的信息
      # lodger 是保存和加載模型路徑
      # 啓動就會去這個logdir目錄看是否有檢查點文件,有的話就自動加載
      # 沒有就用init_op指定初始化參數
      # 主工作節點(chief)負責模型參數初始化工作
      # 過程中,其他工作節點等待主節眯完成初始化工作,初始化完成後,一起開始訓練數據
      # global_step值是所有計算節點共享的
      # 在執行損失函數最小值時自動加1,通過global_step知道所有計算節點一共計算多少步
      sv = tf.train.Supervisor(
          is_chief=is_chief,
          logdir=train_dir,
          init_op=init_op,
          local_init_op=local_init_op,
          ready_for_local_init_op=ready_for_local_init_op,
          recovery_wait_secs=1,
          global_step=global_step)
    else:
      sv = tf.train.Supervisor(
          is_chief=is_chief,
          logdir=train_dir,
          init_op=init_op,
          recovery_wait_secs=1,
          global_step=global_step)
    # 創建會話,設置屬性allow_soft_placement爲True
    # 所有操作默認使用被指定設置,如GPU
    # 如果該操作函數沒有GPU實現,自動使用CPU設備
    sess_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
        device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
    # The chief worker (task_index==0) session will prepare the session,
    # while the remaining workers will wait for the preparation to complete.
    # 主工作節點(chief),task_index爲0節點初始化會話
    # 其餘工作節點等待會話被初始化後進行計算
    if is_chief:
      print("Worker %d: Initializing session..." % FLAGS.task_index)
    else:
      print("Worker %d: Waiting for session to be initialized..." %
            FLAGS.task_index)
    if FLAGS.existing_servers:
      server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
      print("Using existing server at: %s" % server_grpc_url)
      # 創建TensorFlow會話對象,用於執行TensorFlow圖計算
      # prepare_or_wait_for_session需要參數初始化完成且主節點準備好後,纔開始訓練
      sess = sv.prepare_or_wait_for_session(server_grpc_url,
                                            config=sess_config)
    else:
      sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
    print("Worker %d: Session initialization complete." % FLAGS.task_index)
    if FLAGS.sync_replicas and is_chief:
      # Chief worker will start the chief queue runner and call the init op.
      sess.run(sync_init_op)
      sv.start_queue_runners(sess, [chief_queue_runner])
    # Perform training
    # 執行分佈式模型訓練
    time_begin = time.time()
    print("Training begins @ %f" % time_begin)
    local_step = 0
    while True:
      # Training feed
      # 讀入MNIST訓練數據,默認每批次100張圖片
      batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
      train_feed = {x: batch_xs, y_: batch_ys}
      _, step = sess.run([train_step, global_step], feed_dict=train_feed)
      local_step += 1
      now = time.time()
      print("%f: Worker %d: training step %d done (global step: %d)" %
            (now, FLAGS.task_index, local_step, step))
      if step >= FLAGS.train_steps:
        break
    time_end = time.time()
    print("Training ends @ %f" % time_end)
    training_time = time_end - time_begin
    print("Training elapsed time: %f s" % training_time)
    # Validation feed
    # 讀入MNIST驗證數據,計算驗證的交叉熵
    val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
    val_xent = sess.run(cross_entropy, feed_dict=val_feed)
    print("After %d training step(s), validation cross entropy = %g" %
          (FLAGS.train_steps, val_xent))
if __name__ == "__main__":
  tf.app.run()

參考資料:
《TensorFlow技術解析與實戰》

歡迎推薦上海機器學習工作機會,我的微信:qingxingfengzi

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章