序
- 學習這個是因爲搞tensorflow肯定跳不過這個坑,所以還不如靜下心來好好梳理一下。
- 本文學完理論會優化自己以前的一個分類代碼,從原來最古老的placeholder版本做一下優化——啓發是來自transformer的源碼,它的做法讓我覺得我有必要體會一下。
TFrecord
- 注意,這裏他只是一種文件存儲格式的改變,前文那些隊列的思想是沒變的!!!
簡單介紹
-
TFRecords其實是一種二進制文件,雖然它不如其他格式好理解,但是它能更好的利用內存,更方便複製和移動,並且不需要單獨的標籤文件。總而言之,這樣的文件格式好處多多。
-
TFRecords文件包含了tf.train.Example 協議內存塊(protocol buffer)(協議內存塊包含了字段 Features)。我們可以寫一段代碼獲取你的數據, 將數據填入到Example協議內存塊(protocol buffer),將協議內存塊序列化爲一個字符串, 並且通過tf.python_io.TFRecordWriter 寫入到TFRecords文件。
-
從TFRecords文件中讀取數據, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。這個操作可以將Example協議內存塊(protocol buffer)解析爲張量。
-
其實我是這樣理解的,我們可以把存入文件和讀取文件看作一種”通信協議“,我首先指定一下我們交互信息的協議,然後我存的時候這麼存進去,讀的時候也這麼讀出來,僅此而已!
開篇
-
基本代碼爲,目的:把你的數據轉化成tf_record文件
def to_tfrecord(file_name,train_data,train_label):
# 這裏準備一個樣本一個樣本的寫入TFRecord file中
# 先把每個樣本中所有feature的信息和值存到字典中,key爲feature名,value爲feature值。
# feature值需要轉變成tensorflow指定的feature類型中的一個。
# tensorflow feature類型只接受list數據
writer = tf.python_io.TFRecordWriter('%s.tfrecord' %file_name)
for i in range(len(train_data)):
# 寫入字典
features = {}
# 寫入向量,類型float,本身就是list,所以"value=vectors[i]"沒有中括號
features['data'] = tf.train.Feature(float_list=tf.train.FloatList(value=train_data[i]))
features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=train_label[i]))
# 轉化爲tf_features
tf_features = tf.train.Features(feature=features)
# 再將其變成一個樣本example
tf_example = tf.train.Example(features=tf_features)
# 序列化該樣本
tf_serialized = tf_example.SerializeToString()
# 寫入一個序列化的樣本
writer.write(tf_serialized)
writer.close()
讀取(我感覺我碰到了最玄學的問題)
正常
# 使用TF_record導入數據
# 使用TF_record導入數據
filenames = "test.tfrecord"
filename_queue = tf.train.string_input_producer([filenames], num_epochs=None,
shuffle=True)
# **2.創建一個讀取器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# **3.根據你寫入的格式對應說明讀取的格式
features = tf.parse_single_example(serialized_example,
features={
'data': tf.FixedLenFeature(shape=[100], dtype=tf.float32),
'label': tf.FixedLenFeature(shape=[2], dtype=tf.float32)} # 而標量就不用說明
)
X_out = features['data']
y_out = features['label']
X_batch, y_batch = tf.train.shuffle_batch([X_out, y_out], batch_size=2,
capacity=200, min_after_dequeue=100, num_threads=2)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
# **5.啓動隊列進行數據讀取
# 下面的 coord 是個線程協調器,把啓動隊列的時候加上線程協調器。
# 這樣,在數據讀取完畢以後,調用協調器把線程全部都關了。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
y_outputs = list()
for i in range(5):
_X_batch, _y_batch = sess.run([X_batch, y_batch])
print('** batch %d' % i)
print('_X_batch:', _X_batch)
print('_y_batch:', _y_batch)
y_outputs.extend(_y_batch.tolist())
print(y_outputs)
# **6.最後記得把隊列關掉
coord.request_stop()
coord.join(threads)
報錯代碼:
- 使用datasets
參考最最上面那個正常的代碼
def parse_function(example_proto):
# 只接受一個輸入:example_proto,也就是序列化後的樣本tf_serialized
# 解析規則
# 也可以把形狀信息存入example_proto裏,然後在下面用
dics = {
'data': tf.FixedLenFeature(shape=[100], dtype=tf.float32, default_value=0.0),
'label': tf.FixedLenFeature(shape=[2], dtype=tf.float32)
}
# 解析樣本
parsed_example = tf.parse_single_example(example_proto,dics)
# parsed_example['data'] = tf.reshape(parsed_example['data'], (1,100))
#
# # 轉變tensor形狀
# parsed_example['label'] = tf.reshape(parsed_example['label'], (1,2))
# 轉變特徵
return parsed_example
# 使用TF_record導入數據
filenames = "test.tfrecord"
dataset = tf.data.TFRecordDataset(filenames)
'''由於從tfrecord文件中導入的樣本是剛纔寫入的tf_serialized序列化樣本,
所以我們需要對每一個樣本進行解析。這裏就用dataset.map(parse_function)來對dataset裏的每個樣本進行相同的解析操作。'''
new_dataset = dataset.map(parse_function)
# 創建迭代器
iterator = new_dataset.make_one_shot_iterator()
# 獲取樣本
next_element = iterator.get_next()
sess = tf.Session()
sess.run(next_element['data'])
END
- 這個報錯挖個坑,下篇填。