TF.keras + tfrecord
在工程中,模型常常需要訓練大數據,而大數據的讀取通常不能一次性讀取進內存中,因此需要不斷從數據集中讀取數據並進行處理。在大數據中,這部分的耗時相當可觀,因此可以利用tfrecord進行預先處理數據,節省讀取和處理的時間。
使用tfrecord有幾個問題:
1.如何將圖像轉爲tfrecord格式。
2.如何讀取tfrecord文件進行訓練。
3.如何讀取多個tfrecord文件進行訓練。
圖像轉爲tfrecord
這裏需要注意的是,由於數據過大,不能在讀取tfrecord的時候打亂數據,這樣打亂數據不能充分打亂所有數據,因此,在保存tfrecord的時候就應該打亂數據,建議將圖像名列表打亂後在按照圖像名列表順序保存進多個tfrecord中。
1.首先打開tfrecord文件:
tf.python_io.TFRecordWriter( path, options=None)
在tensorflow1.14版本以上也可以使用:
tf.io.TFRecordWriter
tf.io.TFRecordWriter(
path, options=None
)
Args:
path
: The path to the TFRecords file.options
: (optional) String specifying compression type,TFRecordCompressionType
, orTFRecordOptions
object.
具有如下屬性:
close
close()
Close the file.
flush
flush()
Flush the file.
write
write(
record
)
Write a string record to the file.
tf.train.Example
https://tensorflow.google.cn/api_docs/python/tf/train/Example?hl=en
tf.train.Features
https://tensorflow.google.cn/api_docs/python/tf/train/Example?hl=en
tf.train.Feature
Attributes:
bytes_list
:BytesList bytes_list
float_list
:FloatList float_list
int64_list
:Int64List int64_list
https://tensorflow.google.cn/api_docs/python/tf/train/Example?hl=en
writer = tf.python_io.TFRecordWriter(os.path.join(tfrecord_save_path, ftrecordfilename))
img = Image.open(image_path, 'r')
img = img.resize((224, 224))
size = img.size
img_raw = img.tobytes() # 將圖片轉化爲二進制格式
example = tf.train.Example(
features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'img_width': tf.train.Feature(int64_list=tf.train.Int64List(value=[size[0]])),
'img_height': tf.train.Feature(int64_list=tf.train.Int64List(value=[size[1]]))
}))
writer.write(example.SerializeToString()) # 序列化爲字符串
writer.close()
讀取多個tfrecord文件進行訓練
通過加載tfrecord文件的文件名,傳入
tf.data.TFRecordDataset
Args:
filenames
: Atf.string
tensor ortf.data.Dataset
containing one or more filenames.compression_type
: (Optional.) Atf.string
scalar evaluating to one of""
(no compression),"ZLIB"
, or"GZIP"
.buffer_size
: (Optional.) Atf.int64
scalar representing the number of bytes in the read buffer. If your input pipeline is I/O bottlenecked, consider setting this parameter to a value 1-100 MBs. IfNone
, a sensible default for both local and remote file systems is used.num_parallel_reads
: (Optional.) Atf.int64
scalar representing the number of files to read in parallel. If greater than one, the records of files read in parallel are outputted in an interleaved order. If your input pipeline is I/O bottlenecked, consider setting this parameter to a value greater than one to parallelize the I/O. IfNone
, files will be read sequentially.
def get_dataset_batch(data_files):
dataset = tf.data.TFRecordDataset(data_files)
dataset = dataset.repeat() # 重複數據集
dataset = dataset.map(read_and_decode) # 解析數據
dataset = dataset.shuffle(buffer_size=100) # 在緩衝區中隨機打亂數據
batch = dataset.batch(batch_size=4) # 每10條數據爲一個batch,生成一個新的Datasets
return batch
其中調用了回調函數:read_and_decode
def read_and_decode(example_string):
'''
從TFrecord格式文件中讀取數據
'''
features = tf.parse_single_example(example_string,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
'img_width': tf.FixedLenFeature([], tf.int64),
'img_height': tf.FixedLenFeature([], tf.int64)
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [224, 224, 3])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
label = tf.cast(features['label'], tf.int64)
label = tf.one_hot(label, 2)
return img, label
最後進行訓練
def train(model, batch):
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss="categorical_crossentropy",
metrics=["accuracy"],
)
model.fit(batch, epochs=1, steps_per_epoch=10)
return model
def get_model():
model = tf.keras.applications.MobileNetV2(include_top=False, weights=None)
inputs = tf.keras.layers.Input(shape=(224, 224, 3))
x = model(inputs) # 此處x爲MobileNetV2模型去處頂層時輸出的特徵相應圖。
x = tf.keras.layers.GlobalAveragePooling2D()(x)
outputs = tf.keras.layers.Dense(2, activation='softmax',
use_bias=True, name='Logits')(x)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.summary()
return model
github代碼: https://github.com/18150167970/TFrecord_tf_keras_demo
if useful:
start work
have fun(笑)