tensorflow(2)——讀取數據TFrecord

  • 學習這個是因爲搞tensorflow肯定跳不過這個坑,所以還不如靜下心來好好梳理一下。
  • 本文學完理論會優化自己以前的一個分類代碼,從原來最古老的placeholder版本做一下優化——啓發是來自transformer的源碼,它的做法讓我覺得我有必要體會一下。

TFrecord

  • 注意,這裏他只是一種文件存儲格式的改變,前文那些隊列的思想是沒變的!!!

簡單介紹

  • TFRecords其實是一種二進制文件,雖然它不如其他格式好理解,但是它能更好的利用內存,更方便複製和移動,並且不需要單獨的標籤文件。總而言之,這樣的文件格式好處多多。

  • TFRecords文件包含了tf.train.Example 協議內存塊(protocol buffer)(協議內存塊包含了字段 Features)。我們可以寫一段代碼獲取你的數據, 將數據填入到Example協議內存塊(protocol buffer),將協議內存塊序列化爲一個字符串, 並且通過tf.python_io.TFRecordWriter 寫入到TFRecords文件。

  • 從TFRecords文件中讀取數據, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。這個操作可以將Example協議內存塊(protocol buffer)解析爲張量。

  • 其實我是這樣理解的,我們可以把存入文件和讀取文件看作一種”通信協議“,我首先指定一下我們交互信息的協議,然後我存的時候這麼存進去,讀的時候也這麼讀出來,僅此而已!

開篇

def to_tfrecord(file_name,train_data,train_label):

    # 這裏準備一個樣本一個樣本的寫入TFRecord file中
    # 先把每個樣本中所有feature的信息和值存到字典中,key爲feature名,value爲feature值。
    # feature值需要轉變成tensorflow指定的feature類型中的一個。
    # tensorflow feature類型只接受list數據

    writer = tf.python_io.TFRecordWriter('%s.tfrecord' %file_name)

    for i in range(len(train_data)):

        # 寫入字典
        features = {}

        # 寫入向量,類型float,本身就是list,所以"value=vectors[i]"沒有中括號
        features['data'] = tf.train.Feature(float_list=tf.train.FloatList(value=train_data[i]))
        features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=train_label[i]))

        # 轉化爲tf_features
        tf_features = tf.train.Features(feature=features)

        # 再將其變成一個樣本example
        tf_example = tf.train.Example(features=tf_features)

        # 序列化該樣本
        tf_serialized = tf_example.SerializeToString()

        # 寫入一個序列化的樣本
        writer.write(tf_serialized)

    writer.close()

讀取(我感覺我碰到了最玄學的問題)

正常

# 使用TF_record導入數據

# 使用TF_record導入數據

filenames = "test.tfrecord"
filename_queue = tf.train.string_input_producer([filenames], num_epochs=None,
                                                shuffle=True)
# **2.創建一個讀取器
reader = tf.TFRecordReader()

_, serialized_example = reader.read(filename_queue)

# **3.根據你寫入的格式對應說明讀取的格式
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'data': tf.FixedLenFeature(shape=[100], dtype=tf.float32),
                                        'label': tf.FixedLenFeature(shape=[2], dtype=tf.float32)}     # 而標量就不用說明
                                   )
X_out = features['data']
y_out = features['label']

X_batch, y_batch = tf.train.shuffle_batch([X_out, y_out], batch_size=2,
                                          capacity=200, min_after_dequeue=100, num_threads=2)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

# **5.啓動隊列進行數據讀取
# 下面的 coord 是個線程協調器,把啓動隊列的時候加上線程協調器。
# 這樣,在數據讀取完畢以後,調用協調器把線程全部都關了。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
y_outputs = list()
for i in range(5):
    _X_batch, _y_batch = sess.run([X_batch, y_batch])
    print('** batch %d' % i)
    print('_X_batch:', _X_batch)
    print('_y_batch:', _y_batch)
    y_outputs.extend(_y_batch.tolist())
print(y_outputs)

# **6.最後記得把隊列關掉
coord.request_stop()
coord.join(threads)

在這裏插入圖片描述

報錯代碼:

def parse_function(example_proto):
    # 只接受一個輸入:example_proto,也就是序列化後的樣本tf_serialized

    # 解析規則
    # 也可以把形狀信息存入example_proto裏,然後在下面用
    dics = {
        'data': tf.FixedLenFeature(shape=[100], dtype=tf.float32, default_value=0.0),
        'label': tf.FixedLenFeature(shape=[2], dtype=tf.float32)
    }

    # 解析樣本
    parsed_example = tf.parse_single_example(example_proto,dics)

    # parsed_example['data'] = tf.reshape(parsed_example['data'], (1,100))
    #
    # # 轉變tensor形狀
    # parsed_example['label'] = tf.reshape(parsed_example['label'], (1,2))

    # 轉變特徵
    return parsed_example




# 使用TF_record導入數據

filenames = "test.tfrecord"
dataset = tf.data.TFRecordDataset(filenames)

'''由於從tfrecord文件中導入的樣本是剛纔寫入的tf_serialized序列化樣本,
所以我們需要對每一個樣本進行解析。這裏就用dataset.map(parse_function)來對dataset裏的每個樣本進行相同的解析操作。'''

new_dataset = dataset.map(parse_function)

# 創建迭代器
iterator = new_dataset.make_one_shot_iterator()

# 獲取樣本
next_element = iterator.get_next()

sess = tf.Session()

sess.run(next_element['data'])

在這裏插入圖片描述

END

  • 這個報錯挖個坑,下篇填。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章