從零基礎入門Tensorflow2.0 ----六、32cifar10數據訓練

every blog every motto:

0. 前言

cifar10 訓練

1. 代碼部分

1. 導入模塊

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn
import os,sys
import tensorflow as tf
import time
from tensorflow import keras
os.environ['CUDA_VISIBLE_DEVICES'] = '/gpu:0'
print(tf.__version__)
print(sys.version_info)
for module in mpl,pd,sklearn,tf,keras:
    print(module.__name__,module.__version__)

2. 讀取數據

class_names = [
    'airplane',
    'automobile',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck'
]
train_labels_file = './cifar10/trainLabels.csv'
test_csv_file = './cifar10/sampleSubmission.csv'
train_folder = './cifar10/train'
test_folder = './cifar10/test'

def parse_csv_file(filepath,folder):
    """parsers csv files into(filename(path),label) format"""
    results = []
    with open(filepath,'r') as f:
        lines = f.readlines()[1:]
    for line in lines:
        image_id,label_str = line.strip('\n').split(',')
        image_full_path = os.path.join(folder,image_id + '.png')
        results.append((image_full_path,label_str))
    return results


train_labels_info = parse_csv_file(train_labels_file,train_folder)
test_csv_info = parse_csv_file(test_csv_file,test_folder)

import pprint
pprint.pprint(train_labels_info[0:5])
pprint.pprint(test_csv_info[0:5])
print(len(train_labels_info),len(test_csv_info))

2.2 劃分數據

# train_df = pd.DataFrame(train_labels_info)
train_df = pd.DataFrame(train_labels_info[0:45000])
valid_df = pd.DataFrame(train_labels_info[45000:])
test_df = pd.DataFrame(test_csv_info)

# 修改列名
train_df.columns = ['filepath','class']
valid_df.columns = ['filepath','class']
test_df.columns = ['filepath','class']

print(train_df.head())
print(valid_df.head())
print(test_df.head())

3. 讀取圖片

# 讀取圖片
height = 32
width = 32
channels = 3
batch_size = 32
num_classes = 10

train_datagen = keras.preprocessing.image.ImageDataGenerator(
rescale = 1. / 255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range = 0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip = True,
    fill_mode = 'nearest',
)
train_generator = train_datagen.flow_from_dataframe(train_df,directory='./',x_col='filepath',y_col='class',classes=class_names,
                                                    target_size=(height,width),batch_size=batch_size,seed=7,shuffle=True,
                                                    class_mode='sparse',)


valid_datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
valid_generator = valid_datagen.flow_from_dataframe(valid_df,directory='./',x_col='filepath',y_col='class',classes=class_names,
                                                    target_size=(height,width),batch_size=batch_size,seed=7,shuffle=False,
                                                    class_mode="sparse")

train_num = train_generator.samples
valid_num = valid_generator.samples
print(train_num,valid_num)
# 讀取數據
for i in range(2):
    x,y = train_generator.next()
    print(x.shape,y.shape)
    print(y)

4. 模型搭建

model = keras.models.Sequential([
    keras.layers.Conv2D(filters=128,kernel_size=3,padding='same',activation='relu',input_shape=[width,height,channels]),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(filters=128,kernel_size=3,padding='same',activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=2),
    
    keras.layers.Conv2D(filters=256,kernel_size=3,padding='same',activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(filters=256,kernel_size=3,padding='same',activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=2),
    
    keras.layers.Conv2D(filters=512,kernel_size=3,padding='same',activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(filters=512,kernel_size=3,padding='same',activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=2),
    
    # 展平
    keras.layers.Flatten(),
    keras.layers.Dense(512,activation='relu'),
    keras.layers.Dense(num_classes,activation='softmax'),
])

model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
model.summary()

5. 訓練

epochs = 20
history = model.fit_generator(train_generator,steps_per_epoch = train_num // batch_size,epochs=epochs,
                            validation_data = valid_generator,validation_steps=valid_num // batch_size)

6. 學習曲線

# 學習曲線
def plot_learning_curves(hsitory,label,epochs,min_value,max_value):
    data = {}
    data[label] = history.history[label]
    data['val_' + label] = hsitory.history['val_' + label]
    pd.DataFrame(data).plot(figsize=(8,5))
    plt.grid(True)
    plt.axis([0,epochs,min_value,max_value])
    plt.show()

plot_learning_curves(history,'accuracy',epochs,0,1)
plot_learning_curves(history,'loss',epochs,1.5,2.5)

7. 測試集上

test_datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_dataframe(test_df,directory='./',x_col='filepath',y_col='class',classes=class_names,
                                                    target_size=(height,width),batch_size=batch_size,seed=7,shuffle=False,
                                                    class_mode="sparse")
test_num = test_generator.samples
print(test_num)
test_predict = model.predict_generator(test_generator,workers=10,use_multiprocessing=False)
# True 進程; False:線程
print(test_predict.shape)
print(test_predict[0:5])
test_predict_class_indices = np.argmax(test_predict,axis=1)
print(test_predict_class_indices[0:5])
test_predict_class=[class_names[index] for index in test_predict_class_indices]
print(test_predict_class[0:5])
def generate_submissions(fielname,predict_class):
    with open(filename,'w') as f:
        f.write('id,label\n')
        for i in rangelen((predict_class)):
            f.write('%d,%s\n'%(i+1,predict_class[i]))

output_file = './cifar10/submission.csv'
generate_submissions(output_file,test_predict_class)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章