MXNET深度學習框架-29-語義分割(FCN)

          語義分割計算機視覺領域中的一個重要模塊,它與之前的圖像分類、目標檢測任務不同,它是一個精細到每個像素塊的一個圖像任務,當然,它包括分類和定位。更多關於語義分割的方法和原理請讀者自行搜索相關論文和博文,本文不做過多闡述。
在這裏插入圖片描述
1、數據集
          在計算機視覺領域,Pascal VOC 2012 數據集是比較經典的,該數據集包含了目標檢測、對象分割的數據及相關標籤,本文使用的語義分割數據集就是來自於VOC2012。首先,需要下載該數據集:官網下載鏈接,注意,該數據集大概有2個G的大小,建議提前下載。
在這裏插入圖片描述
在這裏插入圖片描述

2、讀取數據集
寫一段程序來讀取並顯示一下相關數據圖片及其標籤:

def show_images(imgs, num_rows, num_cols, scale=2):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    for i in range(num_rows):
        for j in range(num_cols):
            axes[i][j].imshow(imgs[i * num_cols + j].asnumpy())
            axes[i][j].axes.get_xaxis().set_visible(False)
            axes[i][j].axes.get_yaxis().set_visible(False)
    plt.show()


root='F:/test/VOC2012_dataset/VOCtrainval_11-May-2012/VOCdevkit/VOC2012'
# 1、將訓練圖片和標註圖標讀進內存
def read_images(root,train=True):
    if train:
        txt_fname=root+"/ImageSets/Segmentation/train.txt"
    else:
        txt_fname = root + "/ImageSets/Segmentation/val.txt"
    with open(txt_fname,'r') as f:
        images=f.read().split()

    features, labels = [None] * len(images), [None] * len(images)

    for i, fname in enumerate(images):
        features[i] = image_deal.imread('%s/JPEGImages/%s.jpg' % (root, fname))
        labels[i] = image_deal.imread(
            '%s/SegmentationClass/%s.png' % (root, fname))
    return features, labels

n=5 #顯示前5張圖片
train_features,train_labels=read_images(root)
show_images(train_features[0:n] + train_labels[0:n],2,5)
for im in train_features[0:n]:
    print(im.shape)

顯示結果:
在這裏插入圖片描述
在這裏插入圖片描述
          我們可以看到,每個圖片均對應一個分割的圖像,這個分割的圖像就是標籤(實際上是對一張圖片的所以像素塊分別賦予了一個label)。在打印的圖像大小信息中我們也可以看到,每張圖片的大小可能並不是一樣的,在CNN訓練中,爲了批量訓練,我們都會把它resize成同樣的大小,但是這裏不行,爲什麼?因爲它是對每個pixel做了標籤,如果resize(比如插值resize)之後,這樣會出來一個新的pixel,那麼新出來的pixel可能是介於兩邊的中間值,這就導致標籤信息無法匹配上,結果不準。
2、裁剪
          無法resize,那怎麼辦?我們可以注意到,每張圖片的寬度好像都是500,爲了能批量訓練,我們可以使用剪切的方法來使得它們成爲一樣的大小。

def rand_crop(image, label, height, width):
    data, rect = image_deal.random_crop(image, (width, height))
    label = image_deal.fixed_crop(label, *rect)
    return data, label
imgs=[]
for _ in range(3):
    imgs += rand_crop(train_features[0], train_labels[0], 200, 300)
show_images(imgs,3,2)

結果:
在這裏插入圖片描述
3、每個物體和背景對應的RGB值(官方網站給的)

