every blog every motto:
0. 前言
keras_generator讀取 kaggle 10 monkeys數據
1. 代碼部分
1. 導入模塊
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as lpt
import numpy as np
import pandas as pd
import sklearn
import os,sys
import tensorflow as tf
import time
from tensorflow import keras
print(tf.__version__)
print(sys.version_info)
for module in mpl,pd,sklearn,tf,keras:
print(module.__name__,module.__version__)
2. 文件路徑
# 文件路徑
train_dir = '../input/10-monkey-species/training/training'
valid_dir = "../input/10-monkey-species/validation/validation"
label_file = '../input/10-monkey-species/monkey_labels.txt'
print(os.path.exists(train_dir))
print(os.path.exists(valid_dir))
print(os.path.exists(label_file))
print(os.listdir(train_dir))
print(os.listdir(valid_dir))
3. 讀取數據
4.1 讀取標籤
# 讀取數據
labels = pd.read_csv(label_file,header=0)
print(labels)
4.2 讀取圖片
# 讀取圖片
height = 128
width = 128
channels = 3
batch_size = 64
num_classes = 10
train_datagen = keras.preprocessing.image.ImageDataGenerator(
rescale = 1. / 255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range = 0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip = True,
fill_mode = 'nearest',
)
train_generator = train_datagen.flow_from_directory(train_dir,target_size=(height,width),batch_size=batch_size,seed=7,shuffle=True,class_mode="categorical")
valid_datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
valid_generator = valid_datagen.flow_from_directory(valid_dir,target_size=(height,width),batch_size=batch_size,seed=7,shuffle=False,class_mode="categorical")
trian_num = train_generator.samples
valid_num = valid_generator.samples
print(trian_num,valid_num)
# 讀取數據
for i in range(2):
x,y = train_generator.next()
print(x.shape,y.shape)
print(y)