〖TensorFlow2.0筆記21〗自定義數據集(寶可精靈數據集)實現圖像分類+補充:tf.where!

自定義數據集(寶可精靈數據集)實現圖像分類+補充:tf.where!

一. 數據集介紹以及加載

1.1. 數據集簡單描述

  • 我們收集了寶可精靈(動漫)視頻片段,從中收集了5種精靈,每個精靈有各種形態的圖片。其中:皮卡丘234張圖片,超夢239張,傑尼龜223張,小火龍238張,妙蛙種子234張。
  • 數據集劃分: 每一類別的所有圖片,按照下面這個比例進行劃分。其實這個比例並不是針對每個類別進行提取60%(600多張),而是針對總體的1168張圖片。

注意: 如果你這個test只劃分了10%(就是測試樣本比較少)。這樣的話,測試性能波動是比較大的,這裏10%,大概100多張(每類20多張),還是比較可觀的。如果你的數據集非常非常小,測試的時候波動是比較大的。我們爲了讓測試更加的準確,我們有目的的讓test,validation增加到20%。這裏我們的數據集大概1000多張,可以看做一箇中小規模的數據集。

1.2. 程序實現步驟

  • 總的說來,我們接下來會分爲4大步驟:
  • 數據集加載(重要)。
  • 建立模型。
  • 訓練、驗證和測試。
  • 遷移學習(這裏主要針對小樣本的數據集,如果你的數據比較少,你又希望獲得一個比較好的性能,這裏遷移學習就非常有用了。它通過共享其它領域的一些知識,可以幫助你在這個領域只需要少量的數據集,就能取得一個不錯的性能)。

1.3. 加載數據的格式

  • 數據加載如下

**注意: ** TensorFlow中map這個函數特別的重要對於數據的預處理,這裏把數據地址通過 TensorFlow 自帶的函數解析成圖片!

1.4. map函數數據處理

  • 下面是處理的核心:

1.5. 自定義數據集處理流程

  • 下面是處理的核心:
  • 根據類別名,來給類別編碼。編碼爲0-4;編碼好之後,存儲到這樣一個字典中。
  • image.csv文件的格式,生成image.csv文件之後,我們下次不再需要從新執行了,只需要直接分析image.csv文件。把第一列的圖片路徑解析成圖片本身。

二. 數據集的預處理工作

2.1. 數據增強Data Augmention

  • 通過上面這種data augmention的方式,可以獲得圖片成倍的增加。本質上,data augmention可以無限的增加樣式,因爲可以通過不同的方式進行變化。一般來說左右翻轉穩定一些裁剪一點,能去掉一些邊緣部分
  • Normalizaion的時候,比如上面的,我們0~255,放縮到0 ~ 1;之前我們也有嘗試過把值放縮到-1 ~ 1之間。但是實際上針對圖片數據集,我們有一個更加高效的normalization的方式。

2.2. 預處理代碼總結

import  os, glob
import  random, csv
import tensorflow as tf

def load_csv(root, filename, name2label):
    """ 加載CSV文件!
    :param root:            root:數據集根目錄
    :param filename:        filename:csv文件名
    :param name2label:      name2label:類別名編碼表
    :return:
    """
    # 判斷.csv文件是否已經存在!
    if not os.path.exists(os.path.join(root, filename)):
        images = []
        for name in name2label.keys():
            # 'pokemon\\mewtwo\\00001.png
            images += glob.glob(os.path.join(root, name, '*.png'))
            images += glob.glob(os.path.join(root, name, '*.jpg'))
            images += glob.glob(os.path.join(root, name, '*.jpeg'))

        # 1167, 'pokemon\\bulbasaur\\00000000.png'
        print(len(images), images)

        random.shuffle(images)
        with open(os.path.join(root, filename), mode='w', newline='') as f:
            writer = csv.writer(f)
            for img in images:  # 'pokemon\\bulbasaur\\00000000.png'
                name = img.split(os.sep)[-2]
                label = name2label[name]
                # 'pokemon\\bulbasaur\\00000000.png', 0  圖片路徑和標籤!
                writer.writerow([img, label])
            print('written into csv file:', filename)

    # read from csv file
    images, labels = [], []
    with open(os.path.join(root, filename)) as f:
        reader = csv.reader(f)
        for row in reader:
            # 'pokemon\\bulbasaur\\00000000.png', 0
            img, label = row
            label = int(label)

            images.append(img)
            labels.append(label)

    assert len(images) == len(labels)

    return images, labels

# 加載pokemon數據集的工具!
def load_pokemon(root, mode='train'):
    """ 加載pokemon數據集的工具!
    :param root:    數據集存儲的目錄
    :param mode:    mode:當前加載的數據是train,val,還是test
    :return:
    """
    # 創建數字編碼表,範圍0-4;
    name2label = {}  # "sq...":0   類別名:類標籤;  字典 可以看一下目錄,一共有5個文件夾,5個類別:0-4範圍;
    for name in sorted(os.listdir(os.path.join(root))):     # 列出所有目錄;
        if not os.path.isdir(os.path.join(root, name)):
            continue
        # 給每個類別編碼一個數字
        name2label[name] = len(name2label.keys())

    # 讀取Label信息;保存索引文件images.csv
    # [file1,file2,], 對應的標籤[3,1] 2個一一對應的list對象。
    # 根據目錄,把每個照片的路徑提取出來,以及每個照片路徑所對應的類別都存儲起來,存儲到CSV文件中。
    images, labels = load_csv(root, 'images.csv', name2label)

    # 圖片切割成,訓練70%,驗證15%,測試15%。
    if mode == 'train':                                                     # 70% 訓練集
        images = images[:int(0.7 * len(images))]
        labels = labels[:int(0.7 * len(labels))]
    elif mode == 'val':                                                     # 15% = 70%->85%  驗證集
        images = images[int(0.7 * len(images)):int(0.85 * len(images))]
        labels = labels[int(0.7 * len(labels)):int(0.85 * len(labels))]
    else:                                                                   # 15% = 70%->85%  測試集
        images = images[int(0.85 * len(images)):]
        labels = labels[int(0.85 * len(labels)):]

    return images, labels, name2label

# 數據normalize
# 下面這2個值均值和方差,怎麼得到的。其實是統計所有imagenet的圖片(幾百萬張)的均值和方差;
# 所有者2個數據比較有意義,因爲本質上所有圖片的分佈都和imagenet圖片的分佈基本一致。
# 這6個數據基本是通用的,網上一搜就能查到。
img_mean = tf.constant([0.485, 0.456, 0.406])
img_std = tf.constant([0.229, 0.224, 0.225])
def normalize(x, mean=img_mean, std=img_std):
    # x shape: [224, 224, 3]
    # mean:shape爲1;這裏用到了廣播機制。我們安裝好右邊對齊的原則,可以得到如下;
    # mean : [1, 1, 3], std: [3]        先插入1
    # mean : [224, 224, 3], std: [3]    再變爲224
    x = (x - mean)/std
    return x

# 數據normalize之後,這裏有一個反normalizaion操作。比如數據可視化的時候,需要反過來。
def denormalize(x, mean=img_mean, std=img_std):
    x = x * std + mean
    return x

def preprocess(x,y):
    # x: 圖片的路徑,
    # y:圖片的數字編碼
    x = tf.io.read_file(x)                  # 通過圖片路徑讀取圖片
    x = tf.image.decode_jpeg(x, channels=3) # RGBA 這裏注意有些圖片不止3個通道。還有A,透明通道。
    x = tf.image.resize(x, [244, 244])      # 圖片重置的,這裏224*224,剛好resnet大小匹配的,方便查看。

    # data augmentation, 0~255    首先做一個數據增強!這個操作必須在normalizaion之前(因爲是針對圖片的。)
    # x = tf.image.random_flip_up_down(x)   # 隨機的做一個上和下的翻轉。如果全都翻轉,相當於沒有增加。隨機選擇一部分翻轉。
    x= tf.image.random_flip_left_right(x)   # 隨機的做一個左和右的翻轉。
    # x = tf.image.random_crop(x, [224, 224, 3]) # 圖片裁剪,這裏注意這裏裁剪到224*224,所以resize不能是224,比如250,250不然什麼也沒做。

    # x: [0,255]=> 0~1 或者-0.5~0.5   其次:normalizaion
    x = tf.cast(x, dtype=tf.float32) / 255.
    # 0~1 => D(0,1) 調用函數;
    x = normalize(x)

    y = tf.convert_to_tensor(y)

    return x, y

def main():
    import  time
    images, labels, table = load_pokemon('/home/zhangkf/tf/TF2/TF2_8_data/pokeman', 'train')
    # 圖片的路徑
    print('images', len(images), images)
    # 圖片的標籤
    print('labels', len(labels), labels)
    # 編碼表,所對應的類別名字。
    print(table)
    # 數據集裝載
    db = tf.data.Dataset.from_tensor_slices((images, labels))
    # 數據集預處理
    db = db.shuffle(1000).map(preprocess).batch(32)
    # 我們做一個可視化,圖片可視化出來。
    writter = tf.summary.create_file_writer('log')

    for step, (x, y) in enumerate(db):
        # 這裏x的大小: [32, 224, 224, 3]
        # 這裏y: [32]
        with writter.as_default():
            tf.summary.image('img', x, step=step, max_outputs=9)  # 一次記錄9張圖片。
            time.sleep(5)                                         # 如果顯示感覺太快,每5秒刷新一次batch。

if __name__ == '__main__':
    main()