colormap = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
              [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
              [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
              [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
              [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
              [0, 64, 128]]
colorclass = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
            'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
            'diningtable', 'dog', 'horse', 'motorbike', 'person',
            'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']

4、查找pixel中每個類別的像素索引

colormap2label = nd.zeros(256 ** 3)
for i, color_map in enumerate(colormap):
    colormap2label[(color_map[0] * 256 + color_map[1]) * 256 + color_map[2]] = i


def label_indices(colormap):
    colormap = colormap.astype('int32')
    idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256
           + colormap[:, :, 2])
    return colormap2label[idx]


y = label_indices(train_labels[0])
print(y[105:115, 130:140])

結果:
          可以看到是飛機的那塊地方均被標記成了1,而背景則被標記成了0。爲什麼飛機是1?因爲飛機排在第2個,那麼索引就是1:在這裏插入圖片描述
在這裏插入圖片描述
5、圖像預處理

rgb_mean=nd.array([0.485,0.456,0.406]) #官方給的
rgb_std=nd.array([0.229,0.224,0.225])

def normlize(image):
    return ((image.astype("float32")/255)-rgb_mean)/rgb_std

6、自定義語義分割數據集類
          在mxnet中,可以通過繼承Gluon提供的Dataset類自定義了語義分割數據集類VOCSegDataset。通過實現__getitem__函數,可以任意訪問數據集中索引爲idx的輸入圖像及其每個像素的類別索引。由於數據集中有些圖像的尺寸可能小於隨機裁剪所指定的輸出尺寸,這些樣本需要通過自定義的filter函數所移除。

class VOCSegDataset(gn.data.Dataset):
    def __init__(self, is_train, crop_size, voc_dir):

        self.crop_size = crop_size
        data, labels = read_images(root=voc_dir, train=is_train)
        data=self.filter(data)
        self.data = [normlize(im) for im in data]
        self.labels = self.filter(labels)
        print('read ' + str(len(self.data)) + ' examples')

  #只保留大於裁剪尺寸的圖片
    def filter(self, images):
        return [img for img in images if (
            img.shape[0] >= self.crop_size[0] and
            img.shape[1] >= self.crop_size[1])]
  # 每讀一張圖片,就把它裁剪,並把label的圖片轉成標籤
    def __getitem__(self, idx):
        data, label = rand_crop(self.data[idx], self.labels[idx],
                                       *self.crop_size)
        data=data.transpose((2,0,1))
        label=label_indices(label)
        return data,label

    def __len__(self):
        return len(self.data)

看看這樣處理之後有多少張圖片,這裏設裁剪的圖片爲(h,w)=(320,480):

input_size=(320,480)
voc_train=VOCSegDataset(True,input_size,root)
voc_test=VOCSegDataset(False,input_size,root)

結果:在這裏插入圖片描述
          可以看到,訓練集有1114張,測試集有1078張,對於這樣的樣本集,使用深度學習方法從頭訓練顯然是不可能的,到了這裏,就要提前想好後面要用到什麼方法了——微調(Fine Turning)

7、定義批量讀取

batch_size=5 
train_data=gn.data.DataLoader(voc_train,batch_size,shuffle=True)
test_data=gn.data.DataLoader(voc_test,batch_size,shuffle=False)
# 查看一下數據維度
for data,label in train_data:
    print(data.shape," ",label.shape)
    break

結果:
在這裏插入圖片描述
可以看到,圖片的維度與我們常見的一樣,滿足NCHW的要求,而標籤則變成了3維的形狀。

全卷積神經網絡(Fully Convolutional Networks,FCN)

          從上面標籤是3維的我們可以知道,分割的任務與分類、檢測的任務完全不一樣,預測的標號不再是一個數字,而是每個pixel。那麼在預測的時候,輸出的結果也應該是一個3維的結果,因爲要與輸入的標籤一一對應,但是,我們知道,CNN都是把一個3維的數據變成一個一維的標量,這顯然與分割預測的結果不同,對於這種情況,轉置卷積(transposed convolution)出現了,全卷積網絡通過轉置卷積層將中間層特徵圖的高和寬變換回輸入圖像的尺寸,從而令預測結果與輸入圖像在空間維(高和寬)上一一對應:給定空間維上的位置,通道維的輸出即該位置對應像素的類別預測。(FCN就是在Forward時把維度變小,在Backward時把維度變大,卷積本身就是一個對偶函數,卷積的導數的導數還是卷積自己)

          舉個例子:一個(3,320,480)的圖片,通過步長爲2,padding爲1的卷積之後,它的形狀變成(3,160,240),那麼,通過轉置卷積之後,它的形狀變成(3,320,480),這樣就進行了還原。

1、轉置卷積
          gluon中已經實現了轉置卷積的函數:

# 除了替換輸出的通道數以外,其餘的參數都不變,可以將輸出還原爲輸入的大小
conv=gn.nn.Conv2D(channels=10,kernel_size=4,strides=2,padding=1)
conv_trans=gn.nn.Conv2DTranspose(channels=3,kernel_size=4,strides=2,padding=1) #圖片最開始輸入通道數爲3
conv.initialize()
conv_trans.initialize()
x=nd.random_normal(shape=(1,3,16,16))
print(conv(x).shape," ",conv_trans(conv(x)).shape)

結果:
在這裏插入圖片描述
          可以看到,我們定義了輸入是(1,3,16,16),通過卷積層之後,大小變成了(1,10,8,8),再將結果通過轉置卷積之後,結果變成了(1,3,16,16),與輸入的維度一樣。

          另外需要注意的是,在最後的卷積層我們同樣使用Flatten或GAP來使得數據偏平化,使其能輸入到FC中,而這樣操作會損害空間信息,這對語義分割非常重要,其中一個解決辦法是去掉不需要的池化層,並利用1X1的卷積層來替代FC。所以給定一個FCN,它需要做以下工作:
                              1)利用1X1的卷積替代FC;
                              2)去掉損失空間信息的池化層;
                              3)最後接上卷積轉置層來得到需要輸出的大小;
                              4)爲了訓練更快,可使用微調的辦法(數據量大可以忽略這一條)。

