TF.keras + tfrecord

TF.keras + tfrecord

在工程中,模型常常需要训练大数据,而大数据的读取通常不能一次性读取进内存中,因此需要不断从数据集中读取数据并进行处理。在大数据中,这部分的耗时相当可观,因此可以利用tfrecord进行预先处理数据,节省读取和处理的时间。

使用tfrecord有几个问题:

1.如何将图像转为tfrecord格式。

2.如何读取tfrecord文件进行训练。

3.如何读取多个tfrecord文件进行训练。

图像转为tfrecord

这里需要注意的是,由于数据过大,不能在读取tfrecord的时候打乱数据,这样打乱数据不能充分打乱所有数据,因此,在保存tfrecord的时候就应该打乱数据,建议将图像名列表打乱后在按照图像名列表顺序保存进多个tfrecord中。

1.首先打开tfrecord文件:

tf.python_io.TFRecordWriter( path, options=None)

在tensorflow1.14版本以上也可以使用:

tf.io.TFRecordWriter

tf.io.TFRecordWriter(
    path, options=None
)
Args:
  • path: The path to the TFRecords file.
  • options: (optional) String specifying compression type, TFRecordCompressionType, or TFRecordOptions object.

具有如下属性:

close
close()

Close the file.

flush
flush()

Flush the file.

write
write(
    record
)

Write a string record to the file.

tf.train.Example

https://tensorflow.google.cn/api_docs/python/tf/train/Example?hl=en

tf.train.Features

https://tensorflow.google.cn/api_docs/python/tf/train/Example?hl=en

tf.train.Feature

Attributes:
  • bytes_list: BytesList bytes_list
  • float_list: FloatList float_list
  • int64_list: Int64List int64_list

https://tensorflow.google.cn/api_docs/python/tf/train/Example?hl=en

writer = tf.python_io.TFRecordWriter(os.path.join(tfrecord_save_path, ftrecordfilename))
img = Image.open(image_path, 'r')
img = img.resize((224, 224))
size = img.size

img_raw = img.tobytes()  # 将图片转化为二进制格式
example = tf.train.Example(
features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'img_width': tf.train.Feature(int64_list=tf.train.Int64List(value=[size[0]])),
'img_height': tf.train.Feature(int64_list=tf.train.Int64List(value=[size[1]]))
}))
writer.write(example.SerializeToString())  # 序列化为字符串
writer.close()

读取多个tfrecord文件进行训练

通过加载tfrecord文件的文件名,传入

tf.data.TFRecordDataset

Args:
  • filenames: A tf.string tensor or tf.data.Dataset containing one or more filenames.
  • compression_type: (Optional.) A tf.string scalar evaluating to one of "" (no compression), "ZLIB", or "GZIP".
  • buffer_size: (Optional.) A tf.int64 scalar representing the number of bytes in the read buffer. If your input pipeline is I/O bottlenecked, consider setting this parameter to a value 1-100 MBs. If None, a sensible default for both local and remote file systems is used.
  • num_parallel_reads: (Optional.) A tf.int64 scalar representing the number of files to read in parallel. If greater than one, the records of files read in parallel are outputted in an interleaved order. If your input pipeline is I/O bottlenecked, consider setting this parameter to a value greater than one to parallelize the I/O. If None, files will be read sequentially.
def get_dataset_batch(data_files):
    dataset = tf.data.TFRecordDataset(data_files)
    dataset = dataset.repeat()  # 重复数据集
    dataset = dataset.map(read_and_decode)  # 解析数据
    dataset = dataset.shuffle(buffer_size=100)  # 在缓冲区中随机打乱数据
    batch = dataset.batch(batch_size=4)  # 每10条数据为一个batch,生成一个新的Datasets
    return batch

其中调用了回调函数:read_and_decode

def read_and_decode(example_string):
    '''
    从TFrecord格式文件中读取数据
    '''
    features = tf.parse_single_example(example_string,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                           'img_width': tf.FixedLenFeature([], tf.int64),
                                           'img_height': tf.FixedLenFeature([], tf.int64)
                                       })

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [224, 224, 3])
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(features['label'], tf.int64)
    label = tf.one_hot(label, 2)
    return img, label

最后进行训练

def train(model, batch):
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )

    model.fit(batch, epochs=1, steps_per_epoch=10)
    return model
def get_model():
    model = tf.keras.applications.MobileNetV2(include_top=False, weights=None)
    inputs = tf.keras.layers.Input(shape=(224, 224, 3))
    x = model(inputs)  # 此处x为MobileNetV2模型去处顶层时输出的特征相应图。
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = tf.keras.layers.Dense(2, activation='softmax',
                                    use_bias=True, name='Logits')(x)
    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    model.summary()
    return model

github代码: https://github.com/18150167970/TFrecord_tf_keras_demo

if useful:
	start work

have fun(笑)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章