pytorch 在計算long和float運算時會出現錯誤:
>>> import torch
>>> a = torch.tensor([1,2,3], dtype=torch.long)
>>> a + 0.5
1
2
3
[torch.LongTensor of size (3,)]
這對應到在long類型數據使用dropout上時有時會出現如下問題:
/usr/local/lib/python3.6/dist-packages/torch/nn/_functions/dropout.py in forward(cls, ctx, input, p, train, inplace)
38 ctx.noise.fill_(0)
39 else:
---> 40 ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
41 ctx.noise = ctx.noise.expand_as(input)
42 output.mul_(ctx.noise)
RuntimeError: invalid argument 3: divide by zero at /pytorch/aten/src/THC/generic/THCTensorMathPairwise.cu:88
轉成float類型可以暫時解決:
>>> a
tensor([ 100, 3, 100])
>>> a.float()
tensor([ 100., 3., 100.])
>>> d
Dropout(p=0.5)
>>> d(a.float())
tensor([ 200., 0., 0.]) # works fine
>>> d(a)
Floating point exception # throws error
暫時解決辦法:
dropout_p = 0.1
dropout = torch.nn.Droupout(p=droupout_p)
if self.training:
x_ = dropout(inputs.float())
print(x_)
inputs = torch.round(x_.mul(1-dropout_p)).long() #round用來解決浮點數的誤差問題
print(inputs)