2、下載預訓練模型(ResNet18,可自行選擇預訓練模型)

pretrained_model=models.resnet18_v2(pretrained=True)
print(pretrained_model.features[-4:],pretrained_model.output) #打印看看最後幾層

結果:
在這裏插入圖片描述
從上圖的結果來看,我們是不需要GAP和FC的,所以要把它替換掉。

3、添加訓練好的權重信息,並修改GAP和FC

net=gn.nn.Sequential()
for layer in pretrained_model.features[:-2]:
    net.add(layer)
m=nd.random_normal(shape=(1,3,320,480))
print("input shape:",m.shape)
print("out shape:",net(m).shape)

# 添加1X1卷積和轉置卷積
num_class=len(colorclass)# 幾個類別
with net.name_scope():
    net.add(gn.nn.Conv2D(num_class,1,1),
            gn.nn.Conv2DTranspose(num_class,kernel_size=64,padding=16,strides=32)) # kernel_size最好大於32,padding=(kernel_size-strides)/2

結果:
在這裏插入圖片描述
          可以看到,一張(3,320,480)通過預訓練模型之後,輸出的維度是(512,10,15),其中,寬高均縮小了32倍,爲什麼要明確到32倍?以爲最後轉置卷積還原的時候會用到這個係數。

4、訓練
          因爲上面我們把轉置卷積的核大小設成了64,所以很難訓練,而轉置卷積我們可以把它看成是插值的操作,在實際操作中發現,把轉置卷積初始化從雙線性插值函數可以使得訓練更加容易。

# 雙線性插值
def bilinear_kernel(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = np.ogrid[:kernel_size, :kernel_size]
    filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
    weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
                      dtype='float32')
    weight[range(in_channels), range(out_channels), :, :] = filt
    return nd.array(weight)

# 初始化
net[-2].initialize(init=init.Xavier()) # 1X1 conv
net[-1].initialize(init.Constant(bilinear_kernel(in_channels=num_class,out_channels=num_class,kernel_size=64)),ctx=ctx)

之前做分類的時候,使用了gn.loss.SoftmaxCrossEntropyLoss()這個函數,它默認會把結果Flatten化,所以,這裏要加入axis=1的命令,其它的訓練都是一樣的。

