Pytorch中的view()和reshape()有何不同?

Pytorch中的view()和reshape()的功能都是reshape tensor:

import torch
x = torch.arange(10)

x_2x5 = x.view(2, 5)
print(x_2x5)
x_5x2 = x.reshape(5, 2)
print(x_5x2)

其区别是:

  • view()要求tensor必须是Contiguous Memory,遇到noncontiguous memory会报错!
  • reshape()没有上述要求,在操作Contiguous Memory时,性能比view()稍差
import torch
x = torch.arange(10)
# contiguous memory
x_2x5 = x.view(2, 5)
print(x_2x5)
x_5x2 = x.reshape(5, 2)
print(x_5x2)

# noncontiguous memory
y = x_2x5.t()
y_1x10 = y.view(10)

报错信息:

y_1x10 = y.view(10)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

解决方式:用Tensor.contiguous(memory_format=torch.contiguous_format) → Tensor
方法,将noncontiguous memory变成contiguous memory,然后再用view()

import torch
x = torch.arange(10)
# contiguous memory
x_2x5 = x.view(2, 5)
print(x_2x5)
x_5x2 = x.reshape(5, 2)
print(x_5x2)

# noncontiguous memory
y = x_2x5.t()
y_1x10 = y.contiguous().view(10)
print(y_1x10.shape)

执行结果:

torch.Size([10])

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章