tensorflow筆記 string_input_producer, slice_input_producer

tensorflow將讀取數據分爲了兩個步驟,先讀入文件名隊列,再讀入內存隊列進行運算。爲了減少GPU的等待時間,提高計算速度,tensorflow使用兩個線程來分別處理這兩個步驟。tf有三個函數string_input_producer, slice_input_producer, input_producer用於建立文件名隊列。

函數參數如下所示,除了tensor list是必須外,其餘都可以省略。input_producer輸入爲一個tensor,每行是一個數據,slice_input_producer輸入爲一個tensor列表,string_input_producer輸入爲一個string類型的tensor列表。

tf.train.input_producer(
    input_tensor,
    element_shape=None,
    num_epochs=None, #文件名隊列中數據重複n次,保證每個文件都會被訪問n次
    shuffle=True,
    seed=None,
    capacity=32,
    shared_name=None,
    summary_name=None,
    name=None,
    cancel_op=None
)

tf.train.string_input_producer(
    string_tensor,
    num_epochs=None,
    shuffle=True,
    seed=None,
    capacity=32,
    shared_name=None,
    name=None,
    cancel_op=None
)

tf.train.slice_input_producer(
    tensor_list,
    num_epochs=None,
    shuffle=True,
    seed=None,
    capacity=32,
    shared_name=None,
    name=None
)

文件名隊列並不會直接讀入,需要用tf.train.start_queue_runners()啓動

threads = tf.train.start_queue_runners(sess,coord)

官方建議:
THIS FUNCTION IS DEPRECATED. It will be removed in a future version. Instructions for updating: Queue-based input pipelines have been replaced by tf.data. Use tf.data.Dataset.from_tensor_slices(input_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs). If shuffle=False, omit the .shuffle(...)

現在一般使用更簡潔的data api來讀取數據。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章