三. 網絡模型的搭建

3.1. TensorFlow2.0中的keras接口

  • TensorFlow2.0中創建模型是相當簡單的,因爲有TensorFlow2.0中有keras接口,所有的類只需要集成Model類就可以啦。然後在Model基礎上,創建一些子單元類,所有的類只需要繼承這個Model就可以啦。其次在前向傳播的過程中調用這些子單元就可以啦。去完成每個層的前向傳播就可以啦。
  • 之前寫過Resnet網絡結構,尤其寫過resnet18這個網絡。之前寫的resnet屬於一個精簡版本,有可能是沒有18層的。並且輸入肯定不是224×224的。當時做了很多的這樣修剪的工作。把輸入,輸出以及通道數量都做了一個精簡。下面要介紹的這個就是一個相對比較標準一些的啦。因爲Resnet本身沒有一個標準的實現。中間有些超參數可以根據經驗值自己修改一些。
import  os
import  tensorflow as tf
import  numpy as np
from    tensorflow import keras
from    tensorflow.keras import layers

tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

class ResnetBlock(keras.Model):

    def __init__(self, channels, strides=1):
        super(ResnetBlock, self).__init__()

        self.channels = channels
        self.strides = strides

        self.conv1 = layers.Conv2D(channels, 3, strides=strides,
                                   padding=[[0,0],[1,1],[1,1],[0,0]])
        self.bn1 = keras.layers.BatchNormalization()
        self.conv2 = layers.Conv2D(channels, 3, strides=1,
                                   padding=[[0,0],[1,1],[1,1],[0,0]])
        self.bn2 = keras.layers.BatchNormalization()

        if strides!=1:
            self.down_conv = layers.Conv2D(channels, 1, strides=strides, padding='valid')
            self.down_bn = tf.keras.layers.BatchNormalization()

    def call(self, inputs, training=None):
        residual = inputs

        x = self.conv1(inputs)
        x = tf.nn.relu(x)
        x = self.bn1(x, training=training)
        x = self.conv2(x)
        x = tf.nn.relu(x)
        x = self.bn2(x, training=training)

        # 殘差連接
        if self.strides!=1:
            residual = self.down_conv(inputs)
            residual = tf.nn.relu(residual)
            residual = self.down_bn(residual, training=training)

        x = x + residual
        x = tf.nn.relu(x)
        return x

# Resnet18的實現。
class ResNet(keras.Model):

    def __init__(self, num_classes, initial_filters=16, **kwargs):
        super(ResNet, self).__init__(**kwargs)

        self.stem = layers.Conv2D(initial_filters, 3, strides=3, padding='valid')

        # 一共包含8個ResnetBlock模塊,16+ 根連接的1層 + 輸出的1層。
        # 就是分成了4組,每組的第一個完成了高和寬的降維;第二個Strides=1就是維度保持不變。
        self.blocks = keras.models.Sequential([
            ResnetBlock(initial_filters * 2, strides=3),
            ResnetBlock(initial_filters * 2, strides=1),
            # layers.Dropout(rate=0.5),

            ResnetBlock(initial_filters * 4, strides=3),
            ResnetBlock(initial_filters * 4, strides=1),

            ResnetBlock(initial_filters * 8, strides=2),
            ResnetBlock(initial_filters * 8, strides=1),

            ResnetBlock(initial_filters * 16, strides=2),
            ResnetBlock(initial_filters * 16, strides=1),
        ])

        self.final_bn = layers.BatchNormalization()
        self.avg_pool = layers.GlobalMaxPool2D()
        self.fc = layers.Dense(num_classes)             # 全連接層

    def call(self, inputs, training=None):
        # print('x:',inputs.shape)
        out = self.stem(inputs)  # 根鏈接。
        out = tf.nn.relu(out)

        # print('stem:',out.shape)

        out = self.blocks(out, training=training)
        # print('res:',out.shape)

        out = self.final_bn(out, training=training)
        # out = tf.nn.relu(out)

        out = self.avg_pool(out)

        # print('avg_pool:',out.shape)
        out = self.fc(out)  # 分類層得到一個輸出。
        # print('out:',out.shape)
        return out

def main():
    num_classes = 5

    resnet18 = ResNet(5)
    resnet18.build(input_shape=(4,224,224,3))
    resnet18.summary()

if __name__ == '__main__':
    main()

3.2. 網絡模型的參數量

  • 輸出結果:可以發現網絡的參數量有280萬,可以訓練的參數量,還有不可以訓練的參數量。這裏不可以訓練的參數主要是因爲存在BatchNormalization層中,這些層中有一些統計的數據,這些統計的數據是不可以訓練的,這部分數據是根據運行的時候統計得到的,不參與網絡的反響傳播。接下來如何使用它呢?

四. 網絡的訓練工作

4.1. 小樣本很難訓練網絡

  • 從零開始訓練train_scratch.py
import os
import tensorflow as tf
import numpy as np

from tensorflow import keras
from tensorflow.python.keras.api._v2.keras import layers, optimizers, losses
from tensorflow.keras.callbacks import EarlyStopping

tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

# 導入一些具體的工具
from pokemon import  load_pokemon, normalize, denormalize
from resnet import ResNet                   # 導入模型

# 預處理的函數,複製過來。
def preprocess(x,y):
    # x: 圖片的路徑,y:圖片的數字編碼
    x = tf.io.read_file(x)
    x = tf.image.decode_jpeg(x, channels=3) # RGBA
    x = tf.image.resize(x, [244, 244])

    x = tf.image.random_flip_left_right(x)
    # x = tf.image.random_flip_up_down(x)
    x = tf.image.random_crop(x, [224,224,3])

    # x: [0,255]=> -1~1
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = normalize(x)
    y = tf.convert_to_tensor(y)
    y = tf.one_hot(y, depth=5)

    return x, y

batchsz = 8

# creat train db   一般訓練的時候需要shuffle。其它是不需要的。
images, labels, table = load_pokemon('/home/zhangkf/johnCodes/TF2/TF2_8_data/pokeman',mode='train')
db_train = tf.data.Dataset.from_tensor_slices((images, labels))  # 變成個Dataset對象。
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz) # map函數圖片路徑變爲內容。
# crate validation db
images2, labels2, table = load_pokemon('/home/zhangkf/johnCodes/TF2/TF2_8_data/pokeman',mode='val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
# create test db
images3, labels3, table = load_pokemon('/home/zhangkf/johnCodes/TF2/TF2_8_data/pokeman',mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)


# 訓練樣本太小了,resnet網絡表達能力很強。這裏換成4層小的網絡了。
# resnet = keras.Sequential([
#     layers.Conv2D(16,5,3),
#     layers.MaxPool2D(3,3),
#     layers.ReLU(),
#     layers.Conv2D(64,5,3),
#     layers.MaxPool2D(2,2),
#     layers.ReLU(),
#     layers.Flatten(),
#     layers.Dense(64),
#     layers.ReLU(),
#     layers.Dense(5)
# ])

# 首先創建Resnet18
resnet = ResNet(5)
resnet.build(input_shape=(batchsz, 224, 224, 3))
resnet.summary()

# monitor監聽器, 連續5個驗證準確率不增加,這個事情觸發。
early_stopping = EarlyStopping(
    monitor='val_accuracy',
    min_delta=0.001,
    patience=20

)

# 網絡的裝配。
resnet.compile(optimizer=optimizers.Adam(lr=1e-4),
               loss=losses.CategoricalCrossentropy(from_logits=True),
               metrics=['accuracy'])

# 完成標準的train,val, test;
# 標準的邏輯必須通過db_val挑選模型的參數,就需要提供一個earlystopping技術,
resnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=1000,
           callbacks=[early_stopping])   # 1個epoch驗證1次。觸發了這個事情,提前停止了。
resnet.evaluate(db_test)
  • 訓練結果
