2019年CVPR的文章EDVR: Video Restoration with Enhanced Deformable Convolutional Networks,做的是視頻處理(包括視頻幀的超分辨率技術與去模糊),從結構上看能夠處理應用於任意書品轉換的強監督任務;文中最亮眼的地方在於他提出了保證時序一致性(temporal consistency)的新方法,不是使用光流(optical flow),而是藉助可變形卷積對可追蹤點進行追蹤,成爲PCD模塊;以及提出了多幀處理時信息的融合的spatial-temporal維度的融合,成爲TSA模塊;代碼見:EDVR.
小編受委託於一個師姐,幫她調試了代碼。
(一)環境準備
老規矩哈,對每個項目新建一個虛擬環境,完了後再刪除。虛擬環境的新建見我的另一篇博客:vid2vid 代碼調試+訓練+測試(debug+train+test)(一)測試篇。
(二)下載工程
$ git clone https://github.com/xinntao/EDVR.git
$ cd EDVR/
目錄結構如下, 其中,experiments用來保存訓練的模型和驗證結果(checkpoints),tb_logger用來保存log日誌;codes是主要的,包括了各種代碼(包括train.py);datasets其實可以不管,用來存放數據集,但是後面你會發現數據集的引用是使用絕對路徑(Σ( ° △ °|||)︴)!
(三)數據集準備
這應該是本項目最繁瑣的一部分了。
先 pip 安裝 lmdb 。
準備好數據集後,修改代碼“codes/data_scripts/create_lmdb_mp.py”。建議數據集結構如下:
"""
datasets
|--inputs
|--0000 # clip
|--00000000.png
|--00000001.png
...
|--00000029.png # 30 frames for each clip in our datasets
|--0001
|--00000000.png
...
...
|--GT
|--0000 # clip
|--00000000.png
|--00000001.png
...
|--00000029.png # 30 frames for each clip in our datasets
|--0001
|--00000000.png
...
...
"""
指定一個mode後,修改或者新增 if -else 結構,修改input或者GT的目錄與保存路徑。
針對input與GT分別執行create_lmdb_mp.py一次。
這之後,我們可以看看對應的"train_xxxx_wval.lmdb(是一個文件夾)"包含了什麼。,其中的data.mdb就是圖像數據了,是以字節形式存儲,所以空間較小,讀取快!lock.mdb僅是數據庫中防止衝突的操作,當然代碼中對數據的讀取都是隻讀,所以影響不大。meata_info.pkl是一個字典,結構如下:(我們待會要使用它!)
"""
{
'name': 'xxxxxxxxxxxx',
'resolution': '3_720_1280',
'keys': ['0000_00000000', '0000_00000001', ...]
}
"""
僅僅如此還不夠!因爲meta_info.pkl的作用在於,幫助讀取mdb文件中對應的圖像數據。那麼,我們還需要一個keys來告訴datatset 有哪一些keys,因此我們需要將生成的兩個"train_xxxx_wval.lmdb"中的一個的meta_info.pkl複製到目錄“codes/data/”下(因爲input與GT的名字對應一般是相同的),注意到原先已經存在“REDS_trainval_keys.pkl”和“Vimeo90K_train_keys.pkl”了。我們這裏是重命名爲:EFRM_train_keys.pkl,待會在配置文件中的cache_keys指向它。
注意meta_info.pkl裏存儲的是字典,其中包含keys對應了文件名的列表;而原有的“REDS_trainval_keys.pkl”和“Vimeo90K_train_keys.pkl”則只是包含了列表;因此,我們還需要修改一下下面代碼:
/* codes/data/REDS_dataset.py: __init__() */
“cache_keys”是在配置文件中設置的,見(四), pickle是python讀寫pkl的一個包;我們載入剛剛的pkl後是一個字典,所以我們需要將在後面補上“[ 'keys' ]”表示只取keys鍵對應的值(文件名的列表)。
(四)設置配置文件
在路徑“./codes/options/train/”下有兩個“.yml”文件,這是參數配置文件。參數意義與配置如下:
#### general settings
name: 001_EDVRwoTSA_scratch_lr4e-4_600k_REDS_LrCAR4S # 爲本次實驗命名
use_tb_logger: true # 是否要輸出和保存日誌(一般都是要的吧~)
model: VideoSR_base # 使用的模型(不用改,這是作者文章的模型,通過參數配置可以構造文章所有的模型)
distortion: sr
scale: 4 # 輸出大小是輸入的4倍(不用改,實際上就是網絡最後對應多了多少層上採樣)
gpu_ids: [3] # 可以使用單核gpu(特別適合小編這種窮人)
#### datasets
datasets:
train:
name: REDS
mode: REDS
interval_list: [1] # 相鄰幀:t-i, t, t+i
random_reverse: false # 是否隨機對幀序取反
border_mode: false
dataroot_GT: /home/xyy/ssd/xwp/__temp__/train_EFGT_wval.lmdb # GT數據的絕對路徑
dataroot_LQ: /home/xyy/ssd/xwp/__temp__/train_EF_wval.lmdb # 輸入數據的絕對路徑
cache_keys: EFRM_train_keys.pkl # 前面我們自定義的訓練數據的文件名彙總(存儲的是:List: ['0000_00000000', '0000_00000001', ..., '0001_00000000', ...])
N_frames: 5 # 輸入的幀數(中間幀爲key)
use_shuffle: true
n_workers: 3 # per GPU
batch_size: 8
GT_size: 256
LQ_size: 256 # 如果做得不是SR的任務,而是deblur/derain等輸入輸出的分辨率一樣的話,這裏要求設置:GT_size = LQ_size,具體數值不管;而如果是SR任務,則需要保證:GT_size/LQ_size = scale
use_flip: true # 隨機翻轉(水平/垂直)做數據增強
use_rot: true # 隨機旋轉
color: RGB
#### network structures
network_G:
which_model_G: EDVR
nf: 64 # 第一個conv的通道數
nframes: 5
groups: 8
front_RBs: 5
back_RBs: 10
predeblur: true # 是否使用一個預編碼層,它的作用是對輸入 HxW 經過下采樣得到 H/4xW/4 的feature,以便符合後面的網絡
HR_in: true # 很重要!!只要你的輸入與輸出是同樣分辨率,就要求設置爲true
w_TSA: true # 是否使用TSA模塊
#### path
path:
pretrain_model_G: ~ # 假如沒有與訓練的模型,設置爲~(表示None)
strict_load: true
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 4e-4
lr_scheme: CosineAnnealingLR_Restart
beta1: 0.9
beta2: 0.99
niter: 600000
warmup_iter: -1 # -1: no warm up
T_period: [150000, 150000, 150000, 150000]
restarts: [150000, 300000, 450000]
restart_weights: [1, 1, 1]
eta_min: !!float 1e-7
pixel_criterion: cb
pixel_weight: 1.0
val_freq: !!float 2e3
manual_seed: 0
#### logger
logger:
print_freq: 10 # 每多少個iterations打印日誌
save_checkpoint_freq: !!float 2e3 # 沒多少個iterations保存模型
(五) 修改代碼
下面小編將展示次工程比較不友好的一個地方。
以“codes/data/REDS_dataset.py”爲例,在函數__getitem__(self, index)中有一個坑。
(本項目讀取數據的規則是:在前面部分將所有的數據封裝成lmdb的形式,需要通過key(圖片名,無後綴)進行讀取;在dataset的__getitem__中,是先將所有的keys讀入(就是前面我們需要自己準備的"XXX_keys.pkl"文件)),然後每次讀取連續的幾個keys,再經過_read_img_mc_BGR函數去獲取圖像數組。
這裏有幾個數值我們需要修改:
1)上面的兩個紅框,原本的數值是99;這是因爲作者用的訓練數據每一個clip中含有100幀(xxxx_00000000, xxxx_00000001, ..., xxxx_00000099) ,爲了保證不讀取到兩個clips的幀,需要對幀的索引做檢查。師姐的數據中每個clip的幀數是30,所以這裏要設置成29.
2)假如讀者使用與作者相同的命名格式:“xxxx_xxxxxxxx”,那麼底下的框就不需要修改;但假如不是,像師姐的命名是“xxxx_xxxxxx”,所以這裏就需要改成“{:06d}”而不是原來的“{:06d}”。
這裏最好奇的是,上面的“99”爲什麼不設置成一個超參數?
(六)訓練
python -m torch.distributed.launch --nproc_per_node=2 --master_port=21688 train.py -opt options/train/<我自己的配置文件>.yml --launcher pytorch
# 注意這裏的 master_port 不是固定的,根據自己服務器當前的端口使用,賦予一個沒有使用的端口即可;否則會發生系統錯誤,甚至無法fork出子進程