pytorch torch.chunk(tensor, chunks, dim)

2. torch.chunk(tensor, chunks, dim)

說明:在給定的維度上講張量進行分塊。

參數

  • tensor(Tensor) -- 待分塊的輸入張量
  • chunks(int) -- 分塊的個數
  • dim(int) -- 維度,沿着此維度進行分塊
>>> x = torch.randn(3, 3)
>>> x
tensor([[ 1.0103,  2.3358, -1.9236],
        [-0.3890,  0.6594,  0.6664],
        [ 0.5240, -1.4193,  0.1681]])
>>> torch.chunk(x, 3, dim=0)
(tensor([[ 1.0103,  2.3358, -1.9236]]), tensor([[-0.3890,  0.6594,  0.6664]]), tensor([[ 0.5240, -1.4193,  0.1681]]))
>>> torch.chunk(x, 3, dim=1)
(tensor([[ 1.0103],
        [-0.3890],
        [ 0.5240]]), tensor([[ 2.3358],
        [ 0.6594],
        [-1.4193]]), tensor([[-1.9236],
        [ 0.6664],
        [ 0.1681]]))
>>> torch.chunk(x, 2, dim=1)
(tensor([[ 1.0103,  2.3358],
        [-0.3890,  0.6594],
        [ 0.5240, -1.4193]]), tensor([[-1.9236],
        [ 0.6664],
        [ 0.1681]]))

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章