超參數自動優化方法PBT(Population Based Training)

我們知道,機器學習模型的效果好壞很大程度上取決於超參的選取。人肉調參需要依賴經驗與直覺,且花費大量精力。PBT(Population based training)是DeepMind在論文《Population Based Training of Neural Networks》中提出的一種異步的自動超參數調節優化方法。以往的自動調節超參方法可分爲兩類:parallel search和sequential optimization。前者並行執行很多不同超參的優化任務,優點是可以並行利用計算資源更快找到最優解;後者需要利用之前的信息來進行下一步的超參優化,因此只能串行執行,但一般能得到更好的解。PBT完美地結合兩種方法,兼具兩者優點。它被應用於一些領域取得了不錯的效果。如DeepMind的論文《Human-level performance in first-person multiplayer games with population-based deep reinforcement learning》將之用於第一人稱多人遊戲使AI達到人類水平。還有今年UC Berkeley的論文《Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules》中用PBT來自動學習data augmentation策略,在幾個benchmark上達到了不錯的精度。另外,最近自動駕駛公司Waymo也稱將PBT應用於識別任務,與手工調參相比可以提高精度和加快訓練速度。

PBT開局與parallel search類似,會並行訓練一批隨機初始化的模型。過程中它會週期性地將表現好的模型替換表現不好的模型(exploitation),同時再加上隨機擾動(主要是爲了exploration)。PBT與其它方法的一個重要不同是它在訓練的過程中對超參進行調節,因此可以更快地發現超參和優異的schedule。論文《Population Based Training of Neural Networks》中的示意圖非常清楚地示意了整個過程,及與其它方法的區別:
在這裏插入圖片描述

PBT是一種很通用的方法,可以用於很多場景,其一般套路如下:

  1. Step:對模型訓練一步。至於一步是一次iteration還是一個epoch還是其它可以根據需要指定。
  2. Eval:在驗證集上做評估。
  3. Ready: 選取羣體中的一個模型來進行下面的exploit和explore操作(即perturbation)。這個模型一般是上次做過該操作後經過指定的時間(或迭代次數等)。
  4. Exploit: 將那些經過評估比較爛的模型用那些比較牛叉的模型替代。
  5. Explore: 對上一步產生的複製體模型加隨機擾動,如加上隨機值或重採樣。

Ray中實現了PBT算法。Ray中關於PBT有三個example:一個是learning rate搜索pbt_example.py,另一個是強化學習算法PPO的超參數搜索pbt_ppo_example.py。還有一個是pbt_tune_cifar10_with_keras.py。我們來看下最簡單的pbt_example.py。其中的PBTBenchmarkExample類繼承自Trainable類,它是一個toy的模擬環境,假設在模型訓練過程中最優的learning rate是變化的,是accuracy的函數。目標是找到learning rate的schedule。它的核心函數是_train(),這裏會模擬最優的learning rate。

然後看主函數,首先通過ray.init()初始化ray,然後創建PopulationBasedTraining對象,接着通過run()函數開始超參搜索過程。

    pbt = PopulationBasedTraining(
        time_attr="training_iteration",
        metric="mean_accuracy",
        mode="max",
        perturbation_interval=20,
        hyperparam_mutations={
            # distribution for resampling
            "lr": lambda: random.uniform(0.0001, 0.02),
            # allow perturbations within this set of categorical values
            "some_other_factor": [1, 2],
        })
        
    run(
        PBTBenchmarkExample,
        name="pbt_test",
        scheduler=pbt,
        reuse_actors=True,
        verbose=False,
        **{
            "stop": {
                "training_iteration": 2000,
            },
            "num_samples": 4,
            "config": {
                "lr": 0.0001,
                # note: this parameter is perturbed but has no effect on
                # the model training in this example
                "some_other_factor": 1,
            },
        })    

先看第一步,PopulationBasedTraining的實現在python/ray/tune/schedulers/pbt.py中。它繼承自FIFOScheduler類。構造函數中幾個主要參數:

  • time_attr: 用於定義訓練時長的測度,要求單調遞增,比如training_iteration
  • metric: 訓練結果衡量目標。
  • mode: 上面metric屬性是越高越好,還是越低越好。
  • perturbation_interval: 模型會以time_attr爲間隔來進行perturbation。
  • hyperparam_mutations: 需要變異的超參。它是一個dict,對於每個key對應list或者function。如果沒設這個,就需要在custom_explore_fn中指定。
  • quantile_fraction: 決定按多大比例將表現好的頭部模型克隆到尾部模型。
  • resample_probability: 當對超參進行exploration時從原分佈中重新採樣的概率,否則會根據現有的值調整。
  • custom_explore_fn: 自定義的exploration函數。

第二步中run()函數實現在ray/python/ray/tune/tune.py中:

