Tensor must be 4-D with last dim 1, 3, or 4,bug記錄

torch學習入坑(1)
記錄一下第一個坑,防止自己忘記。
在用torchvision.utils.make_grid處理完圖片之後,維度會變爲三維,如果需要使用tf.summary.image ,需要增加batch size 的那一維,需要用到的函數爲 tf.expand_dims。

注意:

1、torchvision.utils.make_grid函數
輸入:Tensor of shape (B x C x H x W)
輸出:多個圖片拼接成的一個大圖 三維 沒有batch size 那一維
like this:
batch size 16

以下解釋來自官網
鏈接http://pytorch.org/docs/master/torchvision/utils.html
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)
scale_each=False, pad_value=0)

作用:Make a grid of images.
Parameters:

    tensor (Tensor or list) – 4D mini-batch Tensor of shape (B x C x H x W) or a list of images all of the same size.
    nrow (int, optional) – Number of images displayed in each row of the grid. The Final grid size is (B / nrow, nrow). Default is 8.
    padding (int, optional) – amount of padding. Default is 2.
    normalize (bool, optional) – If True, shift the image to the range (0, 1), by subtracting the minimum and dividing by the maximum pixel value.
    range (tuple, optional) – tuple (min, max) where min and max are numbers, then these numbers are used to normalize the image. By default, min and max are computed from the tensor.
    scale_each (bool, optional) – If True, scale each image in the batch of images separately rather than the (min, max) over all images.
    pad_value (float, optional) – Value for the padded pixels.

2、順便記錄一下另外一個函數 torchvision.utils.save_image:
torchvision.utils.save_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)
Save a given Tensor into an image file.
Parameters:

tensor (Tensor or list) – Image to be saved. If given a mini-batch tensor, saves the tensor as a grid of images by calling make_grid.
   **kwargs – Other arguments are documented in make_grid.

3、tf.summary.image:
官網鏈接:
http://www.tensorfly.cn/tfdoc/api_docs/python/train.html#image_summary

tf.image_summary(tag, tensor, max_images=None, collections=None, name=None)

Outputs a Summary protocol buffer with images.

The summary has up to max_images summary values containing images. The images are built from tensor which must be 4-D with shape [batch_size, height, width, channels] and where channels can be:

    1: tensor is interpreted as Grayscale.
    3: tensor is interpreted as RGB.
    4: tensor is interpreted as RGBA.

4、 tf.expand_dims 增加維度:
官網例子:

# 't' is a tensor of shape [2]
shape(expand_dims(t, 0)) ==> [1, 2]
shape(expand_dims(t, 1)) ==> [2, 1]
shape(expand_dims(t, -1)) ==> [2, 1]

# 't2' is a tensor of shape [2, 3, 5]
shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]

Args:
input: A Tensor.
dim: A Tensor. Must be one of the following types: int32, int64. 0-D (scalar). Specifies the dimension index at which to expand the shape of input.
name: A name for the operation (optional).

Returns:
A Tensor. Has the same type as input. Contains the same data as input, but its shape has an additional dimension of size 1 added.

例子:重要!!

# coding=utf-8


from __future__ import print_function
from six.moves import range

import torch.backends.cudnn as cudnn
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torchvision.utils as vutils
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
import os
import time
import torchvision
from PIL import Image, ImageFont, ImageDraw
from copy import deepcopy
import tensorflow as tf
from torch.utils.data import DataLoader, Dataset
# from miscc.config import cfg
# from miscc.utils import mkdir_p
# from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

sess = tf.InteractiveSession()
# def test():
# torchvision輸出的是PILImage,值的範圍是[0, 1].
# 我們將其轉化爲tensor數據,並歸一化爲[-1, 1]。
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ])

# 訓練集,將相對目錄./data下的cifar-10-batches-py文件夾中的全部數據(50000張圖片作爲訓練數據)加載到內存中,若download爲True時,會自動從網上下載數據並解壓
trainset = torchvision.datasets.CIFAR10(root='./test', train=True, download=False, transform=transform)

# 將訓練集的50000張圖片劃分成12500份,每份4張圖,用於mini-batch輸入。shffule=True在表示不同批次的數據遍歷時,打亂順序。num_workers=2表示使用兩個子進程來加載數據
trainloader = torch.utils.data.DataLoader(trainset, batch_size=3, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print ("len(trainset)",len(trainset))
print ("len(trainloader)",len(trainloader))
print("end1")

for i, data in enumerate(trainloader, 0):
    # print(data[i][0])
    #
    # img = transforms.ToPILImage()(data[i][0])
    # img.show()
    # break
""
以上代碼來自簡書 https://www.jianshu.com/p/8da9b24b2fb6
""

    real_img_set = vutils.make_grid(data[i][0]).numpy()
    # print("real_img_set1",real_img_set)
    print("real_img_set_make_grid.shape",real_img_set.shape)
    real_img_set = np.transpose(real_img_set, (1, 2, 0))
    # print("real_img_set_transpose",real_img_set)
    print("real_img_set_transpose.shape",real_img_set.shape)
    real_img_set = real_img_set * 255
    # print("real_img_set255",real_img_set)
    print("real_img_set255.shape",real_img_set.shape)
    real_img_set = real_img_set.astype(np.uint8)
    print("real_img_setend",real_img_set)
    print("real_img_setend.shape",real_img_set.shape)

    super_real_img_set = tf.expand_dims(real_img_set, 0)
    print ("super_real_img_old", super_real_img_set)
    print ("super_real_img_old shape", super_real_img_set.shape)
    print ("super_real_img_old [0]", super_real_img_set[0])
    sup_real_img = tf.summary.image('real_img', super_real_img_set)
    print("sup_real_img", sup_real_img)
    print("sup_real_img shape", sup_real_img.shape)
    sup_real_img_new = sess.run(sup_real_img)
    # summary_writer.add_summary(sup_real_img_new, count)

    break
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章