ssh://zhangkf@192.168.136.64:22/home/zhangkf/anaconda3/envs/tf2b/bin/python -u /home/zhangkf/johnCodes/TF2/TF2_8_data/train_scratch.py
WARNING: Logging before flag parsing goes to stderr.
W0828 10:51:31.044445 139752945207040 deprecation.py:323] From /home/zhangkf/anaconda3/envs/tf2b/lib/python3.7/site-packages/tensorflow/python/data/util/random_seed.py:58: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "res_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              multiple                  448       
_________________________________________________________________
sequential (Sequential)      multiple                  2797280   
_________________________________________________________________
batch_normalization_20 (Batc multiple                  1024      
_________________________________________________________________
global_max_pooling2d (Global multiple                  0         
_________________________________________________________________
dense (Dense)                multiple                  1285      
=================================================================
Total params: 2,800,037
Trainable params: 2,794,725
Non-trainable params: 5,312
_________________________________________________________________
Epoch 1/1000
88/88 [==============================] - 37s 425ms/step - loss: 2.3255 - accuracy: 0.2472 - val_loss: 1.6504 - val_accuracy: 0.1845
Epoch 2/1000
88/88 [==============================] - 24s 277ms/step - loss: 0.3819 - accuracy: 0.8250 - val_loss: 1.7054 - val_accuracy: 0.1760
Epoch 3/1000
88/88 [==============================] - 24s 273ms/step - loss: 0.0643 - accuracy: 0.9928 - val_loss: 1.7032 - val_accuracy: 0.1931
Epoch 4/1000
88/88 [==============================] - 24s 271ms/step - loss: 0.0198 - accuracy: 0.9949 - val_loss: 1.7639 - val_accuracy: 0.2318
Epoch 5/1000
88/88 [==============================] - 21s 242ms/step - loss: 0.0069 - accuracy: 1.0000 - val_loss: 1.8334 - val_accuracy: 0.2103
Epoch 6/1000
88/88 [==============================] - 23s 266ms/step - loss: 0.0046 - accuracy: 1.0000 - val_loss: 1.8315 - val_accuracy: 0.3004
Epoch 7/1000
88/88 [==============================] - 24s 271ms/step - loss: 0.0036 - accuracy: 1.0000 - val_loss: 1.9369 - val_accuracy: 0.3090
Epoch 8/1000
88/88 [==============================] - 23s 264ms/step - loss: 0.0029 - accuracy: 1.0000 - val_loss: 2.0374 - val_accuracy: 0.2918
Epoch 9/1000
88/88 [==============================] - 24s 268ms/step - loss: 0.0024 - accuracy: 1.0000 - val_loss: 2.0581 - val_accuracy: 0.3219
Epoch 10/1000
88/88 [==============================] - 24s 269ms/step - loss: 0.0021 - accuracy: 1.0000 - val_loss: 2.0795 - val_accuracy: 0.3219
Epoch 11/1000
88/88 [==============================] - 24s 268ms/step - loss: 0.0018 - accuracy: 1.0000 - val_loss: 2.0932 - val_accuracy: 0.3047
Epoch 12/1000
88/88 [==============================] - 24s 272ms/step - loss: 0.0016 - accuracy: 1.0000 - val_loss: 2.0997 - val_accuracy: 0.3047
Epoch 13/1000
88/88 [==============================] - 21s 243ms/step - loss: 0.0014 - accuracy: 1.0000 - val_loss: 2.1042 - val_accuracy: 0.3090
Epoch 14/1000
88/88 [==============================] - 23s 265ms/step - loss: 0.0012 - accuracy: 1.0000 - val_loss: 2.1075 - val_accuracy: 0.3090
Epoch 15/1000
88/88 [==============================] - 24s 270ms/step - loss: 0.0011 - accuracy: 1.0000 - val_loss: 2.1106 - val_accuracy: 0.3090
Epoch 16/1000
88/88 [==============================] - 23s 260ms/step - loss: 0.0010 - accuracy: 1.0000 - val_loss: 2.1136 - val_accuracy: 0.3090
Epoch 17/1000
88/88 [==============================] - 23s 267ms/step - loss: 9.1804e-04 - accuracy: 1.0000 - val_loss: 2.1162 - val_accuracy: 0.3090
Epoch 18/1000
88/88 [==============================] - 24s 268ms/step - loss: 8.3833e-04 - accuracy: 1.0000 - val_loss: 2.1183 - val_accuracy: 0.3090
Epoch 19/1000
88/88 [==============================] - 24s 272ms/step - loss: 7.6828e-04 - accuracy: 1.0000 - val_loss: 2.1211 - val_accuracy: 0.3047
Epoch 20/1000
88/88 [==============================] - 24s 269ms/step - loss: 7.0653e-04 - accuracy: 1.0000 - val_loss: 2.1231 - val_accuracy: 0.3047
Epoch 21/1000
88/88 [==============================] - 24s 268ms/step - loss: 6.5178e-04 - accuracy: 1.0000 - val_loss: 2.1257 - val_accuracy: 0.3047
Epoch 22/1000
88/88 [==============================] - 23s 262ms/step - loss: 6.0310e-04 - accuracy: 1.0000 - val_loss: 2.1278 - val_accuracy: 0.3047
Epoch 23/1000
88/88 [==============================] - 24s 268ms/step - loss: 5.5923e-04 - accuracy: 1.0000 - val_loss: 2.1298 - val_accuracy: 0.3047
Epoch 24/1000
88/88 [==============================] - 23s 267ms/step - loss: 5.1979e-04 - accuracy: 1.0000 - val_loss: 2.1319 - val_accuracy: 0.3047
Epoch 25/1000
88/88 [==============================] - 23s 259ms/step - loss: 4.8411e-04 - accuracy: 1.0000 - val_loss: 2.1338 - val_accuracy: 0.3047
Epoch 26/1000

4.2. 解決小樣本難訓練的方法

  • 問題:小樣本很難訓練,我們觀察到上面的一個現象,它在Training上面準確率很容易達到100%左右,但是在Val的驗證準確率很低20%左右。那就意味着這個網絡結構完全是沒有訓練好的。爲什麼會這樣呢?5個驗證準確率沒有增加0.01,觸發EarlyStopping操作,程序停止。

  • Resnet18上的驗證準確率很低,基本是不工作的網絡結構,這裏主要是因爲Resnet網絡結構太大了,網絡層數比較深,而且網絡參數量比較大。有什麼辦法解決這個問題?第一個最根本重要的是增加數據集imagenet是有幾百萬圖片。這裏寶卡夢精靈數據集每個類別纔有200多張圖片。

  • 第二個辦法:對網路做一些約束,比如把網絡層數減少,網絡參數參數量減少。或者增加一些正則化的手段,或者過擬合的手段。這些都可以嘗試一下。

  • 但是我們這裏的圖片實在是太少了,所以我們這裏採用另一種方式,直接換一個比較小型的網絡。對於數據集不夠的情況下這個小網絡往往能發揮出不可預測的效果。我們測試一下。網路小容易訓練。

  • 換好網絡結構之後的效果如下:測試準確率達到86%左右啦。

ssh://zhangkf@192.168.136.64:22/home/zhangkf/anaconda3/envs/tf2b/bin/python -u /home/zhangkf/johnCodes/TF2/TF2_8_data/train_scratch.py
WARNING: Logging before flag parsing goes to stderr.
W0828 10:33:44.677829 139788463359744 deprecation.py:323] From /home/zhangkf/anaconda3/envs/tf2b/lib/python3.7/site-packages/tensorflow/python/data/util/random_seed.py:58: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              multiple                  1216      
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple                  0         
_________________________________________________________________
re_lu (ReLU)                 multiple                  0         
_________________________________________________________________
conv2d_1 (Conv2D)            multiple                  25664     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 multiple                  0         
_________________________________________________________________
re_lu_1 (ReLU)               multiple                  0         
_________________________________________________________________
flatten (Flatten)            multiple                  0         
_________________________________________________________________
dense (Dense)                multiple                  36928     
_________________________________________________________________
re_lu_2 (ReLU)               multiple                  0         
_________________________________________________________________
dense_1 (Dense)              multiple                  325       
=================================================================
Total params: 64,133
Trainable params: 64,133
Non-trainable params: 0
_________________________________________________________________
Epoch 1/1000
88/88 [==============================] - 23s 266ms/step - loss: 1.4197 - accuracy: 0.3153 - val_loss: 1.1087 - val_accuracy: 0.6223
Epoch 2/1000
88/88 [==============================] - 22s 251ms/step - loss: 0.9204 - accuracy: 0.6749 - val_loss: 0.7547 - val_accuracy: 0.7940
Epoch 3/1000
88/88 [==============================] - 23s 266ms/step - loss: 0.6312 - accuracy: 0.8405 - val_loss: 0.5803 - val_accuracy: 0.8326
Epoch 4/1000
88/88 [==============================] - 19s 221ms/step - loss: 0.4844 - accuracy: 0.8674 - val_loss: 0.4981 - val_accuracy: 0.8369
Epoch 5/1000
88/88 [==============================] - 22s 253ms/step - loss: 0.4089 - accuracy: 0.8874 - val_loss: 0.4607 - val_accuracy: 0.8455
Epoch 6/1000
88/88 [==============================] - 23s 259ms/step - loss: 0.3629 - accuracy: 0.9137 - val_loss: 0.4383 - val_accuracy: 0.8584
Epoch 7/1000
88/88 [==============================] - 21s 243ms/step - loss: 0.3281 - accuracy: 0.9312 - val_loss: 0.4283 - val_accuracy: 0.8498
Epoch 8/1000
88/88 [==============================] - 24s 269ms/step - loss: 0.3019 - accuracy: 0.9329 - val_loss: 0.4155 - val_accuracy: 0.8541
Epoch 9/1000
88/88 [==============================] - 22s 249ms/step - loss: 0.2772 - accuracy: 0.9367 - val_loss: 0.4091 - val_accuracy: 0.8541
Epoch 10/1000
88/88 [==============================] - 23s 265ms/step - loss: 0.2553 - accuracy: 0.9413 - val_loss: 0.4041 - val_accuracy: 0.8541
Epoch 11/1000
88/88 [==============================] - 23s 259ms/step - loss: 0.2344 - accuracy: 0.9513 - val_loss: 0.4001 - val_accuracy: 0.8584
Epoch 12/1000
88/88 [==============================] - 22s 252ms/step - loss: 0.2156 - accuracy: 0.9560 - val_loss: 0.3960 - val_accuracy: 0.8584
Epoch 13/1000
88/88 [==============================] - 23s 257ms/step - loss: 0.1981 - accuracy: 0.9651 - val_loss: 0.3915 - val_accuracy: 0.8627
Epoch 14/1000
88/88 [==============================] - 23s 258ms/step - loss: 0.1804 - accuracy: 0.9686 - val_loss: 0.3909 - val_accuracy: 0.8627
Epoch 15/1000
88/88 [==============================] - 21s 237ms/step - loss: 0.1655 - accuracy: 0.9717 - val_loss: 0.3849 - val_accuracy: 0.8627
Epoch 16/1000
88/88 [==============================] - 24s 269ms/step - loss: 0.1501 - accuracy: 0.9758 - val_loss: 0.3862 - val_accuracy: 0.8627
Epoch 17/1000
88/88 [==============================] - 18s 206ms/step - loss: 0.1362 - accuracy: 0.9784 - val_loss: 0.3825 - val_accuracy: 0.8627
Epoch 18/1000
88/88 [==============================] - 19s 212ms/step - loss: 0.1225 - accuracy: 0.9788 - val_loss: 0.3804 - val_accuracy: 0.8627
Epoch 19/1000
88/88 [==============================] - 22s 254ms/step - loss: 0.1107 - accuracy: 0.9856 - val_loss: 0.3789 - val_accuracy: 0.8712
Epoch 20/1000
88/88 [==============================] - 23s 265ms/step - loss: 0.0996 - accuracy: 0.9883 - val_loss: 0.3785 - val_accuracy: 0.8755
Epoch 21/1000
88/88 [==============================] - 22s 249ms/step - loss: 0.0897 - accuracy: 0.9896 - val_loss: 0.3850 - val_accuracy: 0.8670
Epoch 22/1000
88/88 [==============================] - 22s 251ms/step - loss: 0.0808 - accuracy: 0.9935 - val_loss: 0.3869 - val_accuracy: 0.8670
Epoch 23/1000
88/88 [==============================] - 23s 258ms/step - loss: 0.0726 - accuracy: 0.9937 - val_loss: 0.3926 - val_accuracy: 0.8670
Epoch 24/1000
88/88 [==============================] - 23s 266ms/step - loss: 0.0652 - accuracy: 0.9937 - val_loss: 0.3958 - val_accuracy: 0.8627
Epoch 25/1000
88/88 [==============================] - 23s 260ms/step - loss: 0.0583 - accuracy: 0.9937 - val_loss: 0.3974 - val_accuracy: 0.8670
Epoch 26/1000
88/88 [==============================] - 20s 222ms/step - loss: 0.0521 - accuracy: 0.9958 - val_loss: 0.4033 - val_accuracy: 0.8670
Epoch 27/1000
88/88 [==============================] - 22s 255ms/step - loss: 0.0471 - accuracy: 0.9970 - val_loss: 0.4051 - val_accuracy: 0.8670
Epoch 28/1000
88/88 [==============================] - 23s 260ms/step - loss: 0.0424 - accuracy: 0.9986 - val_loss: 0.4089 - val_accuracy: 0.8627
Epoch 29/1000
88/88 [==============================] - 20s 231ms/step - loss: 0.0380 - accuracy: 0.9994 - val_loss: 0.4088 - val_accuracy: 0.8584
Epoch 30/1000
88/88 [==============================] - 20s 223ms/step - loss: 0.0342 - accuracy: 0.9994 - val_loss: 0.4104 - val_accuracy: 0.8627
Epoch 31/1000
88/88 [==============================] - 23s 264ms/step - loss: 0.0311 - accuracy: 0.9994 - val_loss: 0.4137 - val_accuracy: 0.8627
Epoch 32/1000
88/88 [==============================] - 19s 215ms/step - loss: 0.0278 - accuracy: 0.9994 - val_loss: 0.4183 - val_accuracy: 0.8627
Epoch 33/1000
88/88 [==============================] - 22s 253ms/step - loss: 0.0251 - accuracy: 0.9994 - val_loss: 0.4170 - val_accuracy: 0.8627
Epoch 34/1000
88/88 [==============================] - 22s 253ms/step - loss: 0.0226 - accuracy: 0.9994 - val_loss: 0.4236 - val_accuracy: 0.8584
Epoch 35/1000
88/88 [==============================] - 22s 251ms/step - loss: 0.0205 - accuracy: 0.9994 - val_loss: 0.4269 - val_accuracy: 0.8584
Epoch 36/1000
88/88 [==============================] - 23s 257ms/step - loss: 0.0187 - accuracy: 0.9994 - val_loss: 0.4276 - val_accuracy: 0.8584
Epoch 37/1000
88/88 [==============================] - 23s 264ms/step - loss: 0.0169 - accuracy: 0.9994 - val_loss: 0.4359 - val_accuracy: 0.8584
Epoch 38/1000
88/88 [==============================] - 23s 259ms/step - loss: 0.0153 - accuracy: 0.9994 - val_loss: 0.4345 - val_accuracy: 0.8584
Epoch 39/1000
88/88 [==============================] - 23s 262ms/step - loss: 0.0142 - accuracy: 0.9994 - val_loss: 0.4405 - val_accuracy: 0.8584
Epoch 40/1000
88/88 [==============================] - 21s 244ms/step - loss: 0.0129 - accuracy: 0.9994 - val_loss: 0.4409 - val_accuracy: 0.8627
30/30 [==============================] - 5s 154ms/step - loss: 0.4972 - accuracy: 0.8755

Process finished with exit code 0

五. 深度遷移學習

5.1. 遷移學習介紹+實戰

  • 通過上面的測試結果我們可以發現小型網絡的效果還很多,這確實是Renset網絡的表達能力太強了,並且數據集規模太小了,效果不太好。針對我們這種小型的數據集又希望取得好的效果(也能用上深層次的網絡結構),直接訓練有時候無法訓練起來,這裏我們有一個快速的手段叫做遷移學習
  • 代碼實戰
import os
import tensorflow as tf
import numpy as np

from tensorflow import keras
from tensorflow.python.keras.api._v2.keras import layers, optimizers, losses
# from tensorflow.keras.callbacks import EarlyStopping

tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

# 導入一些具體的工具
from pokemon import  load_pokemon, normalize, denormalize
from resnet import ResNet                               # 導入模型

# 預處理的函數,複製過來。
def preprocess(x,y):
    # x: 圖片的路徑,y:圖片的數字編碼
    x = tf.io.read_file(x)
    x = tf.image.decode_jpeg(x, channels=3)             # RGBA
    x = tf.image.resize(x, [256, 256])

    x = tf.image.random_flip_left_right(x)
    x = tf.image.random_flip_up_down(x)
    x = tf.image.random_crop(x, [224,224,3])

    # x: [0,255]=> 0~1
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = normalize(x)
    y = tf.convert_to_tensor(y)
    y = tf.one_hot(y, depth=5)

    return x, y

batchsz = 8

# creat train db  一般訓練的時候需要shuffle。其它是不需要的。
images, labels, table = load_pokemon('/home/zhangkf/tf/TF2/TF2_8_data/pokeman',mode='train')
db_train = tf.data.Dataset.from_tensor_slices((images, labels))     # 變成個Dataset對象。
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)    # map函數圖片路徑變爲內容。
# crate validation db
images2, labels2, table = load_pokemon('/home/zhangkf/tf/TF2/TF2_8_data/pokeman',mode='val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
# create test db
images3, labels3, table = load_pokemon('/home/zhangkf/tf/TF2/TF2_8_data/pokeman',mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)

# 導入別的已經訓練好的網絡和參數, 這部分工作在keras網絡中提供了一些經典的網絡以及經典網絡訓練好的參數。
# 這裏使用Vgg19,還把他的權值導入進來。imagenet訓練的1000類,我們就把輸出層去掉。
net = keras.applications.VGG19(weights='imagenet', include_top=False,
                               pooling='max')

net.trainable = False;                                  # 把這部分老的網絡,不需要參與反向更新。不訓練。

newnet = keras.Sequential([net, layers.Dense(5)])

newnet.build(input_shape=(batchsz, 224, 224, 3))
newnet.summary()

# monitor監聽器, 連續5個驗證準確率不增加,這個事情觸發。
# early_stopping:當驗證集損失值,連續增加小於0時,持續10個epoch,則終止訓練。
early_stopping = keras.callbacks.EarlyStopping(monitor='val_accuracy',
                                               min_delta=0.00001,
                                               patience=10, verbose=1)

# reduce_lr:當評價指標不在提升時,減少學習率,每次減少10%,當驗證損失值,持續3次未減少時,則終止訓練。
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', factor=0.1,
                                              patience=10, min_lr=0.000001, verbose=1)

