TensorFlowOnSpark源碼解析

  前言

  這兩天琢磨了下spark-deep-learning和spark-sklearn兩個項目,但是感覺都不盡人如意。在training時,都需要把數據broadcast到各個節點進行並行訓練,基本就失去實用價值了(tranning數據都會大於單節點內存的好麼),而且spark-deep-learning目前還沒有實現和tf cluster的結合。所以這個時候轉向了開源已久的yahoo的TensorFlowOnSpark項目。簡單了過了下他的源碼,大致理清楚了原理,這裏算是記錄下來,也希望能幫到讀者。

  TensorFlowOnSpark 代碼運行剖析

  從項目中打開examples/mnist/spark/mnist_spark/mnist_dist.py,

  第一步通過pyspark創建SparkContext,這個過程其實就啓動了Spark cluster,至於如何通過python啓動spark 並且進行相互通訊

  第二步是接受一些命令行參數,這個我就不貼了。

  第三步是使用標準的pyspark API 從HDFS獲取圖片數據,構成一個

  接着就是開始進入正題,啓動tf cluster了:

cluster = TFCluster.run(sc, mnist_dist.map_fun, args, args.cluster_size, num_ps, args.tensorboard, TFCluster.InputMode.SPARK)

  TFCluster.run 裏的sc 就是sparkcontext,mnist_dist.map_fun函數則包含了你的tensorflow業務代碼,在這個示例裏就是minist的模型代碼,模型代碼具體細節代碼我們會晚點說。我們先看看TFCluster.run方法:

  上面是確定parameter server和worker的數目,這兩個概念是和tf相關的。

  接着會啓動一個Server:

  在driver端啓動一個Server,主要是爲了監聽待會spark executor端啓動的tf worker,進行協調。

  上面的代碼獲取完整的啓動tf cluster所需要的信息。建議大家可以去google下如何手動配置tf cluster,然後就能更深入理解TensorFlowOnSpark是如何預先收集好哪些參數。

  上面的第一段代碼其實是爲了確保啓動cluster_size個task,每個task對應一個partition,每個partition其實只有一個元素,就是worker的編號。通過對partition進行foreatch來啓動對應的tf worker(包含ps)。倒數第二行代碼我們又看到了,前面的那個server了,它會阻塞代碼往下執行,直到所有tf worker都啓動爲止。

  到這裏我們也可以看到,一個spark executor可能會啓動多個tf worker。

  現在我們進入 TFSparkNode.run看看,這裏麪包含了具體如何啓動tf worker的邏輯,記得這些代碼已經在executor執行了。

  首先定義了一個函數_mapfn,他的參數是一個iter,這個iter 沒啥用,就是前面的worker編號,只有一個元素。該函數裏主要作用其實就是啓動tf worker(PS)的,並且運行用戶的代碼的:

  啓動的過程中會啓動一個client,連接我們前面說的Server,報告自己成功啓動了。

  這裏會判斷是ps還是worker。如果是後臺運行,則通過multiprocessing.Process直接運行我們前年提到的mnist_dist.map_fun方法,而mnist_dist.map_fun其實包含了tf session的邏輯代碼。當然這個時候模型雖然啓動了,但是因爲在獲取數據時使用了queue.get(block=True) 時,這個時候還沒有數據進來,所以會被阻塞住。值得注意的是,這裏的代碼會發送給spark起的python worker裏執行。

  在獲得cluster對象後,我們就可以調用train方法做真實的訓練了,本質上就是開始喂數據:

  進入 cluster.train看下,會進入如下代碼:

  這裏會把數據按partition的方式餵給每個TF worker(通過調用train方法):

  這裏會拿到tf的queue,然後通過iter(也就是實際的spark rdd包含的訓練數據)往裏面放,如果放滿了就會阻塞。

  直至,大致流程就完成了。現在我們回過頭來看我們的業務代碼mnist_dist.map_fun,該方法其實是在每個tf worker上執行的:

  簡單的做了判定,如果是ps則停止在這,否則執行構建模型的工作。在with tf.device.. 裏面就是開始定義模型什麼的了,標準的tf 代碼了:

  當然,在TensorFlowOnSpark的示例代碼裏,使用了Supervisor:

  TFNode.DataFeed提供了一個便捷的獲取批量數據的方式,讓你不用操心queue的事情。

  在訓練達到必要的數目後,你可以停止訓練:

  現在整個流程應該是比較清晰了。



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章