cross_loss=gn.loss.SoftmaxCrossEntropyLoss(axis=1)

net.collect_params().reset_ctx(ctx)
trainer=gn.Trainer(net.collect_params(),'sgd',{"learning_rate":0.1,"wd":1e-3})

# 定義準確率
def accuracy(output,label):
    return nd.mean(output.argmax(axis=1)==label).asscalar()

def evaluate_accuracy(data_iter, net,ctx):
    acc_sum, n = 0.0, 0
    for features,label in data_iter:
        features = features.as_in_context(ctx)
        label = label.as_in_context(ctx)
        output=net(features)
        acc_sum+=accuracy(output,label)
        n += 1
    return acc_sum/ n

def train(train_iter, test_iter, net, cross_loss, trainer, ctx, num_epochs):

    print('training on', ctx)
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n,start = 0.0, 0.0, 0, time.time()
        for Xs, ys in train_iter:
            Xs=Xs.as_in_context(ctx)
            ys=ys.as_in_context(ctx)
            with ag.record():
                output = net(Xs)
                loss = nd.mean(cross_loss(output, ys))
            loss.backward()
            trainer.step(batch_size)
            train_l_sum += loss.asscalar()
            train_acc_sum+=accuracy(output,ys)
            n += 1
        test_acc = evaluate_accuracy(test_iter, net,ctx)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f ''time %.1f sec'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n,test_acc,
                 time.time() - start))
train(train_data,test_data,net,cross_loss,trainer,ctx,10)

訓練結果:
在這裏插入圖片描述
訓練完了別忘了保存一下模型,好不容易訓練的別讓它丟了:

net.save_parameters("FCN.params")# 保存模型

5、預測
預測的時候與圖像分類類似,主要的區別在於我們需要在axis=1上做argmax,同時我們定義img2label的反函數,它將預測值轉成圖片。

net.load_parameters("FCN.params") #讀取參數
def predict(image):
    data=normlize(image)
    data=data.transpose((2,0,1)).expand_dims(axis=0)
    y_hat=net(data.as_in_context(ctx))
    pred=nd.argmax(y_hat,axis=1)
    return pred.reshape((pred.shape[1],pred.shape[2]))

def label2img(pred,colormap):
    cm = nd.array(colormap, ctx=ctx, dtype='uint8')
    X = pred.astype('int32')
    return cm[X, :]

# 讀取測試集前幾張圖片做預測
test_image,test_label=read_images(root,False)

image_6=[]
for i in range(6):
    x=test_image[i]
    pred=label2img(predict(x),colormap)
    image_6+=[x,pred,test_label[i]]
show_images(image_6,6,3)

結果:
在這裏插入圖片描述
上圖中,中間的是預測結果,最右邊的是真實結果,預測結果離真實結果還有一段距離,可以多加幾個epoch跑跑看。

下面附上所有源碼:

import mxnet.ndarray as nd
import mxnet.gluon as gn
import mxnet.autograd as ag
import mxnet.initializer as init
import mxnet.image as image_deal
import matplotlib.pyplot as plt
import numpy as np
from mxnet.gluon.model_zoo import vision as models
import time
import mxnet as mx

# 顯示圖片
def show_images(imgs, num_rows, num_cols, scale=2):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    for i in range(num_rows):
        for j in range(num_cols):
            axes[i][j].imshow(imgs[i * num_cols + j].asnumpy())
            axes[i][j].axes.get_xaxis().set_visible(False)
            axes[i][j].axes.get_yaxis().set_visible(False)
    plt.show()


root = 'F:/test/VOC2012_dataset/VOCtrainval_11-May-2012/VOCdevkit/VOC2012'


