EDVR——代碼調試+訓練

2019年CVPR的文章EDVR: Video Restoration with Enhanced Deformable Convolutional Networks,做的是視頻處理(包括視頻幀的超分辨率技術與去模糊),從結構上看能夠處理應用於任意書品轉換的強監督任務;文中最亮眼的地方在於他提出了保證時序一致性(temporal consistency)的新方法,不是使用光流(optical flow),而是藉助可變形卷積對可追蹤點進行追蹤,成爲PCD模塊;以及提出了多幀處理時信息的融合的spatial-temporal維度的融合,成爲TSA模塊;代碼見:EDVR.

小編受委託於一個師姐,幫她調試了代碼。

(一)環境準備

老規矩哈,對每個項目新建一個虛擬環境,完了後再刪除。虛擬環境的新建見我的另一篇博客:vid2vid 代碼調試+訓練+測試(debug+train+test)(一)測試篇

(二)下載工程

$ git clone https://github.com/xinntao/EDVR.git
$ cd EDVR/

目錄結構如下, 其中,experiments用來保存訓練的模型和驗證結果(checkpoints),tb_logger用來保存log日誌;codes是主要的,包括了各種代碼(包括train.py);datasets其實可以不管,用來存放數據集,但是後面你會發現數據集的引用是使用絕對路徑(Σ( ° △ °|||)︴)!

(三)數據集準備

這應該是本項目最繁瑣的一部分了。

先 pip 安裝 lmdb 。

準備好數據集後,修改代碼“codes/data_scripts/create_lmdb_mp.py”。建議數據集結構如下:

"""
datasets
     |--inputs
           |--0000                        # clip
               |--00000000.png
               |--00000001.png
               ...
               |--00000029.png            # 30 frames for each clip in our datasets
           |--0001
               |--00000000.png
               ...
           ...
     |--GT
           |--0000                        # clip
               |--00000000.png
               |--00000001.png
               ...
               |--00000029.png            # 30 frames for each clip in our datasets
           |--0001
               |--00000000.png
               ...
           ...
"""

指定一個mode後,修改或者新增 if -else 結構,修改input或者GT的目錄與保存路徑。

針對input與GT分別執行create_lmdb_mp.py一次。

這之後,我們可以看看對應的"train_xxxx_wval.lmdb(是一個文件夾)"包含了什麼。,其中的data.mdb就是圖像數據了,是以字節形式存儲,所以空間較小,讀取快!lock.mdb僅是數據庫中防止衝突的操作,當然代碼中對數據的讀取都是隻讀,所以影響不大。meata_info.pkl是一個字典,結構如下:(我們待會要使用它!)

"""
{
    'name': 'xxxxxxxxxxxx',
    'resolution': '3_720_1280',
    'keys': ['0000_00000000', '0000_00000001', ...]
}
"""

僅僅如此還不夠!因爲meta_info.pkl的作用在於,幫助讀取mdb文件中對應的圖像數據。那麼,我們還需要一個keys來告訴datatset 有哪一些keys,因此我們需要將生成的兩個"train_xxxx_wval.lmdb"中的一個的meta_info.pkl複製到目錄“codes/data/”下(因爲input與GT的名字對應一般是相同的),注意到原先已經存在“REDS_trainval_keys.pkl”和“Vimeo90K_train_keys.pkl”了。我們這裏是重命名爲:EFRM_train_keys.pkl,待會在配置文件中的cache_keys指向它。

注意meta_info.pkl裏存儲的是字典,其中包含keys對應了文件名的列表;而原有的“REDS_trainval_keys.pkl”和“Vimeo90K_train_keys.pkl”則只是包含了列表;因此,我們還需要修改一下下面代碼:

/* codes/data/REDS_dataset.py: __init__() */

“cache_keys”是在配置文件中設置的,見(四), pickle是python讀寫pkl的一個包;我們載入剛剛的pkl後是一個字典,所以我們需要將在後面補上“[ 'keys' ]”表示只取keys鍵對應的值(文件名的列表)。

(四)設置配置文件

在路徑“./codes/options/train/”下有兩個“.yml”文件,這是參數配置文件。參數意義與配置如下:

#### general settings
name: 001_EDVRwoTSA_scratch_lr4e-4_600k_REDS_LrCAR4S                   # 爲本次實驗命名
use_tb_logger: true                                                    # 是否要輸出和保存日誌(一般都是要的吧~)
model: VideoSR_base                                                    # 使用的模型(不用改,這是作者文章的模型,通過參數配置可以構造文章所有的模型)
distortion: sr                         
scale: 4                                                               # 輸出大小是輸入的4倍(不用改,實際上就是網絡最後對應多了多少層上採樣)
gpu_ids: [3]                                                           # 可以使用單核gpu(特別適合小編這種窮人)

