Pytorch基础函数(二)变形函数——view()与view_as()

view()函数是用于对Tensor(张量)进行形状变化的函数,如一个Tensor的size是3x2,可以通过view()函数将其形状转为2x3。

 

但是需要注意的是进行操作的张量必须是contiguous()的,即在在内存中连续的。(不连续可以使用tensor.contiguous()操作转为连续)。

 

一、view()函数基本操作

函数定义:view(*args)   [Ps: 参数的乘积积必须与原Tensor的元素个数保持一致,也就是转换前后元素数量不变]

将进行view()操作的Tensor其形状根据参数args进行改变。

例如:

t1=torch.Tensor([[1,2,3],[4,5,6]]).long()
print(t1)
print("-------view()-------")
t2=t1.view(3,2)
print(t2)

输出:

二、view()函数:-1参数的应用

可以对view()中的一个参数设置为-1。 若是对目标张量的某一维度不明、待定、视情况而变、懒得计算等等,可以使用-1参数进行操作。函数会自动计算-1参数对应维度的值。

例如:

data=[[[1,2],[3,4],[5,6]],[[7,8],[9,0],[10,11]]]
t1=torch.Tensor(data).long()#size=2,3,2
print(t1)
print(t1.size())
print("-------view()-------")
t2=t1.view(3,2,-1)
print(t2)
print(t2.size())

结果:

可以看到,对于设置参数为-1的最后一维,自动计算为了2。

三、view_as()函数

函数定义:view_as(tensor) [参数为一个Tensor张量]

该函数的作用是将调用函数的变量,转变为同参数tensor同样的形状。

例如:

data1 = [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 0], [10, 11]]]
t1 = torch.Tensor(data1).long()  # size=2,3,2
data2 = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 0], [10, 11]]]
t2 = torch.Tensor(data2).long()
print(t1.size())
print(t2.size())
print("-------view_as()-------")
t2=t2.view_as(t1)
print(t2)
print(t2.size())

输出结果:

可以看出经过view_as()操作后,t2 Tensor转变为了与t1 相同的形状。(需要重新对t2赋值,这是因为不是进行的原地操作)

-------end------

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章