# 網絡的裝配。
newnet.compile(optimizer=optimizers.Adam(lr=1e-4), loss=losses.CategoricalCrossentropy(from_logits=True),
               metrics=['accuracy'])

# 完成標準的train,val, test; 標準的邏輯必須通過db_val挑選模型的參數,就需要提供一個earlystopping技術,
newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=500,
           callbacks=[early_stopping, reduce_lr])   # 1個epoch驗證1次。觸發了這個事情,提前停止了。
newnet.evaluate(db_test)

  • 運行結果
ssh://zhangkf@192.168.136.55:22/home/zhangkf/anaconda3/envs/tf2c/bin/python -u /home/zhangkf/tf/TF2/TF2_8_data/train_transfer.py
WARNING:tensorflow:From /home/zhangkf/anaconda3/envs/tf2c/lib/python3.7/site-packages/tensorflow_core/python/data/util/random_seed.py:58: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
vgg19 (Model)                (None, 512)               20024384  
_________________________________________________________________
dense (Dense)                (None, 5)                 2565      
=================================================================
Total params: 20,026,949
Trainable params: 2,565
Non-trainable params: 20,024,384
_________________________________________________________________
Epoch 1/500
102/102 [==============================] - 23s 229ms/step - loss: 1.9337 - accuracy: 0.2712 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
Epoch 2/500
102/102 [==============================] - 20s 200ms/step - loss: 1.5340 - accuracy: 0.3313 - val_loss: 1.5425 - val_accuracy: 0.3429
Epoch 3/500
102/102 [==============================] - 19s 190ms/step - loss: 1.4445 - accuracy: 0.4000 - val_loss: 1.4317 - val_accuracy: 0.3886
Epoch 4/500
102/102 [==============================] - 21s 206ms/step - loss: 1.3586 - accuracy: 0.4356 - val_loss: 1.3326 - val_accuracy: 0.4686
Epoch 5/500
102/102 [==============================] - 21s 207ms/step - loss: 1.2486 - accuracy: 0.5387 - val_loss: 1.2367 - val_accuracy: 0.5314
Epoch 6/500
102/102 [==============================] - 21s 207ms/step - loss: 1.1888 - accuracy: 0.5472 - val_loss: 1.1698 - val_accuracy: 0.5657
Epoch 7/500
102/102 [==============================] - 19s 190ms/step - loss: 1.1045 - accuracy: 0.6160 - val_loss: 1.0988 - val_accuracy: 0.5771
Epoch 8/500
102/102 [==============================] - 20s 195ms/step - loss: 1.0337 - accuracy: 0.6528 - val_loss: 1.0467 - val_accuracy: 0.6343
Epoch 9/500
102/102 [==============================] - 21s 208ms/step - loss: 0.9766 - accuracy: 0.6675 - val_loss: 0.9844 - val_accuracy: 0.6571
Epoch 10/500
102/102 [==============================] - 21s 207ms/step - loss: 0.9453 - accuracy: 0.6957 - val_loss: 0.9329 - val_accuracy: 0.6686
Epoch 11/500
102/102 [==============================] - 22s 211ms/step - loss: 0.8816 - accuracy: 0.7387 - val_loss: 0.8985 - val_accuracy: 0.6629
Epoch 12/500
102/102 [==============================] - 21s 203ms/step - loss: 0.8345 - accuracy: 0.7583 - val_loss: 0.8516 - val_accuracy: 0.7314
Epoch 13/500
102/102 [==============================] - 20s 198ms/step - loss: 0.7950 - accuracy: 0.7779 - val_loss: 0.8211 - val_accuracy: 0.7714
Epoch 14/500
102/102 [==============================] - 19s 187ms/step - loss: 0.7744 - accuracy: 0.7840 - val_loss: 0.7846 - val_accuracy: 0.7771
Epoch 15/500
102/102 [==============================] - 19s 188ms/step - loss: 0.7308 - accuracy: 0.7877 - val_loss: 0.7560 - val_accuracy: 0.7886
Epoch 16/500
102/102 [==============================] - 21s 210ms/step - loss: 0.7201 - accuracy: 0.8049 - val_loss: 0.7321 - val_accuracy: 0.7943
Epoch 17/500
102/102 [==============================] - 19s 189ms/step - loss: 0.6888 - accuracy: 0.8184 - val_loss: 0.7135 - val_accuracy: 0.8000
Epoch 18/500
102/102 [==============================] - 21s 204ms/step - loss: 0.6566 - accuracy: 0.8233 - val_loss: 0.6860 - val_accuracy: 0.8286
Epoch 19/500
102/102 [==============================] - 21s 208ms/step - loss: 0.6666 - accuracy: 0.8172 - val_loss: 0.6698 - val_accuracy: 0.8286
Epoch 20/500
102/102 [==============================] - 21s 203ms/step - loss: 0.6201 - accuracy: 0.8356 - val_loss: 0.6488 - val_accuracy: 0.8343
Epoch 21/500
102/102 [==============================] - 21s 206ms/step - loss: 0.5972 - accuracy: 0.8564 - val_loss: 0.6338 - val_accuracy: 0.8229
Epoch 22/500
102/102 [==============================] - 21s 210ms/step - loss: 0.5730 - accuracy: 0.8589 - val_loss: 0.6138 - val_accuracy: 0.8343
Epoch 23/500
102/102 [==============================] - 20s 192ms/step - loss: 0.5612 - accuracy: 0.8589 - val_loss: 0.6027 - val_accuracy: 0.8571
Epoch 24/500
102/102 [==============================] - 19s 186ms/step - loss: 0.5578 - accuracy: 0.8528 - val_loss: 0.5859 - val_accuracy: 0.8514
Epoch 25/500
102/102 [==============================] - 21s 203ms/step - loss: 0.5541 - accuracy: 0.8724 - val_loss: 0.5740 - val_accuracy: 0.8457
Epoch 26/500
102/102 [==============================] - 18s 175ms/step - loss: 0.5192 - accuracy: 0.8712 - val_loss: 0.5615 - val_accuracy: 0.8400
Epoch 27/500
102/102 [==============================] - 21s 209ms/step - loss: 0.5069 - accuracy: 0.8748 - val_loss: 0.5524 - val_accuracy: 0.8514
Epoch 28/500
102/102 [==============================] - 19s 186ms/step - loss: 0.4829 - accuracy: 0.8834 - val_loss: 0.5423 - val_accuracy: 0.8571
Epoch 29/500
102/102 [==============================] - 20s 194ms/step - loss: 0.4975 - accuracy: 0.8773 - val_loss: 0.5335 - val_accuracy: 0.8571
Epoch 30/500
102/102 [==============================] - 19s 188ms/step - loss: 0.4687 - accuracy: 0.8847 - val_loss: 0.5202 - val_accuracy: 0.8514
Epoch 31/500
102/102 [==============================] - 21s 205ms/step - loss: 0.4637 - accuracy: 0.8834 - val_loss: 0.5124 - val_accuracy: 0.8571
Epoch 32/500
102/102 [==============================] - 21s 203ms/step - loss: 0.4791 - accuracy: 0.8687 - val_loss: 0.5027 - val_accuracy: 0.8571
Epoch 33/500
102/102 [==============================] - 21s 208ms/step - loss: 0.4606 - accuracy: 0.8724 - val_loss: 0.4952 - val_accuracy: 0.8629
Epoch 34/500
102/102 [==============================] - 21s 207ms/step - loss: 0.4491 - accuracy: 0.8798 - val_loss: 0.4883 - val_accuracy: 0.8514
Epoch 35/500
102/102 [==============================] - 19s 187ms/step - loss: 0.4408 - accuracy: 0.8871 - val_loss: 0.4812 - val_accuracy: 0.8629
Epoch 36/500
102/102 [==============================] - 21s 209ms/step - loss: 0.4296 - accuracy: 0.8982 - val_loss: 0.4754 - val_accuracy: 0.8571
Epoch 37/500
102/102 [==============================] - 22s 214ms/step - loss: 0.4021 - accuracy: 0.9117 - val_loss: 0.4693 - val_accuracy: 0.8629
Epoch 38/500
102/102 [==============================] - 21s 203ms/step - loss: 0.4055 - accuracy: 0.9080 - val_loss: 0.4641 - val_accuracy: 0.8571
Epoch 39/500
102/102 [==============================] - 21s 205ms/step - loss: 0.3998 - accuracy: 0.9117 - val_loss: 0.4572 - val_accuracy: 0.8686
Epoch 40/500
102/102 [==============================] - 21s 210ms/step - loss: 0.4020 - accuracy: 0.8982 - val_loss: 0.4535 - val_accuracy: 0.8686
Epoch 41/500
102/102 [==============================] - 20s 199ms/step - loss: 0.3919 - accuracy: 0.9166 - val_loss: 0.4447 - val_accuracy: 0.8743
Epoch 42/500
102/102 [==============================] - 22s 213ms/step - loss: 0.3676 - accuracy: 0.9141 - val_loss: 0.4423 - val_accuracy: 0.8800
Epoch 43/500
102/102 [==============================] - 20s 201ms/step - loss: 0.3720 - accuracy: 0.9092 - val_loss: 0.4341 - val_accuracy: 0.8743
Epoch 44/500
102/102 [==============================] - 21s 204ms/step - loss: 0.3682 - accuracy: 0.9104 - val_loss: 0.4324 - val_accuracy: 0.8857
Epoch 45/500
102/102 [==============================] - 21s 206ms/step - loss: 0.3680 - accuracy: 0.9166 - val_loss: 0.4234 - val_accuracy: 0.8857
Epoch 46/500
102/102 [==============================] - 19s 191ms/step - loss: 0.3553 - accuracy: 0.9141 - val_loss: 0.4211 - val_accuracy: 0.8914
Epoch 47/500
102/102 [==============================] - 18s 172ms/step - loss: 0.3507 - accuracy: 0.9190 - val_loss: 0.4184 - val_accuracy: 0.8914
Epoch 48/500
102/102 [==============================] - 21s 210ms/step - loss: 0.3640 - accuracy: 0.9141 - val_loss: 0.4158 - val_accuracy: 0.8971
Epoch 49/500
102/102 [==============================] - 21s 210ms/step - loss: 0.3378 - accuracy: 0.9239 - val_loss: 0.4075 - val_accuracy: 0.8971
Epoch 50/500
102/102 [==============================] - 21s 209ms/step - loss: 0.3480 - accuracy: 0.9129 - val_loss: 0.4031 - val_accuracy: 0.8914
Epoch 51/500
102/102 [==============================] - 21s 209ms/step - loss: 0.3298 - accuracy: 0.9325 - val_loss: 0.3978 - val_accuracy: 0.8971
Epoch 52/500
102/102 [==============================] - 18s 175ms/step - loss: 0.3354 - accuracy: 0.9227 - val_loss: 0.3940 - val_accuracy: 0.9029
Epoch 53/500
102/102 [==============================] - 21s 210ms/step - loss: 0.3168 - accuracy: 0.9239 - val_loss: 0.3900 - val_accuracy: 0.8971
Epoch 54/500
102/102 [==============================] - 21s 202ms/step - loss: 0.3190 - accuracy: 0.9264 - val_loss: 0.3909 - val_accuracy: 0.9086
Epoch 55/500
102/102 [==============================] - 21s 206ms/step - loss: 0.3206 - accuracy: 0.9264 - val_loss: 0.3866 - val_accuracy: 0.9086
Epoch 56/500
102/102 [==============================] - 19s 184ms/step - loss: 0.3071 - accuracy: 0.9227 - val_loss: 0.3831 - val_accuracy: 0.9029
Epoch 57/500
102/102 [==============================] - 22s 211ms/step - loss: 0.2999 - accuracy: 0.9362 - val_loss: 0.3784 - val_accuracy: 0.9029
Epoch 58/500
102/102 [==============================] - 22s 212ms/step - loss: 0.2993 - accuracy: 0.9276 - val_loss: 0.3777 - val_accuracy: 0.9029
Epoch 59/500
102/102 [==============================] - 21s 208ms/step - loss: 0.3060 - accuracy: 0.9239 - val_loss: 0.3744 - val_accuracy: 0.9143
Epoch 60/500
102/102 [==============================] - 21s 206ms/step - loss: 0.2913 - accuracy: 0.9362 - val_loss: 0.3762 - val_accuracy: 0.9086
Epoch 61/500
102/102 [==============================] - 21s 208ms/step - loss: 0.2801 - accuracy: 0.9325 - val_loss: 0.3692 - val_accuracy: 0.9143
Epoch 62/500
102/102 [==============================] - 21s 204ms/step - loss: 0.3024 - accuracy: 0.9288 - val_loss: 0.3635 - val_accuracy: 0.9029
Epoch 63/500
102/102 [==============================] - 21s 208ms/step - loss: 0.2828 - accuracy: 0.9350 - val_loss: 0.3649 - val_accuracy: 0.9143
Epoch 64/500
102/102 [==============================] - 21s 203ms/step - loss: 0.2768 - accuracy: 0.9448 - val_loss: 0.3578 - val_accuracy: 0.9086
Epoch 65/500
102/102 [==============================] - 19s 187ms/step - loss: 0.2821 - accuracy: 0.9362 - val_loss: 0.3578 - val_accuracy: 0.9143
Epoch 66/500
102/102 [==============================] - 20s 192ms/step - loss: 0.2714 - accuracy: 0.9387 - val_loss: 0.3557 - val_accuracy: 0.9143
Epoch 67/500
102/102 [==============================] - 21s 208ms/step - loss: 0.2644 - accuracy: 0.9448 - val_loss: 0.3569 - val_accuracy: 0.9257
Epoch 68/500
102/102 [==============================] - 22s 211ms/step - loss: 0.2700 - accuracy: 0.9350 - val_loss: 0.3539 - val_accuracy: 0.9143
Epoch 69/500
102/102 [==============================] - 21s 202ms/step - loss: 0.2668 - accuracy: 0.9448 - val_loss: 0.3459 - val_accuracy: 0.9143
Epoch 70/500
102/102 [==============================] - 21s 202ms/step - loss: 0.2727 - accuracy: 0.9288 - val_loss: 0.3489 - val_accuracy: 0.9143
Epoch 71/500
102/102 [==============================] - 20s 194ms/step - loss: 0.2658 - accuracy: 0.9227 - val_loss: 0.3445 - val_accuracy: 0.9029
Epoch 72/500
102/102 [==============================] - 21s 201ms/step - loss: 0.2586 - accuracy: 0.9399 - val_loss: 0.3421 - val_accuracy: 0.9143
Epoch 73/500
102/102 [==============================] - 21s 207ms/step - loss: 0.2546 - accuracy: 0.9399 - val_loss: 0.3439 - val_accuracy: 0.9086
Epoch 74/500
102/102 [==============================] - 21s 208ms/step - loss: 0.2602 - accuracy: 0.9399 - val_loss: 0.3392 - val_accuracy: 0.9143
Epoch 75/500
102/102 [==============================] - 17s 171ms/step - loss: 0.2507 - accuracy: 0.9423 - val_loss: 0.3401 - val_accuracy: 0.9143
Epoch 76/500
102/102 [==============================] - 18s 177ms/step - loss: 0.2480 - accuracy: 0.9411 - val_loss: 0.3362 - val_accuracy: 0.9257
Epoch 77/500
102/102 [==============================] - 21s 208ms/step - loss: 0.2381 - accuracy: 0.9436 - val_loss: 0.3354 - val_accuracy: 0.9143
Epoch 78/500
102/102 [==============================] - 21s 209ms/step - loss: 0.2550 - accuracy: 0.9362 - val_loss: 0.3333 - val_accuracy: 0.9143
Epoch 79/500
102/102 [==============================] - 21s 208ms/step - loss: 0.2428 - accuracy: 0.9423 - val_loss: 0.3319 - val_accuracy: 0.9143
Epoch 80/500
102/102 [==============================] - 20s 200ms/step - loss: 0.2451 - accuracy: 0.9374 - val_loss: 0.3309 - val_accuracy: 0.9143
Epoch 81/500
102/102 [==============================] - 21s 209ms/step - loss: 0.2368 - accuracy: 0.9534 - val_loss: 0.3325 - val_accuracy: 0.9086
Epoch 82/500
102/102 [==============================] - 20s 193ms/step - loss: 0.2211 - accuracy: 0.9497 - val_loss: 0.3303 - val_accuracy: 0.9143
Epoch 83/500
102/102 [==============================] - 19s 182ms/step - loss: 0.2301 - accuracy: 0.9436 - val_loss: 0.3286 - val_accuracy: 0.9143
Epoch 84/500
102/102 [==============================] - 21s 208ms/step - loss: 0.2339 - accuracy: 0.9534 - val_loss: 0.3254 - val_accuracy: 0.9086
Epoch 85/500
102/102 [==============================] - 20s 197ms/step - loss: 0.2253 - accuracy: 0.9436 - val_loss: 0.3277 - val_accuracy: 0.9143
Epoch 86/500
102/102 [==============================] - 21s 203ms/step - loss: 0.2361 - accuracy: 0.9411 - val_loss: 0.3255 - val_accuracy: 0.9086
Epoch 87/500
102/102 [==============================] - 21s 210ms/step - loss: 0.2299 - accuracy: 0.9399 - val_loss: 0.3215 - val_accuracy: 0.9086
Epoch 88/500
102/102 [==============================] - 21s 206ms/step - loss: 0.2198 - accuracy: 0.9607 - val_loss: 0.3256 - val_accuracy: 0.9143
Epoch 89/500
102/102 [==============================] - 21s 203ms/step - loss: 0.2258 - accuracy: 0.9509 - val_loss: 0.3195 - val_accuracy: 0.9086
Epoch 90/500
102/102 [==============================] - 17s 167ms/step - loss: 0.2184 - accuracy: 0.9521 - val_loss: 0.3158 - val_accuracy: 0.9143
Epoch 91/500
102/102 [==============================] - 21s 207ms/step - loss: 0.2180 - accuracy: 0.9485 - val_loss: 0.3214 - val_accuracy: 0.9143
Epoch 92/500
102/102 [==============================] - 21s 206ms/step - loss: 0.2150 - accuracy: 0.9485 - val_loss: 0.3144 - val_accuracy: 0.9143
Epoch 93/500
102/102 [==============================] - 19s 182ms/step - loss: 0.2085 - accuracy: 0.9534 - val_loss: 0.3179 - val_accuracy: 0.9143
Epoch 94/500
102/102 [==============================] - 19s 183ms/step - loss: 0.2297 - accuracy: 0.9497 - val_loss: 0.3131 - val_accuracy: 0.9143
Epoch 95/500
102/102 [==============================] - 19s 184ms/step - loss: 0.2074 - accuracy: 0.9583 - val_loss: 0.3131 - val_accuracy: 0.9143
Epoch 96/500
102/102 [==============================] - 17s 169ms/step - loss: 0.2145 - accuracy: 0.9411 - val_loss: 0.3157 - val_accuracy: 0.9086
Epoch 97/500
101/102 [============================>.] - ETA: 0s - loss: 0.2023 - accuracy: 0.9604
Epoch 00097: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-06.
102/102 [==============================] - 21s 209ms/step - loss: 0.2023 - accuracy: 0.9607 - val_loss: 0.3101 - val_accuracy: 0.9143
Epoch 00097: early stopping
22/22 [==============================] - 3s 158ms/step - loss: 0.3629 - accuracy: 0.8971

