我們知道,機器學習模型的效果好壞很大程度上取決於超參的選取。人肉調參需要依賴經驗與直覺,且花費大量精力。PBT(Population based training)是DeepMind在論文《Population Based Training of Neural Networks》中提出的一種異步的自動超參數調節優化方法。以往的自動調節超參方法可分爲兩類:parallel search和sequential optimization。前者並行執行很多不同超參的優化任務,優點是可以並行利用計算資源更快找到最優解;後者需要利用之前的信息來進行下一步的超參優化,因此只能串行執行,但一般能得到更好的解。PBT完美地結合兩種方法,兼具兩者優點。它被應用於一些領域取得了不錯的效果。如DeepMind的論文《Human-level performance in first-person multiplayer games with population-based deep reinforcement learning》將之用於第一人稱多人遊戲使AI達到人類水平。還有今年UC Berkeley的論文《Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules》中用PBT來自動學習data augmentation策略,在幾個benchmark上達到了不錯的精度。另外,最近自動駕駛公司Waymo也稱將PBT應用於識別任務,與手工調參相比可以提高精度和加快訓練速度。
PBT開局與parallel search類似,會並行訓練一批隨機初始化的模型。過程中它會週期性地將表現好的模型替換表現不好的模型(exploitation),同時再加上隨機擾動(主要是爲了exploration)。PBT與其它方法的一個重要不同是它在訓練的過程中對超參進行調節,因此可以更快地發現超參和優異的schedule。論文《Population Based Training of Neural Networks》中的示意圖非常清楚地示意了整個過程,及與其它方法的區別:
PBT是一種很通用的方法,可以用於很多場景,其一般套路如下:
- Step:對模型訓練一步。至於一步是一次iteration還是一個epoch還是其它可以根據需要指定。
- Eval:在驗證集上做評估。
- Ready: 選取羣體中的一個模型來進行下面的exploit和explore操作(即perturbation)。這個模型一般是上次做過該操作後經過指定的時間(或迭代次數等)。
- Exploit: 將那些經過評估比較爛的模型用那些比較牛叉的模型替代。
- Explore: 對上一步產生的複製體模型加隨機擾動,如加上隨機值或重採樣。
Ray中實現了PBT算法。Ray中關於PBT有三個example:一個是learning rate搜索pbt_example.py,另一個是強化學習算法PPO的超參數搜索pbt_ppo_example.py。還有一個是pbt_tune_cifar10_with_keras.py。我們來看下最簡單的pbt_example.py
。其中的PBTBenchmarkExample
類繼承自Trainable
類,它是一個toy的模擬環境,假設在模型訓練過程中最優的learning rate是變化的,是accuracy的函數。目標是找到learning rate的schedule。它的核心函數是_train()
,這裏會模擬最優的learning rate。
然後看主函數,首先通過ray.init()
初始化ray,然後創建PopulationBasedTraining
對象,接着通過run()函數開始超參搜索過程。
pbt = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
mode="max",
perturbation_interval=20,
hyperparam_mutations={
# distribution for resampling
"lr": lambda: random.uniform(0.0001, 0.02),
# allow perturbations within this set of categorical values
"some_other_factor": [1, 2],
})
run(
PBTBenchmarkExample,
name="pbt_test",
scheduler=pbt,
reuse_actors=True,
verbose=False,
**{
"stop": {
"training_iteration": 2000,
},
"num_samples": 4,
"config": {
"lr": 0.0001,
# note: this parameter is perturbed but has no effect on
# the model training in this example
"some_other_factor": 1,
},
})
先看第一步,PopulationBasedTraining
的實現在python/ray/tune/schedulers/pbt.py
中。它繼承自FIFOScheduler
類。構造函數中幾個主要參數:
time_attr
: 用於定義訓練時長的測度,要求單調遞增,比如training_iteration
。metric
: 訓練結果衡量目標。mode
: 上面metric屬性是越高越好,還是越低越好。perturbation_interval
: 模型會以time_attr
爲間隔來進行perturbation。hyperparam_mutations
: 需要變異的超參。它是一個dict,對於每個key對應list或者function。如果沒設這個,就需要在custom_explore_fn
中指定。quantile_fraction
: 決定按多大比例將表現好的頭部模型克隆到尾部模型。resample_probability
: 當對超參進行exploration時從原分佈中重新採樣的概率,否則會根據現有的值調整。custom_explore_fn
: 自定義的exploration函數。
第二步中run()
函數實現在ray/python/ray/tune/tune.py
中:
def run(run_or_experiment, name=None, ...):
trial_executor = traial_executor or RayTrialExecutor(...)
experiment = run_or_experiment
if not isinstance(run_or_experiment, Experiment):
if not isinstance(run_or_experiment, Experiment):
experiment = Experiment(...)
...
runner = TrialRunner(
search_alg=search_alg or BasicVariantGenerator(),
scheduler=scheduler or FIFOScheduler(),
local_checkpoint_dir=experiment.checkpoint_dir,
remote_checkpoint_dir=experiment.remote_checkpoint_dir,
sync_to_cloud=sync_to_cloud,
checkpoint_period=global_checkpoint_period,
resume=resume,
launch_web_server=with_server,
server_port=server_port,
verbose=bool(verbose > 1),
trial_executor=trial_executor)
runner.add_experiment(experiment)
...
while not runner.is_finished():
runner.step()
...
wait_for_sync()
...
return ExperimentAnalysis(runner.checkpoint_file, trials=trials)
第一個參數run_or_experiment
是要訓練的目標任務,參數scheduler
就是上面創建的PopulationBasedTraining
,負責超參搜索時的調度。
其中幾個關鍵類關係如下圖:
SearchAlgorithm
的實現類BasicVariantGenerator
會根據給定的Experiment
產生參數變體。每個待訓練的參數變體會創建相應的Trial
對象。Trial
有PENDING, RUNNING, PAUSED, TERMINATED, ERROR幾種狀態。它會開始於PENDING狀態,開始訓練後轉爲RUNNING狀態,出錯了就到ERROR狀態,成功的話就是TERMINATED狀態。訓練中還可能被TrialScheduler
暫停(轉入PAUSED狀態)並釋放資源。
TrialRunner
是最核心的數據結構,它管理一系列的Trial
對象,並且執行一個事件循環,將這些任務通過TrialExecutor
的實現類RayTrialExecutor
提交到Ray cluster運行。RayTrialExecutor
會負責資源的管理。這裏通過Ray分佈執行的主要是Trainable
的實現類(上例中就是PBTBenchmarkExample
)中的_train()
函數。RayTrialExecutor
對象中的_running
維護了正在運行的Trial
。在循環中,TrialRunner
會通過TrialScheduler
的實現類PopulationBasedTraining
來進行調度。它的choose_trial_to_run()
函數從trial_runner
的queue中拿出狀態爲PENDING或者PAUSED的trial,並且選取離上次做perturbation最久的一個保證儘可能公平。
run
函數主要做以下幾步:
- 創建
RayTrailExecutor
對象(如果沒有傳入trial_executor
的話)。 - 如果目標任務不是以
Experiment
對象形式給出,會按照給定的其它參數構建Experiment
對象。 - 創建
TrialRunner
對象,它基於Ray來調度事件循環。- 創建搜索算法對象(如果沒給),默認爲
BasicVariantGenerator
(實現在basic_variant.py
)。它主要用於產生新的參數變體。 - 創建執行實驗的調度器(如果沒給),默認爲
FIFOScheduler
。上例中給定了PopulationBasedTraining
,所以這裏就不需要創建了。 - 創建
TrialRunner
對象(實現在trial_runner.py
)。並上面創建的Experiment
對象通過add_experiment()
函數加到TrialRunner
對象中。
- 創建搜索算法對象(如果沒給),默認爲
- 進入主循環,通過
TrialRunner
的is_finished()
函數判斷是否結束。如果沒有,就調用TrialRunner
的step()
函數執行一步。step()
函數的主要工作下面再細說。 - 收尾工作。如通過
wait_for_sync()
函數同步遠端目標,記錄沒有正常結束的trial,返回分析信息。
其中比較關鍵的是step()
函數,其主要流程如下:
當一個Trial
訓練結束返回結果時,TrialRunner
會調用PopulationBasedTraining
的on_trial_result()
函數。這裏就是PBT的精華了。結合文章開關的PBT一般套路,主要步驟如下:
- 如果離上次pertubation的時間還沒到指定間隔,則返回讓該
Trial
繼續訓練。 - 調用
_quantiles()
函數按設定的比例__quantile_fraction
得到所有Trial
中表現好的頭部和表現不好的尾部。 - 如果當前trial是比較牛的那一批,那趕緊存成checkpoint,等着被其它trial克隆學習。
- 如果很不幸地,當前trial屬於比較差的那一批,那就從牛的那批中隨機挑一個(爲
trial_to_clone
),然後調用_exploit()
函數。該函數會調用explore()
函數對trial_to_clone
進行擾動,然後將它的參數設置和checkpoint設置到當前trial。這樣,當前trial就“洗心革面”,重新出發了。 - 如果
TrialRunner
中有PENDING和PAUSED狀態的trial,則請求暫停當前trial,讓出資源。否則的話就繼續訓練着。
最後,總結下主要模塊間的大體流程: