apex 安裝/使用 記錄

一、apex

是什麼:混合精度

什麼用:提升GPU上的訓練速度

GitHub:https://github.com/NVIDIA/apex

API文檔:https://nvidia.github.io/apex 

使用要求:

Python 3

CUDA 9 or newer

PyTorch 0.4 or newer. The CUDA and C++ extensions require pytorch 1.0 or newer.

推薦已發佈的最新版本,見https://pytorch.org/.

我們也針對最新的主分支進行測試, obtainable from https://github.com/pytorch/pytorch.

在Docker容器中使用Apex通常很方便。兼容的選項包括:

NVIDIA Pytorch containers from NGC, which come with Apex preinstalled. To use the latest Amp API, you may need to pip uninstall apex then reinstall Apex using the Quick Start commands below.
official Pytorch -devel Dockerfiles, e.g. docker pull pytorch/pytorch:nightly-devel-cuda10.0-cudnn7, in which you can install Apex using the Quick Start commands. 

如何安裝:
Linux:

爲了性能和完整的功能,建議通過CUDA和c++擴展來安裝Apex

$ git clone https://github.com/NVIDIA/apex
$ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Apex 同樣支持 Python-only build (required with Pytorch 0.4) via

$ pip install -v --no-cache-dir ./

Windows:

Windows支持是實驗性的,建議使用Linux。

如果你能在你的系統上從源代碼構建Pytorch,採用pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .

pip install -v --no-cache-dir .(沒有CUDA/ c++擴展)更可能有效。

如果您已經在Conda環境中安裝了Pytorch,請確保在相同的環境中安裝Apex。

相關鏈接:https://github.com/NVIDIA/apex/tree/master/examples/docker

安裝後如何使用:參考文檔https://nvidia.github.io/apex/amp.html

例子:

# Declare model and optimizer as usual, with default (FP32) precision
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Allow Amp to perform casts as required by the opt_level
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
...
# loss.backward() becomes:
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()

 

二、我的安裝流程:

1. $ git clone https://github.com/NVIDIA/apex 完成
2. $ cd apex 完成
3. $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

3時出現報錯,這個問題issue上有很多人在問

Cleaning up...
  Removing source in /tmp/pip-req-build-v0deounv
Removed build tracker '/tmp/pip-req-tracker-3n3fyj4o'
ERROR: Command errored out with exit status 1: /users4/zsun/anaconda3/bin/python -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-req-build-v0deounv/setup.py'"'"'; __file__='"'"'/tmp/p
ip-req-build-v0deounv/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' --cpp_e
xt --cuda_ext install --record /tmp/pip-record-rce1cb4d/install-record.txt --single-version-externally-managed --compile Check the logs for full command output.
Exception information:
Traceback (most recent call last):
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/pip/_internal/cli/base_command.py", line 153, in _main
    status = self.run(options, args)
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/pip/_internal/commands/install.py", line 455, in run
    use_user_site=options.use_user_site,
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/pip/_internal/req/__init__.py", line 62, in install_given_reqs
    **kwargs
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/pip/_internal/req/req_install.py", line 888, in install
    cwd=self.unpacked_source_directory,
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/pip/_internal/utils/subprocess.py", line 275, in runner
    spinner=spinner,
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/pip/_internal/utils/subprocess.py", line 242, in call_subprocess
    raise InstallationError(exc_msg)
pip._internal.exceptions.InstallationError: Command errored out with exit status 1: /users4/zsun/anaconda3/bin/python -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-req-build-v0deoun
v/setup.py'"'"'; __file__='"'"'/tmp/pip-req-build-v0deounv/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code
, __file__, '"'"'exec'"'"'))' --cpp_ext --cuda_ext install --record /tmp/pip-record-rce1cb4d/install-record.txt --single-version-externally-managed --compile Check the logs for full command output.
1 location(s) to search for versions of pip:
* http://mirrors.aliyun.com/pypi/simple/pip/
Getting page http://mirrors.aliyun.com/pypi/simple/pip/
Found index url http://mirrors.aliyun.com/pypi/simple/
Starting new HTTP connection (1): mirrors.aliyun.com:80
http://mirrors.aliyun.com:80 "GET /pypi/simple/pip/ HTTP/1.1" 200 12139
Analyzing links from page http://mirrors.aliyun.com/pypi/simple/pip/
  Found link http://mirrors.aliyun.com/pypi/packages/18/ad/c0fe6cdfe1643a19ef027c7168572dac6283b80a384ddf21b75b921877da/pip-0.2.1.tar.gz#sha256=83522005c1266cc2de97e65072ff7554ac0f30ad369c3b02ff3a764b9620