#### datasets
datasets:
  train:
    name: REDS
    mode: REDS
    interval_list: [1]                                                 # 相鄰幀:t-i, t, t+i
    random_reverse: false                                              # 是否隨機對幀序取反
    border_mode: false
    dataroot_GT: /home/xyy/ssd/xwp/__temp__/train_EFGT_wval.lmdb       # GT數據的絕對路徑 
    dataroot_LQ: /home/xyy/ssd/xwp/__temp__/train_EF_wval.lmdb         # 輸入數據的絕對路徑
    cache_keys: EFRM_train_keys.pkl                                    # 前面我們自定義的訓練數據的文件名彙總(存儲的是:List: ['0000_00000000', '0000_00000001', ..., '0001_00000000', ...])

    N_frames: 5                                                        # 輸入的幀數(中間幀爲key)
    use_shuffle: true
    n_workers: 3  # per GPU
    batch_size: 8
    GT_size: 256                                                       
    LQ_size: 256                                                       # 如果做得不是SR的任務,而是deblur/derain等輸入輸出的分辨率一樣的話,這裏要求設置:GT_size = LQ_size,具體數值不管;而如果是SR任務,則需要保證:GT_size/LQ_size = scale
    use_flip: true                                                     # 隨機翻轉(水平/垂直)做數據增強
    use_rot: true                                                      # 隨機旋轉
    color: RGB

#### network structures
network_G:
  which_model_G: EDVR
  nf: 64                                                               # 第一個conv的通道數
  nframes: 5
  groups: 8
  front_RBs: 5
  back_RBs: 10
  predeblur: true                                                      # 是否使用一個預編碼層,它的作用是對輸入 HxW 經過下采樣得到 H/4xW/4 的feature,以便符合後面的網絡
  HR_in: true                                                          # 很重要!!只要你的輸入與輸出是同樣分辨率,就要求設置爲true
  w_TSA: true                                                          # 是否使用TSA模塊

#### path
path:
  pretrain_model_G: ~                                                  # 假如沒有與訓練的模型,設置爲~(表示None)
  strict_load: true
  resume_state: ~

#### training settings: learning rate scheme, loss
train:
  lr_G: !!float 4e-4
  lr_scheme: CosineAnnealingLR_Restart
  beta1: 0.9
  beta2: 0.99
  niter: 600000
  warmup_iter: -1  # -1: no warm up
  T_period: [150000, 150000, 150000, 150000]
  restarts: [150000, 300000, 450000]
  restart_weights: [1, 1, 1]
  eta_min: !!float 1e-7

  pixel_criterion: cb
  pixel_weight: 1.0
  val_freq: !!float 2e3

  manual_seed: 0

#### logger
logger:
  print_freq: 10                                                       # 每多少個iterations打印日誌
  save_checkpoint_freq: !!float 2e3                                    # 沒多少個iterations保存模型

(五) 修改代碼

下面小編將展示次工程比較不友好的一個地方。

以“codes/data/REDS_dataset.py”爲例,在函數__getitem__(self, index)中有一個坑。

(本項目讀取數據的規則是:在前面部分將所有的數據封裝成lmdb的形式,需要通過key(圖片名,無後綴)進行讀取;在dataset的__getitem__中,是先將所有的keys讀入(就是前面我們需要自己準備的"XXX_keys.pkl"文件)),然後每次讀取連續的幾個keys,再經過_read_img_mc_BGR函數去獲取圖像數組。

這裏有幾個數值我們需要修改:

1)上面的兩個紅框,原本的數值是99;這是因爲作者用的訓練數據每一個clip中含有100幀(xxxx_00000000, xxxx_00000001, ..., xxxx_00000099) ,爲了保證不讀取到兩個clips的幀,需要對幀的索引做檢查。師姐的數據中每個clip的幀數是30,所以這裏要設置成29.

2)假如讀者使用與作者相同的命名格式:“xxxx_xxxxxxxx”,那麼底下的框就不需要修改;但假如不是,像師姐的命名是“xxxx_xxxxxx”,所以這裏就需要改成“{:06d}”而不是原來的“{:06d}”。

這裏最好奇的是,上面的“99”爲什麼不設置成一個超參數?

(六)訓練

python -m torch.distributed.launch --nproc_per_node=2 --master_port=21688 train.py -opt options/train/<我自己的配置文件>.yml --launcher pytorch

# 注意這裏的 master_port 不是固定的,根據自己服務器當前的端口使用,賦予一個沒有使用的端口即可;否則會發生系統錯誤,甚至無法fork出子進程

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章