# 1、將訓練圖片和標註圖標讀進內存
def read_images(root, train=True):
    if train:
        txt_fname = root + "/ImageSets/Segmentation/train.txt"
    else:
        txt_fname = root + "/ImageSets/Segmentation/val.txt"
    with open(txt_fname, 'r') as f:
        images = f.read().split()

    features, labels = [None] * len(images), [None] * len(images)

    for i, fname in enumerate(images):
        features[i] = image_deal.imread('%s/JPEGImages/%s.jpg' % (root, fname))
        labels[i] = image_deal.imread(
            '%s/SegmentationClass/%s.png' % (root, fname))
    return features, labels


n = 5  # 顯示前5張圖片
train_features, train_labels = read_images(root)
show_images(train_features[0:n] + train_labels[0:n], 2, 5)
# for im in train_features:
#     print(im.shape)


# 2、裁剪
def rand_crop(image, label, height, width):
    data, rect = image_deal.random_crop(image, (width, height))
    label = image_deal.fixed_crop(label, *rect)
    return data, label


imgs = []
for _ in range(3):
    imgs += rand_crop(train_features[0], train_labels[0], 200, 300)
# print(imgs)
show_images(imgs, 3, 2)

# 3、每個物體和背景對應的RGB值
colormap = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
              [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
              [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
              [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
              [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
              [0, 64, 128]]
colorclass = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
            'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
            'diningtable', 'dog', 'horse', 'motorbike', 'person',
            'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']

print(len(colorclass))  # 共21類

# 4、查找pixel中每個類別的像素索引
colormap2label = nd.zeros(256 ** 3)
for i, color_map in enumerate(colormap):
    colormap2label[(color_map[0] * 256 + color_map[1]) * 256 + color_map[2]] = i


def label_indices(colormap):
    colormap = colormap.astype('int32')
    idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256
           + colormap[:, :, 2])
    return colormap2label[idx]


y = label_indices(train_labels[0])
print(y[105:115, 130:140])

#5、圖像預處理
rgb_mean=nd.array([0.485,0.456,0.406]) #官方給的
rgb_std=nd.array([0.229,0.224,0.225])

def normlize(image):
    return ((image.astype("float32")/255)-rgb_mean)/rgb_std

# for im in train_features:
#     print(normlize(im))
#6、自定義語義分割數據集類
class VOCSegDataset(gn.data.Dataset):
    def __init__(self, is_train, crop_size, voc_dir):

        self.crop_size = crop_size
        data, labels = read_images(root=voc_dir, train=is_train)
        data=self.filter(data)
        self.data = [normlize(im) for im in data]
        self.labels = self.filter(labels)
        print('read ' + str(len(self.data)) + ' examples')

  #只保留大於裁剪尺寸的圖片
    def filter(self, images):
        return [img for img in images if (
            img.shape[0] >= self.crop_size[0] and
            img.shape[1] >= self.crop_size[1])]
  # 每讀一張圖片,就把它裁剪,並把label的圖片轉成標籤
    def __getitem__(self, idx):
        data, label = rand_crop(self.data[idx], self.labels[idx],
                                       *self.crop_size)
        data=data.transpose((2,0,1))
        label=label_indices(label)
        return data,label

    def __len__(self):
        return len(self.data)
input_size=(320,480)
voc_train=VOCSegDataset(True,input_size,root)
voc_test=VOCSegDataset(False,input_size,root)

#7、定義批量讀取
batch_size=5
train_data=gn.data.DataLoader(voc_train,batch_size,shuffle=True)
test_data=gn.data.DataLoader(voc_test,batch_size,shuffle=False)
# 查看一下數據維度
for data,label in train_data:
    print(data.shape," ",label.shape)
    break

'''---------FCN---------'''
#1、轉置卷積實例
# 除了替換輸出的通道數以外,其餘的參數都不變,可以將輸出還原爲輸入的大小
conv=gn.nn.Conv2D(channels=10,kernel_size=4,strides=2,padding=1)
conv_trans=gn.nn.Conv2DTranspose(channels=3,kernel_size=4,strides=2,padding=1) #圖片最開始輸入通道數爲3
conv.initialize()
conv_trans.initialize()
x=nd.random_normal(shape=(1,3,16,16))
print(conv(x).shape," ",conv_trans(conv(x)).shape)

# 2、下載預訓練模型resnet18
ctx=mx.gpu(0)
pretrained_model=models.resnet18_v2(pretrained=True,ctx=ctx)
print(pretrained_model.features[-4:],pretrained_model.output)

#3、添加訓練好的權重信息,並修改GAP和FC
net=gn.nn.Sequential()
for layer in pretrained_model.features[:-2]:
    net.add(layer)
m=nd.random_normal(shape=(1,3,320,480),ctx=ctx)
print("input shape:",m.shape)
print("out shape:",net(m).shape)

# 添加1X1卷積和轉置卷積
num_class=len(colorclass)# 幾個類別
with net.name_scope():
    net.add(gn.nn.Conv2D(num_class,1,1),
            gn.nn.Conv2DTranspose(num_class,kernel_size=64,padding=16,strides=32)) # kernel_size最好大於32,padding=(kernel_size-strides)/2

# 4、訓練
# 雙線性插值
def bilinear_kernel(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = np.ogrid[:kernel_size, :kernel_size]
    filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
    weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
                      dtype='float32')
    weight[range(in_channels), range(out_channels), :, :] = filt
    return nd.array(weight)

# 初始化
net[-2].initialize(init=init.Xavier(),ctx=ctx) # 1X1 conv

net[-1].initialize(init.Constant(bilinear_kernel(in_channels=num_class,out_channels=num_class,
                                                 kernel_size=64)),ctx=ctx)

cross_loss=gn.loss.SoftmaxCrossEntropyLoss(axis=1)

net.collect_params().reset_ctx(ctx)
trainer=gn.Trainer(net.collect_params(),'sgd',{"learning_rate":0.1,"wd":1e-3})

# 定義準確率
def accuracy(output,label):
    return nd.mean(output.argmax(axis=1)==label).asscalar()

def evaluate_accuracy(data_iter, net,ctx):
    acc_sum, n = 0.0, 0
    for features,label in data_iter:
        features = features.as_in_context(ctx)
        label = label.as_in_context(ctx)
        output=net(features)
        acc_sum+=accuracy(output,label)
        n += 1
    return acc_sum/ n

def train(train_iter, test_iter, net, cross_loss, trainer, ctx, num_epochs):

    print('training on', ctx)
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n,start = 0.0, 0.0, 0, time.time()
        for Xs, ys in train_iter:
            Xs=Xs.as_in_context(ctx)
            ys=ys.as_in_context(ctx)
            with ag.record():
                output = net(Xs)
                loss = nd.mean(cross_loss(output, ys))
            loss.backward()
            trainer.step(batch_size)
            train_l_sum += loss.asscalar()
            train_acc_sum+=accuracy(output,ys)
            n += 1
        test_acc = evaluate_accuracy(test_iter, net,ctx)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f ''time %.1f sec'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n,test_acc,
                 time.time() - start))
# train(train_data,test_data,net,cross_loss,trainer,ctx,10)
# net.save_parameters("FCN.params")# 保存模型

# 5、預測
net.load_parameters("FCN.params") #讀取參數
def predict(image):
    data=normlize(image)
    data=data.transpose((2,0,1)).expand_dims(axis=0)
    y_hat=net(data.as_in_context(ctx))
    pred=nd.argmax(y_hat,axis=1)
    return pred.reshape((pred.shape[1],pred.shape[2]))

def label2img(pred,colormap):
    cm = nd.array(colormap, ctx=ctx, dtype='uint8')
    X = pred.astype('int32')
    return cm[X, :]

# 讀取測試集前幾張圖片做預測
test_image,test_label=read_images(root,False)

image_6=[]
for i in range(6):
    x=test_image[i]
    pred=label2img(predict(x),colormap)
    image_6+=[x,pred,test_label[i]]
show_images(image_6,6,3)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章