TF.keras + tfrecord

TF.keras + tfrecord

在工程中,模型常常需要訓練大數據,而大數據的讀取通常不能一次性讀取進內存中,因此需要不斷從數據集中讀取數據並進行處理。在大數據中,這部分的耗時相當可觀,因此可以利用tfrecord進行預先處理數據,節省讀取和處理的時間。

使用tfrecord有幾個問題:

1.如何將圖像轉爲tfrecord格式。

2.如何讀取tfrecord文件進行訓練。

3.如何讀取多個tfrecord文件進行訓練。

圖像轉爲tfrecord

這裏需要注意的是,由於數據過大,不能在讀取tfrecord的時候打亂數據,這樣打亂數據不能充分打亂所有數據,因此,在保存tfrecord的時候就應該打亂數據,建議將圖像名列表打亂後在按照圖像名列表順序保存進多個tfrecord中。

1.首先打開tfrecord文件:

tf.python_io.TFRecordWriter( path, options=None)

在tensorflow1.14版本以上也可以使用:

tf.io.TFRecordWriter

tf.io.TFRecordWriter(
    path, options=None
)
Args:
  • path: The path to the TFRecords file.
  • options: (optional) String specifying compression type, TFRecordCompressionType, or TFRecordOptions object.

具有如下屬性:

close
close()

Close the file.

flush
flush()

Flush the file.

write
write(
    record
)

Write a string record to the file.

tf.train.Example

https://tensorflow.google.cn/api_docs/python/tf/train/Example?hl=en

tf.train.Features

https://tensorflow.google.cn/api_docs/python/tf/train/Example?hl=en

tf.train.Feature

Attributes:
  • bytes_list: BytesList bytes_list
  • float_list: FloatList float_list
  • int64_list: Int64List int64_list

https://tensorflow.google.cn/api_docs/python/tf/train/Example?hl=en

writer = tf.python_io.TFRecordWriter(os.path.join(tfrecord_save_path, ftrecordfilename))
img = Image.open(image_path, 'r')
img = img.resize((224, 224))
size = img.size

img_raw = img.tobytes()  # 將圖片轉化爲二進制格式
example = tf.train.Example(
features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'img_width': tf.train.Feature(int64_list=tf.train.Int64List(value=[size[0]])),
'img_height': tf.train.Feature(int64_list=tf.train.Int64List(value=[size[1]]))
}))
writer.write(example.SerializeToString())  # 序列化爲字符串
writer.close()

讀取多個tfrecord文件進行訓練

通過加載tfrecord文件的文件名,傳入

tf.data.TFRecordDataset

Args:
  • filenames: A tf.string tensor or tf.data.Dataset containing one or more filenames.
  • compression_type: (Optional.) A tf.string scalar evaluating to one of "" (no compression), "ZLIB", or "GZIP".
  • buffer_size: (Optional.) A tf.int64 scalar representing the number of bytes in the read buffer. If your input pipeline is I/O bottlenecked, consider setting this parameter to a value 1-100 MBs. If None, a sensible default for both local and remote file systems is used.
  • num_parallel_reads: (Optional.) A tf.int64 scalar representing the number of files to read in parallel. If greater than one, the records of files read in parallel are outputted in an interleaved order. If your input pipeline is I/O bottlenecked, consider setting this parameter to a value greater than one to parallelize the I/O. If None, files will be read sequentially.
def get_dataset_batch(data_files):
    dataset = tf.data.TFRecordDataset(data_files)
    dataset = dataset.repeat()  # 重複數據集
    dataset = dataset.map(read_and_decode)  # 解析數據
    dataset = dataset.shuffle(buffer_size=100)  # 在緩衝區中隨機打亂數據
    batch = dataset.batch(batch_size=4)  # 每10條數據爲一個batch,生成一個新的Datasets
    return batch

其中調用了回調函數:read_and_decode

def read_and_decode(example_string):
    '''
    從TFrecord格式文件中讀取數據
    '''
    features = tf.parse_single_example(example_string,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                           'img_width': tf.FixedLenFeature([], tf.int64),
                                           'img_height': tf.FixedLenFeature([], tf.int64)
                                       })

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [224, 224, 3])
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(features['label'], tf.int64)
    label = tf.one_hot(label, 2)
    return img, label

最後進行訓練

def train(model, batch):
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )

    model.fit(batch, epochs=1, steps_per_epoch=10)
    return model
def get_model():
    model = tf.keras.applications.MobileNetV2(include_top=False, weights=None)
    inputs = tf.keras.layers.Input(shape=(224, 224, 3))
    x = model(inputs)  # 此處x爲MobileNetV2模型去處頂層時輸出的特徵相應圖。
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = tf.keras.layers.Dense(2, activation='softmax',
                                    use_bias=True, name='Logits')(x)
    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    model.summary()
    return model

github代碼: https://github.com/18150167970/TFrecord_tf_keras_demo

if useful:
	start work

have fun(笑)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章