Process finished with exit code 0

六. 自己的進一步改進工作

6.1. 可以改進的技巧

import os
import tensorflow as tf
import numpy as np

from tensorflow import keras
from tensorflow.python.keras.api._v2.keras import layers, optimizers, losses
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')      # 判斷tf的版本是否是以‘2.’開頭,如果是,則返回True,否則返回False

# 導入一些具體的工具
from pokemon import  load_pokemon, normalize, denormalize
from resnet import ResNet                               # 導入模型

# 預處理的函數,複製過來。
def preprocess(x,y):
    # x: 圖片的路徑,y:圖片的數字編碼
    x = tf.io.read_file(x)
    x = tf.image.decode_jpeg(x, channels=3)             # RGBA
    x = tf.image.resize(x, [256, 256])

    x = tf.image.random_flip_left_right(x)
    # x = tf.image.random_flip_up_down(x)
    x = tf.image.random_brightness(x, max_delta=0.5)    # 在某範圍隨機調整圖片亮度
    x = tf.image.random_contrast(x, 0.1, 0.6)           # 在某範圍隨機調整圖片對比度
    x = tf.image.random_crop(x, [224,224,3])

    # x: [0,255]=> 0~1
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = normalize(x)
    y = tf.convert_to_tensor(y)
    y = tf.one_hot(y, depth=5)

    return x, y

