pytorch+tensorboard可视化最简单例子

尽管pytorch 已经集成了tensorboard的接口,但是你还要下载安装tensorboard工具。

下载tensorboard: pip install tensorboard.    不行的话,再安装 pip install tensorboardX。

tensorboard用网页的方式把很多的信息都展现出来,比较方便。上方image和graph分别代表你训练的数据和你的深度学习网络结构图。

看最简单使用例子:

定义一个学习网络,来分类FashionMNIST,在SummaryWriter的时候,就开始用tensorboard了。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import torchvision
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter

def get_num_correct(preds,labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        self.fc1=nn.Linear(in_features=12*4*4,out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
    def forward(self, t):
        t=F.relu(self.conv1(t))
        t=F.max_pool2d(t,kernel_size=2,stride=2)

        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t,kernel_size=2,stride=2)

        t=t.flatten(start_dim=1)
        t=F.relu(self.fc1(t))

        t=F.relu(self.fc2(t))
        t=self.out(t)

        return t

if __name__ == '__main__':
    train_set=torchvision.datasets.FashionMNIST(
        root='./data-source',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor()
        ])
    )

    train_loader=torch.utils.data.DataLoader(train_set,batch_size=100,shuffle=True)

    #tensor board
    tb=SummaryWriter()
    network=Network()
#取出训练用图
    images,labels=next(iter(train_loader))
    grid=torchvision.utils.make_grid(images)
#想用tensorboard看什么,你就tb.add什么。image、graph、scalar等
    tb.add_image('images', grid)
    tb.add_graph(model=network,input_to_model=images)
    tb.close()
    exit(0)

写好代码之后,运行一遍,看有没有错误,有错误的地方tensorboard不会储存也不会显示。

运行之后这个目录下会出现runs目录,里面储存量tensorboard要显示的数据。

然后在这个目录下cmd,指定吧runs目录下的数据在tensorboard显示。

tensorboard --logdir=runs

其就会出现这个,然后直接浏览器访问就行了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章