Pytorch中的view()和reshape()有何不同?

Pytorch中的view()和reshape()的功能都是reshape tensor:

import torch
x = torch.arange(10)

x_2x5 = x.view(2, 5)
print(x_2x5)
x_5x2 = x.reshape(5, 2)
print(x_5x2)

其區別是:

  • view()要求tensor必須是Contiguous Memory,遇到noncontiguous memory會報錯!
  • reshape()沒有上述要求,在操作Contiguous Memory時,性能比view()稍差
import torch
x = torch.arange(10)
# contiguous memory
x_2x5 = x.view(2, 5)
print(x_2x5)
x_5x2 = x.reshape(5, 2)
print(x_5x2)

# noncontiguous memory
y = x_2x5.t()
y_1x10 = y.view(10)

報錯信息:

y_1x10 = y.view(10)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

解決方式:用Tensor.contiguous(memory_format=torch.contiguous_format) → Tensor
方法,將noncontiguous memory變成contiguous memory,然後再用view()

import torch
x = torch.arange(10)
# contiguous memory
x_2x5 = x.view(2, 5)
print(x_2x5)
x_5x2 = x.reshape(5, 2)
print(x_5x2)

# noncontiguous memory
y = x_2x5.t()
y_1x10 = y.contiguous().view(10)
print(y_1x10.shape)

執行結果:

torch.Size([10])

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章