pytorch0.4使用注意

1.梯度
1.Variable()中,requires_grad=Fasle時不需要更新梯度, 適用於凍結某些層的梯度;
2.volatile=True相當於requires_grad=False,適用於測試階段,不需要反向傳播。在torch>=0.4中,這個現在已經取消了,使用with torch.no_grad()或者torch.set_grad_enable(grad_mode)來替代:

with torch.no_grad():
  test()
>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
...     y = x * 2
>>> y.requires_grad
False
>>>
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...     y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True)  # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

2.tensor與Variable
Tensor在0.4中,現在默認requires_grad=False的Variable了,相當於(tensor 等價於Variable(Tensor,requires_grad=Fasle)),
torch.Tensor和torch.autograd.Variable現在其實是同一個類! 沒有本質的區別! 所以也就是說, 現在已經沒有純粹的Tensor了, 是個Tensor, 它就支持自動求導! 你現在要不要給Tensor包一下Variable都沒有任何意義了
下面是0.4中一些新建tensor的方法

#0.4中建立一個tensor:
>>> device = torch.device("cuda:1")
>>> x = torch.randn(3, 3, dtype=torch.float64, device=device)
tensor([[-0.6344,  0.8562, -1.2758],
        [ 0.8414,  1.7962,  1.0589],
        [-0.1369, -1.0462, -0.4373]], dtype=torch.float64, device='cuda:1')
>>> x.requires_grad  # default is False
False
>>> x = torch.zeros(3, requires_grad=True)
>>> x.requires_grad
True

3.requires_grad 已經是Tensor的一個屬性了
舉個例子:

>>> x = torch.ones(1)
>>> x.requires_grad #默認是False
False
這裏也說明了tensor就是一個requires_grad=False的Variable

4.不要隨便用.data
在torch0.3中,Variable分爲tensor和grad兩項,通過.data取出Variable中的Tensor,torch0.4變了.
torch0.4中,.data返回的是一個tensor,但是現在這個tensor是一個有requires_grad(可以自動求導)的tensor,而且現在.data取出的tensor和之前的Variable是內存共享,所以不安全.

y = x.data # x需要進行autograd的
# y和x是共享內存的,但是這裏y已經不需要grad了, 
# 所以會導致本來需要計算梯度的x也沒有梯度可以計算.從而x不會得到更新!

爲了解決上面的風險:所以, 推薦用x.detach(), 這個仍舊是共享內存的, 也是使得y的requires_grad爲False, 但是,如果x需要求導, 仍舊是可以自動求導的!

y = x.datach() # x需要進行autograd的
y和x也是共享內存,並且y的requires_grad爲False,但是,如果x需要求導, 仍舊是可以自動求導的!

5..item()
以前取tensor的值用.data,現在用.item()
比如爲了顯示loss到命令行:以前了累加loss(爲了看loss的大小)一般是用total_loss+=loss.data[0] , 比較詭異的是, 爲啥是.data[0]? 這是因爲, 這是因爲loss是一個Variable, 所以以後累加loss, 用loss.item().

6.棄用 volatile(同最開始的梯度解釋)

參考:https://www.itency.com/topic/show.do?id=494122

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章