48da (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.2.1
  Found link http://mirrors.aliyun.com/pypi/packages/3d/9d/1e313763bdfb6a48977b65829c6ce2a43eaae29ea2f907c8bbef024a7219/pip-0.2.tar.gz#sha256=88bb8d029e1bf4acd0e04d300104b7440086f94cc1ce1c5c3c31e3293aee1f
81 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.2
  Found link http://mirrors.aliyun.com/pypi/packages/0a/bb/d087c9a1415f8726e683791c0b2943c53f2b76e69f527f2e2b2e9f9e7b5c/pip-0.3.1.tar.gz#sha256=34ce534f17065c78f980702928e988a6b6b2d8a9851aae5f1571a1feb9bb
58d8 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.3.1
  Found link http://mirrors.aliyun.com/pypi/packages/17/05/f66144ef69b436d07f8eeeb28b7f77137f80de4bf60349ec6f0f9509e801/pip-0.3.tar.gz#sha256=183c72455cb7f8860ac1376f8c4f14d7f545aeab8ee7c22cd4caf79f35a2ed
47 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.3
  Found link http://mirrors.aliyun.com/pypi/packages/cf/c3/153571aaac6cf999f4bb09c019b1ff379b7b599ea833813a41c784eec995/pip-0.4.tar.gz#sha256=28fc67558874f71fddda7168f73595f1650523dce3bc5bf189713ecdfc1e45
6e (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.4
  Found link 



  Found link http://mirrors.aliyun.com/pypi/packages/ac/95/a05b56bb975efa78d3557efa36acaf9cf5d2fd0ee0062060493687432e03/pip-9.0.3-py2.py3-none-any.whl#sha256=c3ede34530e0e0b2381e7363aded78e0c33291654937e7373032fda04e8803e5 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 9.0.3
  Found link http://mirrors.aliyun.com/pypi/packages/c4/44/e6b8056b6c8f2bfd1445cc9990f478930d8e3459e9dbf5b8e2d2922d64d3/pip-9.0.3.tar.gz#sha256=7bf48f9a693be1d58f49f7af7e0ae9fe29fd671cde8a55e6edca3581c4ef5796 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 9.0.3
Given no hashes to check 131 links for project 'pip': discarding no candidates

4. 使用

$ pip install -v --no-cache-dir ./

得到的反饋是

  copying build/lib/apex/RNN/RNNBackend.py -> build/bdist.linux-x86_64/wheel/apex/RNN
  copying build/lib/apex/RNN/cells.py -> build/bdist.linux-x86_64/wheel/apex/RNN
  creating build/bdist.linux-x86_64/wheel/apex/normalization
  copying build/lib/apex/normalization/__init__.py -> build/bdist.linux-x86_64/wheel/apex/normalization
  copying build/lib/apex/normalization/fused_layer_norm.py -> build/bdist.linux-x86_64/wheel/apex/normalization
  running install_egg_info
  running egg_info
  creating apex.egg-info
  writing apex.egg-info/PKG-INFO
  writing dependency_links to apex.egg-info/dependency_links.txt
  writing top-level names to apex.egg-info/top_level.txt
  writing manifest file 'apex.egg-info/SOURCES.txt'
  reading manifest file 'apex.egg-info/SOURCES.txt'
  writing manifest file 'apex.egg-info/SOURCES.txt'
  Copying apex.egg-info to build/bdist.linux-x86_64/wheel/apex-0.1-py3.6.egg-info
  running install_scripts
  creating build/bdist.linux-x86_64/wheel/apex-0.1.dist-info/WHEEL
done
  Created wheel for apex: filename=apex-0.1-cp36-none-any.whl size=136906 sha256=55830f559061fcb30ed616dd6879086c9b79926c3d3e0017a2dcf6c0e1aa8037
  Stored in directory: /tmp/pip-ephem-wheel-cache-m4cxipvx/wheels/6c/91/1a/143cfe0f99d10c8c415d1594024d1de93c5f8c03f5edfad2ba
  Removing source in /tmp/pip-req-build-yg8bljf6
Successfully built apex
Installing collected packages: apex
  Found existing installation: apex 0.1
    Uninstalling apex-0.1:
      Created temporary directory: /users4/zsun/anaconda3/lib/python3.6/site-packages/~pex-0.1.dist-info
      Removing file or directory /users4/zsun/anaconda3/lib/python3.6/site-packages/apex-0.1.dist-info/
      Created temporary directory: /users4/zsun/anaconda3/lib/python3.6/site-packages/~pex
      Removing file or directory /users4/zsun/anaconda3/lib/python3.6/site-packages/apex/
      Successfully uninstalled apex-0.1

Successfully installed apex-0.1
Cleaning up...
Removed build tracker '/tmp/pip-req-tracker-nm6wywoj'
1 location(s) to search for versions of pip:
* http://mirrors.aliyun.com/pypi/simple/pip/
Getting page http://mirrors.aliyun.com/pypi/simple/pip/
Found index url http://mirrors.aliyun.com/pypi/simple/
Starting new HTTP connection (1): mirrors.aliyun.com:80
http://mirrors.aliyun.com:80 "GET /pypi/simple/pip/ HTTP/1.1" 200 12139
Analyzing links from page http://mirrors.aliyun.com/pypi/simple/pip/
  Found link http://mirrors.aliyun.com/pypi/packages/18/ad/c0fe6cdfe1643a19ef027c7168572dac6283b80a384ddf21b75b921877da/pip-0.2.1.tar.gz#sha256=83522005c1266cc2de97e65072ff7554ac0f30ad369c3b02ff3a764b962048da (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.2.1
  Found link http://mirrors.aliyun.com/pypi/packages/3d/9d/1e313763bdfb6a48977b65829c6ce2a43eaae29ea2f907c8bbef024a7219/pip-0.2.tar.gz#sha256=88bb8d029e1bf4acd0e04d300104b7440086f94cc1ce1c5c3c31e3293aee1f81 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.2
  Found link http://mirrors.aliyun.com/pypi/packages/0a/bb/d087c9a1415f8726e683791c0b2943c53f2b76e69f527f2e2b2e9f9e7b5c/pip-0.3.1.tar.gz#sha256=34ce534f17065c78f980702928e988a6b6b2d8a9851aae5f1571a1feb9bb58d8 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.3.1
  Found link http://mirrors.aliyun.com/pypi/packages/17/05/f66144ef69b436d07f8eeeb28b7f77137f80de4bf60349ec6f0f9509e801/pip-0.3.tar.gz#sha256=183c72455cb7f8860ac1376f8c4f14d7f545aeab8ee7c22cd4caf79f35a2ed47 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.3
  Found link http://mirrors.aliyun.com/pypi/packages/cf/c3/153571aaac6cf999f4bb09c019b1ff379b7b599ea833813a41c784eec995/pip-0.4.tar.gz#sha256=28fc67558874f71fddda7168f73595f1650523dce3bc5bf189713ecdfc1e456e (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.4
  Found link http://mirrors.aliyun.com/pypi/packages/9a/aa/f536b6d14fe03343367da2ff44eee28f340ae650cd017ca088b6be13084a/pip-0.5.1.tar.gz#sha256=e27650538c41fe1007a41abd4cfd0f905b822622cbe1f8e7e09d1215af207694 (from http://mirrors.aliyun.com/pypi/simple/pi





  Found link http://mirrors.aliyun.com/pypi/packages/ac/95/a05b56bb975efa78d3557efa36acaf9cf5d2fd0ee0062060493687432e03/pip-9.0.3-py2.py3-none-any.whl#sha256=c3ede34530e0e0b2381e7363aded78e0c33291654937e7373032fda04e8803e5 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 9.0.3
  Found link http://mirrors.aliyun.com/pypi/packages/c4/44/e6b8056b6c8f2bfd1445cc9990f478930d8e3459e9dbf5b8e2d2922d64d3/pip-9.0.3.tar.gz#sha256=7bf48f9a693be1d58f49f7af7e0ae9fe29fd671cde8a55e6edca3581c4ef5796 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 9.0.3
Given no hashes to check 131 links for project 'pip': discarding no candidates

雖然最後報了一樣的錯誤,但在中間出現了【Successfully installed apex-0.1】

暫且當做安裝成功,繼續向下進行。

5. 按照上面的例子以及下面的例子更改我的代碼,很簡單,只有幾處更改。

if args.apex:
    from apex import amp
# Declare model and optimizer as usual, with default (FP32) precision
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Allow Amp to perform casts as required by the opt_level
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
...
# loss.backward() becomes:
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
...
# Save checkpoint
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
amp.load_state_dict(checkpoint['amp'])
...

6.運行第一次看效果,得到反饋

2019-11-27 14:42:14,362 INFO: Loading vocab,train and val dataset.Wait a second,please
#Params: 73.7M
Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Warning:  multi_tensor_applier fused unscale kernel is unavailable, possibly because apex was installed without --cuda_ext --cpp_ext. Using Python fallback.  Original ImportError was: ModuleNotFoundError("No module named 'amp_C'",)
Traceback (most recent call last):
  File "main.py", line 582, in <module>
    train()
  File "main.py", line 388, in train
    with amp.scale_loss(loss, optimizer) as scaled_loss:
  File "/users4/zsun/anaconda3/lib/python3.6/contextlib.py", line 81, in __enter__
    return next(self.gen)
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/apex/amp/handle.py", line 111, in scale_loss
    optimizer._prepare_amp_backward()
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 219, in prepare_backward_no_master_weights
    self._amp_lazy_init()
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 309, in _amp_lazy_init
    self._lazy_init_maybe_master_weights()
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 210, in lazy_init_no_master_weights
    "Received {}".format(param.type()))
TypeError: Optimizer's parameters must be either torch.cuda.FloatTensor or torch.cuda.HalfTensor. Received torch.FloatTensor

這裏可以看到前面步驟沒有使用C擴展來安裝apex還是有一定的問題的,無法使用fused unscale kernel(本程序用不到這個),但是這是warning不是error,所以我們改正錯誤繼續運行。運行成功。

(上面出錯的原因是我的optim是放到cpu計算的,但是apex要求他要放到gpu,與本步驟無關)

沒加apex之前的程序 batch=8的,<3.5h一次eval,三次eval一輪,4911 / 12196 MB | zsun(4901M)

                                    batch=10的,<2.8h一次eval,三次eval一輪, 9723 / 12196 MB | zsun(9713M)

                                    batch=4的,<7.25h一次eval,三次eval一輪,9863 / 16280 MB | zsun(9853M)(另一份程序|沒CVAE的)

                                    batch=16的,會一段時間之後out of memory

加了apex之後的程序 batch=16,<1.5h一次eval,三次eval一輪,11517 / 12196 MB | zsun(11507M),不再oom

7. 但是出現了效果下降的問題。有錯誤提醒:

2019-11-28 06:44:23,244 INFO: eval
2019-11-28 06:49:59,922 INFO: Epoch:  4 fmax: 0.246901 cur_max_f: 0.183399
2019-11-28 06:49:59,923 INFO: Epoch:  4 Min_Val_Loss: 0.318247 Cur_Val_Loss: 0.427435
2019-11-28 06:49:59,930 INFO:   [0] Cur_fmax: 0.183399 Cur_bound: 0.139370
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 0.00048828125
2019-11-28 08:17:46,401 INFO: eval
2019-11-28 08:23:15,293 INFO: Epoch:  4 fmax: 0.246901 cur_max_f: 0.215748
2019-11-28 08:23:15,294 INFO: Epoch:  4 Min_Val_Loss: 0.318247 Cur_Val_Loss: 0.474844
2019-11-28 08:23:15,304 INFO:   [0] Cur_fmax: 0.215748 Cur_bound: 0.016865
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 0.000244140625

而之前我的程序是這樣的

2019-11-28 00:03:36,134 INFO: eval
2019-11-28 00:06:15,011 INFO: Epoch:  5 fmax: 0.437465 cur_max_f: 0.437465
2019-11-28 00:06:15,012 INFO: Epoch:  5 Min_Val_Loss: 0.251108 Cur_Val_Loss: 0.251108
2019-11-28 00:06:15,019 INFO:   [0] Cur_fmax: 0.437465 Cur_bound: 0.231224

發現loss整體變大,而且很不穩定。效果變差。而且這個錯誤提醒是什麼意思呢?

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 0.00048828125

意思是:梯度溢出,issue上也有很多人提出了這個問題,貌似作者一直在收集這個問題出現的樣例,尚未解決。

 

 

 

 

 

 

 

 

三、其餘注意事項(均來自於官方文檔)

3.1.apex.amp

注意⚠️:

目前,控制純或混合精度訓練的底層特性如下:

cast_model_type:將模型的參數和緩衝區強制轉換爲所需的類型。
patch_torch_functions:修補所有Torch函數和張量方法,以執行對張量核心友好的操作,比如FP16中的GEMMs和convolutions,以及FP32中任何受益於FP32精度的操作。
keep_batchnorm_fp32:爲了提高精度並啓用cudnn batchnorm(這可以提高性能),將batchnorm的權重保持在FP32中通常是有益的,即使模型的其餘部分是FP16。
master_weights:保持FP32的主權重,以配合任何FP16模型的權重。FP32主權重由優化器逐步提高精度和捕獲小梯度。
loss_scale:如果loss_scale是一個浮點值,那麼使用這個值作爲靜態(固定)損失範圍。如果loss_scale是字符串“dynamic”,則自適應地隨時間調整損失比例。動態損失比例調整由放大器自動執行。

同樣,您通常不需要手動指定這些屬性。相反,選擇一個opt_level,它將爲您設置它們。在選擇opt_level之後,可以選擇將屬性kwargs作爲手動覆蓋傳遞。

如果您試圖覆蓋一個屬性,這是沒有意義的選擇opt_level, Amp將提出一個錯誤的解釋。例如,選擇opt_level="O1"並使用override master_weights=True是沒有意義的。O1插入圍繞Torch函數而不是模型權重進行強制轉換。數據、激活和權重在它們流經修補過的函數時被動態地重新分配。因此,模型本身的權重可以(也應該)保持FP32,不需要保持單獨的FP32主權重。
opt_levels

可識別的opt_levels是“O0”、“O1”、“O2”和“O3”。

O0和O3並不是真正的混合精度,但是它們分別用於建立精度和速度基線。
O1和O2是混合精度的不同實現。試試這兩種方法,看看什麼能給你的模型帶來最好的加速和準確性。

O0: FP32 training
你的incoming model應該已經是FP32了,所以這可能是一個無操作。O0可用於建立準確性基線。

O0設置的默認屬性:
cast_model_type = torch.float32
patch_torch_functions = False
keep_batchnorm_fp32=None(實際上,“不適用”,一切都是FP32)
master_weights = False
loss_scale = 1.0

O1: Mixed Precision (recommended for typical use)

對所有Torch函數和張量方法進行修補,使它們的輸入符合白名單-黑名單模型。白名單操作(例如,張量核心友好操作,如GEMMs和convolutions)在FP16中執行。受益於FP32精度的黑名單操作(例如softmax)在FP32中執行。O1還使用動態損失縮放,除非覆蓋。

O1設置的默認屬性:
cast_model_type=None (not applicable)
patch_torch_functions=True
keep_batchnorm_fp32=None (again, not applicable, all model weights remain FP32)
master_weights=None (not applicable, model weights remain FP32)
loss_scale="dynamic"


O2: “Almost FP16” Mixed Precision

O2將模型的權值轉換爲FP16,修補模型的前向方法,將輸入數據轉換爲FP16,保持FP32中的批處理規範,維護FP32的主權值,更新優化器的param_groups,以便optimizer.step()直接作用於FP32的權值
(隨後是FP32主重量-如有必要,>FP16型號重量拷貝),
並實現動態損失縮放(除非被覆蓋)。與O1不同,O2不修補Torch函數或張量方法。

O2設置的默認屬性:
cast_model_type=torch.float16
patch_torch_functions=False
keep_batchnorm_fp32=True
master_weights=True
loss_scale="dynamic"


O3: FP16 training

O3可能無法實現真正的混合精度選項O1和O2的穩定性。但是,爲您的模型建立一個速度基線是很有用的,可以比較O1和O2的性能。如果您的模型使用批處理規範化,爲了建立“光速”,您可以嘗試使用帶有附加屬性override keep_batchnorm_fp32=True的O3(如前所述,它支持cudnn batchnorm)。

O3設置的默認屬性:
cast_model_type=torch.float16
patch_torch_functions=False
keep_batchnorm_fp32=False
master_weights=False
loss_scale=1.0

注意⚠️:amp.initialize should be called after you have finished constructing your model(s) and optimizer(s), but before you send your model through any DistributedDataParallel wrapper. Currently, amp.initialize should only be called once.

參數:

Parameters
models (torch.nn.Module or list of torch.nn.Modules) – Models to modify/cast.

optimizers (optional, torch.optim.Optimizer or list of torch.optim.Optimizers) – Optimizers to modify/cast. REQUIRED for training, optional for inference.

enabled (bool, optional, default=True) – If False, renders all Amp calls no-ops, so your script should run as if Amp were not present.

opt_level (str, optional, default="O1") – Pure or mixed precision optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”, explained in detail above.

cast_model_type (torch.dtype, optional, default=None) – Optional property override, see above.

patch_torch_functions (bool, optional, default=None) – Optional property override.

keep_batchnorm_fp32 (bool or str, optional, default=None) – Optional property override. If passed as a string, must be the string “True” or “False”.

master_weights (bool, optional, default=None) – Optional property override.

loss_scale (float or str, optional, default=None) – Optional property override. If passed as a string, must be a string representing a number, e.g., “128.0”, or the string “dynamic”.

cast_model_outputs (torch.dpython:type, optional, default=None) – Option to ensure that the outputs of your model(s) are always cast to a particular type regardless of opt_level.

num_losses (int, optional, default=1) – Option to tell Amp in advance how many losses/backward passes you plan to use. When used in conjunction with the loss_id argument to amp.scale_loss, enables Amp to use a different loss scale per loss/backward pass, which can improve stability. See “Multiple models/optimizers/losses” under Advanced Amp Usage for examples. If num_losses is left to 1, Amp will still support multiple losses/backward passes, but use a single global loss scale for all of them.

verbosity (int, default=1) – Set to 0 to suppress Amp-related output.

min_loss_scale (float, default=None) – Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. The default value of None means that no floor is imposed. If dynamic loss scaling is not used, min_loss_scale is ignored.

max_loss_scale (float, default=2.**24) – Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored.

Returns
Model(s) and optimizer(s) modified according to the opt_level. If either the models or optimizers args were lists, the corresponding return value will also be a list.

checkpoint

爲了正確地保存和加載amp訓練,我們引入了amp.state_dict(),它包含所有的loss_scalers及其相應的未跳過步驟,還引入了amp.load_state_dict()來恢復這些屬性。
注意,我們建議使用相同的opt_level恢復模型。還要注意,我們建議在amp.initialize之後調用load_state_dict方法。

...
# Save checkpoint
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
amp.load_state_dict(checkpoint['amp'])
...

Advanced use cases:
統一的Amp API支持跨迭代的梯度累積、每次迭代的多次後向遍歷、多個模型/優化器、自定義/用戶定義的autograd函數和自定義數據批處理類。梯度裁剪和GANs也需要特殊的處理,但是這種處理不需要改變不同的opt_levels。

Transition guide for old API users:

我們強烈鼓勵遷移到新的Amp API,因爲它更多功能,更容易使用,並在未來的證明。原始的FP16_Optimizer和舊的“Amp”API都是不支持的,而且隨時可能被移除。
以前通過amp_handle公開的函數現在可以通過amp模塊訪問。應該刪除對amp_handle = amp.init()的任何現有調用。
詳細內容請參照文檔,此處不贅述。

3.2.apex.optimizers

待更

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章