寫在前面的話
elephas是一個把python深度學習框架keras銜接到Spark集羣的第三方python包。由於這個版本並不穩定,並且沒有什麼資料,我打算剖析其源代碼。
分析代碼要從其主程序開始,就是spark_model.py,其網址在 https://github.com/maxpumperla/elephas/blob/master/elephas/spark_model.py。在這個博客裏,我暫且把發現的一些問題先記下來。由於我並不是一個喜歡“調格式“的人,更多喜歡不求甚解,所以先散亂地分析,之後我會系統化逐行代碼解析。
程序是從哪裏開始的
根據官網上面的程序,https://github.com/maxpumperla/elephas
- Create a local pyspark context
from pyspark import SparkContext, SparkConf
conf = SparkConf().setAppName('Elephas_App').setMaster('local[8]')
sc = SparkContext(conf=conf)
- Define and compile a Keras model
model = Sequential()
model.add(Dense(128, input_dim=784))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(10))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer=SGD())
- Create an RDD from numpy arrays
from elephas.utils.rdd_utils import to_simple_rdd
rdd = to_simple_rdd(sc, X_train, Y_train)
- A SparkModel is defined by passing Spark context and Keras model. Additionally, one has choose an optimizer used for updating the elephas model, an update frequency, a parallelization mode and the degree of parallelism, i.e. the number of workers.
from elephas.spark_model import SparkModel from elephas import optimizers as elephas_optimizers adagrad = elephas_optimizers.Adagrad() spark_model = SparkModel(sc,model, optimizer=adagrad, frequency='epoch', mode='asynchronous', num_workers=2) spark_model.train(rdd, nb_epoch=20, batch_size=32, verbose=0, validation_split=0.1, num_workers=8)
- Run your script using spark-submit
spark-submit --driver-memory 1G ./your_script.py
可以知道,程序是從spark_model 開始的,具體是從這個類的train方法開始的。
模型的流程
train的方法是啥?從train的方法中應該可以抓住模型的執行流程。
先貼代碼
def train(self, rdd, nb_epoch=10, batch_size=32,
verbose=0, validation_split=0.1):
'''
Train an elephas model.
'''
rdd = rdd.repartition(self.num_workers)
master_url = self.determine_master()
if self.mode in ['asynchronous', 'synchronous', 'hogwild']:
self._train(rdd, nb_epoch, batch_size, verbose, validation_split, master_url)
else:
print("""Choose from one of the modes: asynchronous, synchronous or hogwild""")
def _train(self, rdd, nb_epoch=10, batch_size=32, verbose=0,
validation_split=0.1, master_url='localhost:5000'):
'''
Protected train method to make wrapping of modes easier
'''
self.master_network.compile(optimizer=self.master_optimizer, loss=self.master_loss, metrics=self.master_metrics)
if self.mode in ['asynchronous', 'hogwild']:
self.start_server()
yaml = self.master_network.to_yaml()
train_config = self.get_train_config(nb_epoch, batch_size,
verbose, validation_split)
if self.mode in ['asynchronous', 'hogwild']:
worker = AsynchronousSparkWorker(
yaml, train_config, self.frequency, master_url,
self.master_optimizer, self.master_loss, self.master_metrics, self.custom_objects
)
rdd.mapPartitions(worker.train).collect()
new_parameters = get_server_weights(master_url)
elif self.mode == 'synchronous':
init = self.master_network.get_weights()
parameters = self.spark_context.broadcast(init)
worker = SparkWorker(yaml, parameters, train_config)
deltas = rdd.mapPartitions(worker.train).collect()
new_parameters = self.master_network.get_weights()
for delta in deltas:
constraints = self.master_network.constraints
new_parameters = self.optimizer.get_updates(self.weights, constraints, delta)
self.master_network.set_weights(new_parameters)
if self.mode in ['asynchronous', 'hogwild']:
self.stop_server()
從上面的代碼中,我們可以知道
(1)train 的方法,其實直接調用了_train(),當然這樣做的原因,也是一種程序的模塊化的思想。
(2)master_network 從這個類的初始化中,其實就是model,也就是keras中定義的model。self.master_network.compile 是把model編譯一下,這在一般的keras流程中就是這樣的,然後就是對model的fit了。
(3)start_server() 其實是在主節點開始了Flask的app流程,該程序可以使得不同節點可以進行參數之間的通信,實際上就是在master節點建立一個服務器。從而使的 slave節點可以通過 訪問url的方式 和master節點進行參數的交流。
def get_server_weights(master_url='localhost:5000'):
'''
Retrieve master weights from parameter server
'''
request = urllib2.Request('http://{0}/parameters'.format(master_url),
headers={'Content-Type': 'application/elephas'})
ret = urllib2.urlopen(request).read()
weights = pickle.loads(ret)
return weights
def put_deltas_to_server(delta, master_url='localhost:5000'):
'''
Update master parameters with deltas from training process
'''
request = urllib2.Request('http://{0}/update'.format(master_url),
pickle.dumps(delta, -1), headers={'Content-Type': 'application/elephas'})
return urllib2.urlopen(request).read()
這兩個寫在文件前面的函數,就是後面slave節點調用url 去訪問master節點進行參數傳遞。然後我們還可以看看到底服務器端的代碼是怎麼寫的,也就是那個開啓了flask的app的主節點。
def start_service(self):
''' Define service and run flask app'''
app = Flask(__name__)
self.app = app
@app.route('/')
def home():
return 'Elephas'
@app.route('/parameters', methods=['GET'])
def get_parameters():
if self.mode == 'asynchronous':
self.lock.acquire_read()
self.pickled_weights = pickle.dumps(self.weights, -1)
pickled_weights = self.pickled_weights
if self.mode == 'asynchronous':
self.lock.release()
return pickled_weights
@app.route('/update', methods=['POST'])
def update_parameters():
delta = pickle.loads(request.data)
if self.mode == 'asynchronous':
self.lock.acquire_write()
constraints = self.master_network.constraints
if len(constraints) == 0:
def empty(a): return a
constraints = [empty for x in self.weights]
self.weights = self.optimizer.get_updates(self.weights, constraints, delta)
if self.mode == 'asynchronous':
self.lock.release()
return 'Update done'
self.app.run(host='0.0.0.0', debug=True,
threaded=True, use_reloader=False)
在這裏,有一個問題,便是對其參數的改變時候,有一些lock之類的東東,這個是其其他的文件裏寫的,後面在分析,在這裏我感覺應該是某種保護參數的機制。
(4)接下倆一行,是yaml = self.master_network.to_yaml(). 爲什麼要反覆的地to yaml 然後有read from yaml?
一句話,爲了序列化的方便。首先這句話,是keras裏的知識,是把keras model 序列化爲字符串,除了to yaml,還可以to json。反正你知道yaml是一個字符串就好了。本人本着科學研究的精神,試了一把,請看下圖
In [4]: print model.to_yaml()
class_name: Model
config:
input_layers:
- [input_1, 0, 0]
layers:
- class_name: InputLayer
config:
batch_input_shape: !!python/tuple [null, 784]
input_dtype: float32
name: input_1
sparse: false
inbound_nodes: []
name: input_1
- class_name: Dense
config: {W_constraint: null, W_regularizer: null, activation: relu, activity_regularizer: null,
b_constraint: null, b_regularizer: null, bias: true, init: glorot_uniform, input_dim: null,
name: dense_1, output_dim: 64, trainable: true}
inbound_nodes:
- - [input_1, 0, 0]
name: dense_1
- class_name: Dense
config: {W_constraint: null, W_regularizer: null, activation: relu, activity_regularizer: null,
b_constraint: null, b_regularizer: null, bias: true, init: glorot_uniform, input_dim: null,
name: dense_2, output_dim: 64, trainable: true}
inbound_nodes:
- - [dense_1, 0, 0]
name: dense_2
- class_name: Dense
config: {W_constraint: null, W_regularizer: null, activation: softmax, activity_regularizer: null,
b_constraint: null, b_regularizer: null, bias: true, init: glorot_uniform, input_dim: null,
name: dense_3, output_dim: 10, trainable: true}
inbound_nodes:
- - [dense_2, 0, 0]
name: dense_3
name: model_1
output_layers:
- [dense_3, 0, 0]
keras_version: 1.1.1
至於,爲什麼yaml 這樣做是爲了序列化的方便在下面一點解釋。
(5)接下來,我們分析這個分支:if self.mode in ['asynchronous', 'hogwild']: 精彩內容來了:
if self.mode in ['asynchronous', 'hogwild']:
worker = AsynchronousSparkWorker(
yaml, train_config, self.frequency, master_url,
self.master_optimizer, self.master_loss, self.master_metrics, self.custom_objects
)
rdd.mapPartitions(worker.train).collect()
new_parameters = get_server_weights(master_url)
如果選的是asyn 就是同步更新的話,下面就是定義AsynchronousSparkWorker 這個是系統定義了另外一個僅次於spark model的第二個重要的類。寫到這個,讀者可能會問,到底分佈式在哪裏啊,好像目前都是在主節點一個人YY,沒有slave節點什麼事情。通過rdd mapPartitions,這個函數,其中對於每一個節點執行AsynchronousSparkWorker 的train方法。因爲這個設計到這個類在Spark集羣中的傳輸,所以上面的model 要to yaml以方便傳輸。
(6)最後,便是通過get_server_weights,獲得新的參數,
並將其給self.master_network.set_weights(new_parameters)。
-----------updated 2016.11.22 高鐵上
這兩天去北京參加了一個IBM在北京舉辦的機器學習峯會,今年IBM主推了智慧計算,和Spark。今天還見到了Spark的commiter之一 Nick Pentreath。感覺Spark應該在未來幾年都會是一個比較持續熱的分析框架。而在深度學習的火爆也將直接反映在機器學習框架keras,tensorflow的hot。而據我所知,目前而言,python仍然是數據分析者使用的主要語言。因此如何把一些python的第三包部署在spark集羣上面,將會是一個非常hot的話題。當然也就是我分析elephas的意義所在。
上面主要分析的是elephas的主文件spark_model.py中的master節點的主類,是所有代碼的入口。因此這個類的分析我放在這系列的第一篇。接下來,將繼續分析這個主類,主要是摳一些細節。
(1)爲什麼要import python 默認multiporocess?
我目前查到的信息,這是一個多線程的問題,在這個類中,通過把flask作爲multiporocess的一個進程函數進行管理。進一步查資料,發現這個multiporocess實現的 多線程和thread其實是不一樣的。不同的在於thread的多線程實際上用的還是一個核,而multiporocess可以在多核上運行。這個問題,我目前只有這些淺顯的理解,繼續完善吧。
(2)上面的類中,有兩條通信方式?
一個是Spark自身ssh通信,一個是基於flask,http通信協議,這個是用於解決訓練過程中,不同節點的參數管理。
(3)爲什麼import socket?
因爲,爲了確認
結束
好了,暫時第一篇就先到這裏了。我將分析這個文件中的第二類AsynchronousSparkWorker。這個類是處理slave節點的主類。
現在在高鐵上,睡一會,THE Night Is So Black!!!