tensorflow與pytorch的移植轉換函數對比表

     相信有了這份表格對比,tensorflow與pytorch的基本移植轉換,應該是手到擒來。

名稱 tensorflow pytorch
二維卷積 tf.nn.conv2d(input_x, w, strides=[1, 1, 1, 1], padding='SAME') torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
relu激活函數 tf.nn.relu(input_x) torch.nn.ReLU()
填充函數 tf.pad(input_x, [(a0,b0), (a1,b1), (a2,b2), (a3,b3)])

pad_num = (a3,b3, a2, b2, a1, b1, a0, b0)

torch.nn.functional.pad(input_x, pad_num, mode='constant')

元素個數 tf.size(input_x) torch.numel(input_x)
展平 tf.reshape(input_x, (tf.size(input_x), -1)) input_x.view(torch.numel(input_x), -1)
softmax tf.nn.softmax(input_x, axis=1) torch.nn.functional.softmax(input_x, dim=1)
調整類型 tf.cast(input_x, tf.int32) input_x.type(torch.LongTensor)
除去維度爲1 tf.squeeze(input_x, squeeze_dims=1) torch.squeeze(input_x)
合併 tf.concat((input_x1, input_x2), axis=3) torch.cat((input_x1, input_x2), dim=3)
劃分成相同維度的塊 tf.split(input_x, axis=3, num_or_size_splits=2) torch.chunk(input_x, dim=3, chunks=2)
產生1的矩陣 tf.ones((a,b)) torch.ones(a,b)
重複 tf.tile() input_x.repeat()
交換維度 tf.transpose(input_x, [0,1,2,3]) input_x.permute((0,1,2,3))  注意這種用法:p01.permute((0, 2, 3, 1)).contiguous().view(int(np.prod(shape01)), -1)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章