###########################################################################################################
batchsz = 16

# creat train db  一般訓練的時候需要shuffle。其它是不需要的。
images, labels, table = load_pokemon('/home/zhangkf/tf/TF2/TF2_8_data/pokeman',mode='train')
db_train = tf.data.Dataset.from_tensor_slices((images, labels))     # 變成個Dataset對象。
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)    # map函數圖片路徑變爲內容。
# crate validation db
images2, labels2, table = load_pokemon('/home/zhangkf/tf/TF2/TF2_8_data/pokeman',mode='val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
# create test db
images3, labels3, table = load_pokemon('/home/zhangkf/tf/TF2/TF2_8_data/pokeman',mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)

###########################################################################################################
# 導入別的已經訓練好的網絡和參數, 這部分工作在keras網絡中提供了一些經典的網絡以及經典網絡訓練好的參數。
# 這裏使用Vgg19,還把他的權值導入進來。imagenet訓練的1000類,我們就把輸出層去掉。
net = keras.applications.VGG19(weights='imagenet',
                               include_top=False,
                               pooling='max')

# net.trainable = False                             # 把這部分老的網絡,不需要參與反向更新。不訓練。爲了更好的適應,我下面讓2層可以訓練;
for i in range(len(net.layers)-4):                  # print(len(model.layers))=23
    net.layers[i].trainable = False

model = keras.Sequential([net, layers.Dense(5)])

model.build(input_shape=(None, 224, 224, 3))
model.summary()

# early_stopping:monitor監聽器,當驗證集損失值,連續增加小於0時,持續10個epoch,則終止訓練。
early_stopping = EarlyStopping(monitor='val_accuracy',
                               min_delta=0.00001,
                               patience=30, verbose=1)

# reduce_lr:當評價指標不在提升時,減少學習率,每次減少10%,當驗證損失值,持續3次未減少時,則終止訓練。
reduce_lr = ReduceLROnPlateau(monitor='val_accuracy', factor=0.02,
                              patience=30, min_lr=0.0000001, verbose=1)

###########################################################################################################
model.compile(optimizer=optimizers.Adam(lr=1e-4),
              loss=losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy'])  # 損失函數

model.fit(db_train, validation_data=db_val, validation_freq=1, epochs=1000,
          initial_epoch=0, callbacks=[early_stopping, reduce_lr])                           # 1個epoch驗證1次

model.evaluate(db_test)

  • 運行結果:
ssh://zhangkf@192.168.136.55:22/home/zhangkf/anaconda3/envs/tf2c/bin/python -u /home/zhangkf/tf/TF2/TF2_8_data/train_transfer.py
WARNING:tensorflow:From /home/zhangkf/anaconda3/envs/tf2c/lib/python3.7/site-packages/tensorflow_core/python/data/util/random_seed.py:58: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
vgg19 (Model)                (None, 512)               20024384  
_________________________________________________________________
dense (Dense)                (None, 5)                 2565      
=================================================================
Total params: 20,026,949
Trainable params: 4,722,181
Non-trainable params: 15,304,768
_________________________________________________________________
Epoch 1/1000
44/44 [==============================] - 23s 515ms/step - loss: 0.8220 - accuracy: 0.7182 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
Epoch 2/1000
44/44 [==============================] - 19s 424ms/step - loss: 0.2022 - accuracy: 0.9385 - val_loss: 0.3446 - val_accuracy: 0.8798
Epoch 3/1000
44/44 [==============================] - 18s 401ms/step - loss: 0.1196 - accuracy: 0.9642 - val_loss: 0.2623 - val_accuracy: 0.9313
Epoch 4/1000
44/44 [==============================] - 19s 436ms/step - loss: 0.0686 - accuracy: 0.9871 - val_loss: 0.2385 - val_accuracy: 0.9185
Epoch 5/1000
44/44 [==============================] - 20s 443ms/step - loss: 0.0537 - accuracy: 0.9857 - val_loss: 0.2634 - val_accuracy: 0.9227
Epoch 6/1000
44/44 [==============================] - 19s 427ms/step - loss: 0.0738 - accuracy: 0.9886 - val_loss: 0.2965 - val_accuracy: 0.9185
Epoch 7/1000
44/44 [==============================] - 20s 461ms/step - loss: 0.0663 - accuracy: 0.9900 - val_loss: 0.3312 - val_accuracy: 0.9099
Epoch 8/1000
44/44 [==============================] - 19s 428ms/step - loss: 0.0757 - accuracy: 0.9871 - val_loss: 0.2587 - val_accuracy: 0.9270
Epoch 9/1000
44/44 [==============================] - 21s 467ms/step - loss: 0.0618 - accuracy: 0.9886 - val_loss: 0.2082 - val_accuracy: 0.9442
Epoch 10/1000
44/44 [==============================] - 20s 456ms/step - loss: 0.0597 - accuracy: 0.9843 - val_loss: 0.3669 - val_accuracy: 0.9185
Epoch 11/1000
44/44 [==============================] - 20s 444ms/step - loss: 0.0968 - accuracy: 0.9800 - val_loss: 0.2210 - val_accuracy: 0.9399
Epoch 12/1000
44/44 [==============================] - 19s 440ms/step - loss: 0.0702 - accuracy: 0.9871 - val_loss: 0.2665 - val_accuracy: 0.9356
Epoch 13/1000
44/44 [==============================] - 18s 420ms/step - loss: 0.0507 - accuracy: 0.9886 - val_loss: 0.2004 - val_accuracy: 0.9399
Epoch 14/1000
44/44 [==============================] - 19s 421ms/step - loss: 0.0483 - accuracy: 0.9900 - val_loss: 0.2526 - val_accuracy: 0.9270
Epoch 15/1000
44/44 [==============================] - 19s 442ms/step - loss: 0.0334 - accuracy: 0.9886 - val_loss: 0.2460 - val_accuracy: 0.9227
Epoch 16/1000
44/44 [==============================] - 20s 457ms/step - loss: 0.0555 - accuracy: 0.9900 - val_loss: 0.6035 - val_accuracy: 0.8670
Epoch 17/1000
44/44 [==============================] - 19s 438ms/step - loss: 0.0333 - accuracy: 0.9886 - val_loss: 0.2176 - val_accuracy: 0.9442
Epoch 18/1000
44/44 [==============================] - 18s 416ms/step - loss: 0.0679 - accuracy: 0.9914 - val_loss: 0.2387 - val_accuracy: 0.9399
Epoch 19/1000
44/44 [==============================] - 20s 454ms/step - loss: 0.0813 - accuracy: 0.9857 - val_loss: 0.3338 - val_accuracy: 0.9227
Epoch 20/1000
44/44 [==============================] - 18s 402ms/step - loss: 0.0326 - accuracy: 0.9914 - val_loss: 0.3881 - val_accuracy: 0.8970
Epoch 21/1000
44/44 [==============================] - 19s 437ms/step - loss: 0.0356 - accuracy: 0.9914 - val_loss: 0.3823 - val_accuracy: 0.9227
Epoch 22/1000
44/44 [==============================] - 18s 413ms/step - loss: 0.0394 - accuracy: 0.9914 - val_loss: 0.2497 - val_accuracy: 0.9571
Epoch 23/1000
44/44 [==============================] - 20s 458ms/step - loss: 0.0553 - accuracy: 0.9871 - val_loss: 0.2874 - val_accuracy: 0.9313
Epoch 24/1000
44/44 [==============================] - 18s 411ms/step - loss: 0.0331 - accuracy: 0.9914 - val_loss: 0.2256 - val_accuracy: 0.9442
Epoch 25/1000
44/44 [==============================] - 20s 456ms/step - loss: 0.0344 - accuracy: 0.9928 - val_loss: 0.2680 - val_accuracy: 0.9313
Epoch 26/1000
44/44 [==============================] - 20s 452ms/step - loss: 0.0261 - accuracy: 0.9928 - val_loss: 0.2897 - val_accuracy: 0.9313
Epoch 27/1000
44/44 [==============================] - 20s 459ms/step - loss: 0.0496 - accuracy: 0.9886 - val_loss: 0.3291 - val_accuracy: 0.9227
Epoch 28/1000
44/44 [==============================] - 20s 457ms/step - loss: 0.0770 - accuracy: 0.9871 - val_loss: 0.3221 - val_accuracy: 0.9056
Epoch 29/1000
44/44 [==============================] - 20s 449ms/step - loss: 0.0324 - accuracy: 0.9943 - val_loss: 0.1766 - val_accuracy: 0.9614
Epoch 30/1000
44/44 [==============================] - 18s 408ms/step - loss: 0.0417 - accuracy: 0.9900 - val_loss: 0.2819 - val_accuracy: 0.9227
Epoch 31/1000
44/44 [==============================] - 19s 428ms/step - loss: 0.0350 - accuracy: 0.9900 - val_loss: 0.1817 - val_accuracy: 0.9528
Epoch 32/1000
44/44 [==============================] - 19s 438ms/step - loss: 0.0346 - accuracy: 0.9914 - val_loss: 0.2838 - val_accuracy: 0.9270
Epoch 33/1000
44/44 [==============================] - 19s 430ms/step - loss: 0.0441 - accuracy: 0.9900 - val_loss: 0.2502 - val_accuracy: 0.9313
Epoch 34/1000
44/44 [==============================] - 18s 418ms/step - loss: 0.0187 - accuracy: 0.9928 - val_loss: 0.2004 - val_accuracy: 0.9571
Epoch 35/1000
44/44 [==============================] - 17s 397ms/step - loss: 0.0319 - accuracy: 0.9943 - val_loss: 0.4355 - val_accuracy: 0.9099
Epoch 36/1000
44/44 [==============================] - 20s 447ms/step - loss: 0.0373 - accuracy: 0.9886 - val_loss: 0.1846 - val_accuracy: 0.9571
Epoch 37/1000
44/44 [==============================] - 19s 426ms/step - loss: 0.0275 - accuracy: 0.9943 - val_loss: 0.2332 - val_accuracy: 0.9442
Epoch 38/1000
44/44 [==============================] - 20s 454ms/step - loss: 0.0203 - accuracy: 0.9914 - val_loss: 0.2743 - val_accuracy: 0.9356
Epoch 39/1000
44/44 [==============================] - 19s 430ms/step - loss: 0.0399 - accuracy: 0.9900 - val_loss: 0.2395 - val_accuracy: 0.9356
Epoch 40/1000
44/44 [==============================] - 20s 457ms/step - loss: 0.0305 - accuracy: 0.9914 - val_loss: 0.2900 - val_accuracy: 0.9185
Epoch 41/1000
44/44 [==============================] - 18s 409ms/step - loss: 0.0238 - accuracy: 0.9943 - val_loss: 0.1827 - val_accuracy: 0.9571
Epoch 42/1000
44/44 [==============================] - 18s 411ms/step - loss: 0.0279 - accuracy: 0.9886 - val_loss: 0.2681 - val_accuracy: 0.9399
Epoch 43/1000
44/44 [==============================] - 20s 449ms/step - loss: 0.0192 - accuracy: 0.9914 - val_loss: 0.2340 - val_accuracy: 0.9313
Epoch 44/1000
44/44 [==============================] - 20s 453ms/step - loss: 0.0418 - accuracy: 0.9914 - val_loss: 0.2768 - val_accuracy: 0.9227
Epoch 45/1000
44/44 [==============================] - 18s 400ms/step - loss: 0.0278 - accuracy: 0.9914 - val_loss: 0.1977 - val_accuracy: 0.9313
Epoch 46/1000
44/44 [==============================] - 20s 456ms/step - loss: 0.0279 - accuracy: 0.9943 - val_loss: 0.3983 - val_accuracy: 0.9013
Epoch 47/1000
44/44 [==============================] - 20s 464ms/step - loss: 0.0347 - accuracy: 0.9914 - val_loss: 0.3160 - val_accuracy: 0.9142
Epoch 48/1000
44/44 [==============================] - 20s 447ms/step - loss: 0.0437 - accuracy: 0.9871 - val_loss: 0.2124 - val_accuracy: 0.9442
Epoch 49/1000
44/44 [==============================] - 18s 403ms/step - loss: 0.0286 - accuracy: 0.9900 - val_loss: 0.3201 - val_accuracy: 0.9356
Epoch 50/1000
44/44 [==============================] - 20s 448ms/step - loss: 0.0141 - accuracy: 0.9943 - val_loss: 0.2216 - val_accuracy: 0.9528
Epoch 51/1000
44/44 [==============================] - 19s 442ms/step - loss: 0.0323 - accuracy: 0.9886 - val_loss: 0.2520 - val_accuracy: 0.9485
Epoch 52/1000
44/44 [==============================] - 20s 448ms/step - loss: 0.0215 - accuracy: 0.9886 - val_loss: 0.1760 - val_accuracy: 0.9485
Epoch 53/1000
44/44 [==============================] - 19s 440ms/step - loss: 0.0303 - accuracy: 0.9943 - val_loss: 0.3124 - val_accuracy: 0.9270
Epoch 54/1000
44/44 [==============================] - 20s 446ms/step - loss: 0.0300 - accuracy: 0.9886 - val_loss: 0.2771 - val_accuracy: 0.9356
Epoch 55/1000
44/44 [==============================] - 19s 428ms/step - loss: 0.0173 - accuracy: 0.9928 - val_loss: 0.2744 - val_accuracy: 0.9356
Epoch 56/1000
44/44 [==============================] - 20s 446ms/step - loss: 0.0251 - accuracy: 0.9886 - val_loss: 0.2540 - val_accuracy: 0.9442
Epoch 57/1000
44/44 [==============================] - 18s 410ms/step - loss: 0.0201 - accuracy: 0.9886 - val_loss: 0.2950 - val_accuracy: 0.9356
Epoch 58/1000
44/44 [==============================] - 20s 459ms/step - loss: 0.0238 - accuracy: 0.9914 - val_loss: 0.2186 - val_accuracy: 0.9442
Epoch 59/1000
43/44 [============================>.] - ETA: 0s - loss: 0.0167 - accuracy: 0.9898
Epoch 00059: ReduceLROnPlateau reducing learning rate to 1.9999999494757505e-06.
44/44 [==============================] - 20s 452ms/step - loss: 0.0167 - accuracy: 0.9900 - val_loss: 0.1988 - val_accuracy: 0.9528
Epoch 00059: early stopping
15/15 [==============================] - 5s 319ms/step - loss: 0.2250 - accuracy: 0.9528

Process finished with exit code 0

6.2. 最後的總結工作

七. 補充知識:tf.where

tf.where(
    condition,
    x=None,
    y=None,
    name=None
 )
 Return the elements, either from x or y, depending on the condition.

理解:where嘛,就是要根據條件找到你要的東西。

condition:條件,是一個boolean

x:數據

y:同x維度的數據。

返回,返回符合條件的數據。當條件爲真,取x對應的數據;當條件爲假,取y對應的數據
  • 例子實戰
import tensorflow as tf
import numpy as np

# 定義一個tensor,表示condition,內部數據隨機產生
condition = tf.convert_to_tensor(np.random.random([2, 3]), dtype=tf.float32)
print(condition)

# 定義兩個tensor,表示原數據
a = tf.ones(shape=[2, 3], name='a')
print(a)
b = tf.zeros(shape=[2, 3], name='b')
print(b)

# 選擇大於0.5的數值的座標,並根據condition信息在a和b中選取數據
result = tf.where(condition > 0.5, a, b)

print(result)
  • 運行結果
ssh://zhangkf@192.168.136.64:22/home/zhangkf/anaconda3/envs/tf2c/bin/python -u/home/zhangkf/johnCodes/TF1/test.py
tf.Tensor(
[[0.04526703 0.08822254 0.6437674 ]
 [0.3951503  0.39249578 0.51326084]], shape=(2, 3), dtype=float32)
tf.Tensor(
[[1. 1. 1.]
 [1. 1. 1.]], shape=(2, 3), dtype=float32)
tf.Tensor(
[[0. 0. 0.]
 [0. 0. 0.]], shape=(2, 3), dtype=float32)
tf.Tensor(
[[0. 0. 1.]
 [0. 0. 1.]], shape=(2, 3), dtype=float32)

Process finished with exit code 0
import tensorflow as tf
import numpy as np

# 定義一個tensor,表示condition,內部數據隨機產生
condition = tf.constant([1, 2, 2, 4])
print(condition)

# 定義兩個tensor,表示原數據
a = tf.constant([[1, 2, 2, 4], [3, 4, 5, 6], [7, 8, 9, 10], [2, 3, 3, 4]])
print(a)

# 選擇condition==2所在的座標(哪些行),並根據result_index進行選擇a中對應的行。
result_index = tf.where(condition == 2)
result = tf.gather_nd(a, result_index) # 返回a第2行和第3行。
print(result)

  • 運行結果
tf.Tensor([1 2 2 4], shape=(4,), dtype=int32)
tf.Tensor(
[[ 1  2  2  4]
 [ 3  4  5  6]
 [ 7  8  9 10]
 [ 2  3  3  4]], shape=(4, 4), dtype=int32)
tf.Tensor(
[[1]
 [2]], shape=(2, 1), dtype=int64)
tf.Tensor(
[[ 3  4  5  6]
 [ 7  8  9 10]], shape=(2, 4), dtype=int32)

Process finished with exit code 0
  • axis =0/1/-1
  • axis=0:在第一維操作
  • axis=1:在第二維操作
  • axis=-1:在最後一維操作
  • np.argmax()函數爲例:
>>> a = np.arange(24).reshape(2,3,4)
>>> a
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])
>>> np.argmax(a,axis = 0)  #shape(3,4)
array([[1, 1, 1, 1],
       [1, 1, 1, 1],
       [1, 1, 1, 1]])
>>> np.argmax(a,axis = 1)  #shape(2,4)
array([[2, 2, 2, 2],
       [2, 2, 2, 2]])
>>> np.argmax(a,axis = -1) #shape(2,3) 
array([[3, 3, 3],
       [3, 3, 3]])

八. 需要全套課程視頻+PPT+代碼資源可以私聊我!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章