轉自:https://spaces.ac.cn/archives/5879
今天我們來看一個小衆需求:自定義優化器。
細想之下,不管用什麼框架,自定義優化器這個需求可謂真的是小衆中的小衆。一般而言,對於大多數任務我們都可以無腦地直接上Adam,而調參煉丹高手一般會用SGD來調出更好的效果,換言之不管是高手新手,都很少會有自定義優化器的需求。
那這篇文章還有什麼價值呢?有些場景下會有一點點作用。比如通過學習Keras中的優化器寫法,你可以對梯度下降等算法有進一步的認識,你還可以順帶看到Keras的源碼是多麼簡潔優雅。此外,有時候我們可以通過自定義優化器來實現自己的一些功能,比如給一些簡單的模型(例如Word2Vec)重寫優化器(直接寫死梯度,而不是用自動求導),可以使得算法更快;自定義優化器還可以實現諸如“軟batch”的功能。
Keras優化器 #
我們首先來看Keras中自帶優化器的代碼,位於:
https://github.com/keras-team/keras/blob/master/keras/optimizers.py
簡單起見,我們可以先挑SGD來看。當然,Keras中的SGD算法已經把momentum、nesterov、decay等整合進去了,這使用起來方便,但不利於學習。所以我稍微簡化了一下,給出一個純粹的SGD算法的例子:
from keras.legacy import interfaces
from keras.optimizers import Optimizer
from keras import backend as K
class SGD(Optimizer):
"""Keras中簡單自定義SGD優化器
"""
def __init__(self, lr=0.01, **kwargs):
super(SGD, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.lr = K.variable(lr, name='lr')
@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
"""主要的參數更新算法
"""
grads = self.get_gradients(loss, params) # 獲取梯度
self.updates = [K.update_add(self.iterations, 1)] # 定義賦值算子集合
self.weights = [self.iterations] # 優化器帶來的權重,在保存模型時會被保存
for p, g in zip(params, grads):
# 梯度下降
new_p = p - self.lr * g
# 如果有約束,對參數加上約束
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
# 添加賦值
self.updates.append(K.update(p, new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr))}
base_config = super(SGD, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
應該不是解釋了吧?有沒有特別簡單的感覺?定義一個優化器也不是特別高大上的事情嘛~
實現“軟batch” #
現在來實現一個稍微複雜一點的功能,就是所謂的“軟batch”,不過我不大清楚是不是就叫這個名字,姑且先這樣叫着吧。大概的場景是:假如模型比較龐大,自己的顯卡最多也就能跑batch size=16,但我又想起到batch size=64的效果,那可以怎麼辦呢?一種可以考慮的方案是,每次算batch size=16,然後把梯度緩存起來,4個batch後才更新參數。也就是說,每個小batch都算梯度,但每4個batch才更新一次參數。
如果真的有這個需求,那麼就只能通過修改優化器來解決了。在前面的SGD的基礎上,參考代碼如下:
class MySGD(Optimizer):
"""Keras中簡單自定義SGD優化器
每隔一定的batch才更新一次參數
"""
def __init__(self, lr=0.01, steps_per_update=1, **kwargs):
super(MySGD, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.lr = K.variable(lr, name='lr')
self.steps_per_update = steps_per_update # 多少batch才更新一次
@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
"""主要的參數更新算法
"""
shapes = [K.int_shape(p) for p in params]
sum_grads = [K.zeros(shape) for shape in shapes] # 平均梯度,用來梯度下降
grads = self.get_gradients(loss, params) # 當前batch梯度
self.updates = [K.update_add(self.iterations, 1)] # 定義賦值算子集合
self.weights = [self.iterations] + sum_grads # 優化器帶來的權重,在保存模型時會被保存
for p, g, sg in zip(params, grads, sum_grads):
# 梯度下降
new_p = p - self.lr * sg / float(self.steps_per_update)
# 如果有約束,對參數加上約束
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
cond = K.equal(self.iterations % self.steps_per_update, 0)
# 滿足條件才更新參數
self.updates.append(K.switch(cond, K.update(p, new_p), p))
# 滿足條件就要重新累積,不滿足條件直接累積
self.updates.append(K.switch(cond, K.update(sg, g), K.update(sg, sg+g)))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'steps_per_update': self.steps_per_update}
base_config = super(MySGD, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
應該也很容易理解吧。如果帶有動量的情況,寫起來複雜一點,但也是一樣的。重點就是引入多一個變量來儲存累積梯度,然後引入cond來控制是否更新,原來優化器要做的事情,都要在cond爲True的情況下才做(梯度改爲累積起來的梯度)。對比原始的SGD,改動並不大。
“侵入式”優化器 #
上面實現優化器的方案是標準的,也就是按Keras的設計規範來做的,所以做起來很輕鬆。然而我曾經想要實現的一個優化器,卻不能用這種方式來實現,經過閱讀源碼,得到了一種“侵入式”的寫法,這種寫法類似“外掛”的形式,可以實現我需要的功能,但不是標準的寫法,在此也跟大家分享一下。
原始需求來源於之前的文章《從動力學角度看優化算法(一):從SGD到動量加速》,裏邊指出梯度下降優化器可以看成是微分方程組的歐拉解法,進一步可以聯想到,微分方程組有很多比歐拉解法更高級的解法呀,能不能用到深度學習中?比如稍微高級一點的有“Heun方法”:
p̃ i+1=pi+1=pi+ϵg(pi)pi+12ϵ[g(pi)+g(p̃ i+1)]
其中p是參數(向量),g是梯度,pi表示p的第i次迭代時的結果。這個算法需要走兩步,大概意思就是普通的梯度下降先走一步(探路),然後根據探路的結果取平均,得到更精準的步伐,等價地可以改寫爲:
p̃ i+1=pi+1=pi+ϵg(pi)p̃ i+1+12ϵ[g(p̃ i+1)−g(pi)]
這樣就清楚顯示出後面這一步實際上是對梯度下降的微調。
但是實現這類算法卻有個難題,要計算兩次梯度,一次對參數g(pi)
,另一次對參數p̃ i+1
。而前面的優化器定義中get_updates這個方法卻只能執行一步(對應到tf框架中,就是執行一步sess.run,熟悉tf的朋友知道單單執行一步sess.run很難實現這個需求),因此實現不了這種算法。經過研究Keras模型的訓練源碼,我發現可以這樣寫:
class HeunOptimizer:
"""自定義Keras的侵入式優化器
"""
def __init__(self, lr):
self.lr = lr
def __call__(self, model):
"""需要傳入模型,直接修改模型的訓練函數,而不按常規流程使用優化器,所以稱爲“侵入式”
其實下面的大部分代碼,都是直接抄自keras的源碼:
https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L491
也就是keras中的_make_train_function函數。
"""
params = model._collected_trainable_weights
loss = model.total_loss
inputs = (model._feed_inputs +
model._feed_targets +
model._feed_sample_weights)
inputs += [K.learning_phase()]
with K.name_scope('training'):
with K.name_scope('heun_optimizer'):
old_grads = [[K.zeros(K.int_shape(p)) for p in params]]
update_functions = []
for i,step in enumerate([self.step1, self.step2]):
updates = (model.updates +
step(loss, params, old_grads) +
model.metrics_updates)
# 給每一步定義一個K.function
updates = K.function(inputs,
[model.total_loss] + model.metrics_tensors,
updates=updates,
name='train_function_%s'%i,
**model._function_kwargs)
update_functions.append(updates)
def F(ins):
# 將多個K.function封裝爲一個單獨的函數
# 一個K.function就是一次sess.run
for f in update_functions:
_ = f(ins)
return _
# 最後只需要將model的train_function屬性改爲對應的函數
model.train_function = F
def step1(self, loss, params, old_grads):
ops = []
grads = K.gradients(loss, params)
for p,g,og in zip(params, grads, old_grads[0]):
ops.append(K.update(og, g))
ops.append(K.update(p, p - self.lr * g))
return ops
def step2(self, loss, params, old_grads):
ops = []
grads = K.gradients(loss, params)
for p,g,og in zip(params, grads, old_grads[0]):
ops.append(K.update(p, p - 0.5 * self.lr * (g - og)))
return ops
用法是:
opt = HeunOptimizer(0.1)
opt(model) # model必須是Model型模型,而且已經compile過(compile的時候可以隨便指定一個優化器)
model.fit(x_train, y_train, epochs=100, batch_size=32)
其中關鍵思想在代碼中已經註釋了,主要是Keras的優化器最終都會被包裝爲一個train_function,所以我們只需要參照Keras的源碼設計好train_function,並在其中插入我們自己的操作。在這個過程中,需要留意到K.function所定義的操作相當於一次sess.run就行了。
注:類似地還可以實現RK23、RK45等算法。遺憾的是,這種優化器缺很容易過擬合,也就是很容易將訓練集的loss降到很低,但是驗證集的loss和準確率都很差~
優雅的Keras #
本文講了一個非常非常小衆的需求:自定義優化器,介紹了一般情況下Keras優化器的寫法,以及一種“侵入式”的寫法。如果真有這麼個特殊需求,可以參考使用。
通過Keras中優化器的分析研究,我們進一步可以觀察到Keras整體代碼實在是非常簡潔優雅,難以挑剔~