def run(run_or_experiment, name=None, ...):
    trial_executor = traial_executor or RayTrialExecutor(...)
    experiment = run_or_experiment
    if not isinstance(run_or_experiment, Experiment):
    	if not isinstance(run_or_experiment, Experiment):
    	experiment = Experiment(...)
    ...
    runner = TrialRunner(
        search_alg=search_alg or BasicVariantGenerator(),
        scheduler=scheduler or FIFOScheduler(),
        local_checkpoint_dir=experiment.checkpoint_dir,
        remote_checkpoint_dir=experiment.remote_checkpoint_dir,
        sync_to_cloud=sync_to_cloud,
        checkpoint_period=global_checkpoint_period,
        resume=resume,
        launch_web_server=with_server,
        server_port=server_port,
        verbose=bool(verbose > 1),
        trial_executor=trial_executor)
        
    runner.add_experiment(experiment)
    ...
    while not runner.is_finished():
       runner.step()
       ...
       
	wait_for_sync()
	...
	return ExperimentAnalysis(runner.checkpoint_file, trials=trials)

第一個參數run_or_experiment是要訓練的目標任務,參數scheduler就是上面創建的PopulationBasedTraining,負責超參搜索時的調度。

其中幾個關鍵類關係如下圖:
在這裏插入圖片描述
SearchAlgorithm的實現類BasicVariantGenerator會根據給定的Experiment產生參數變體。每個待訓練的參數變體會創建相應的Trial對象。Trial有PENDING, RUNNING, PAUSED, TERMINATED, ERROR幾種狀態。它會開始於PENDING狀態,開始訓練後轉爲RUNNING狀態,出錯了就到ERROR狀態,成功的話就是TERMINATED狀態。訓練中還可能被TrialScheduler暫停(轉入PAUSED狀態)並釋放資源。

TrialRunner是最核心的數據結構,它管理一系列的Trial對象,並且執行一個事件循環,將這些任務通過TrialExecutor的實現類RayTrialExecutor提交到Ray cluster運行。RayTrialExecutor會負責資源的管理。這裏通過Ray分佈執行的主要是Trainable的實現類(上例中就是PBTBenchmarkExample)中的_train()函數。RayTrialExecutor對象中的_running維護了正在運行的Trial。在循環中,TrialRunner會通過TrialScheduler的實現類PopulationBasedTraining來進行調度。它的choose_trial_to_run()函數從trial_runner的queue中拿出狀態爲PENDING或者PAUSED的trial,並且選取離上次做perturbation最久的一個保證儘可能公平。

run函數主要做以下幾步:

  1. 創建RayTrailExecutor對象(如果沒有傳入trial_executor的話)。
  2. 如果目標任務不是以Experiment對象形式給出,會按照給定的其它參數構建Experiment對象。
  3. 創建TrialRunner對象,它基於Ray來調度事件循環。
    1. 創建搜索算法對象(如果沒給),默認爲BasicVariantGenerator(實現在basic_variant.py)。它主要用於產生新的參數變體。
    2. 創建執行實驗的調度器(如果沒給),默認爲FIFOScheduler。上例中給定了PopulationBasedTraining,所以這裏就不需要創建了。
    3. 創建TrialRunner對象(實現在trial_runner.py)。並上面創建的Experiment對象通過add_experiment()函數加到TrialRunner對象中。
  4. 進入主循環,通過TrialRunneris_finished()函數判斷是否結束。如果沒有,就調用TrialRunnerstep()函數執行一步。step()函數的主要工作下面再細說。
  5. 收尾工作。如通過wait_for_sync()函數同步遠端目標,記錄沒有正常結束的trial,返回分析信息。

其中比較關鍵的是step()函數,其主要流程如下:
在這裏插入圖片描述

當一個Trial訓練結束返回結果時,TrialRunner會調用PopulationBasedTrainingon_trial_result()函數。這裏就是PBT的精華了。結合文章開關的PBT一般套路,主要步驟如下:

  1. 如果離上次pertubation的時間還沒到指定間隔,則返回讓該Trial繼續訓練。
  2. 調用_quantiles()函數按設定的比例__quantile_fraction得到所有Trial中表現好的頭部和表現不好的尾部。
  3. 如果當前trial是比較牛的那一批,那趕緊存成checkpoint,等着被其它trial克隆學習。
  4. 如果很不幸地,當前trial屬於比較差的那一批,那就從牛的那批中隨機挑一個(爲trial_to_clone),然後調用_exploit()函數。該函數會調用explore()函數對trial_to_clone進行擾動,然後將它的參數設置和checkpoint設置到當前trial。這樣,當前trial就“洗心革面”,重新出發了。
  5. 如果TrialRunner中有PENDING和PAUSED狀態的trial,則請求暫停當前trial,讓出資源。否則的話就繼續訓練着。

最後,總結下主要模塊間的大體流程:
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章