【深度學習 走進tensorflow2.0】TensorFlow 2.0 常用模塊:tf.data 數據流加速

無意中發現了一個巨牛的人工智能教程,忍不住分享一下給大家。教程不僅是零基礎,通俗易懂,而且非常風趣幽默,像看小說一樣!覺得太牛了,所以分享給大家。點這裏可以跳轉到教程。人工智能教程

上一篇文章TensorFlow 2.0 常用模塊tf.data
介紹了基本的使用讀取數據方法,下面我們介紹如何通過 prefetch 和 map 的並行化參數,讓 tf.data 的性能得到明顯提升。

當訓練模型時,我們希望充分利用計算資源,減少 CPU/GPU 的空載時間。然而有時,數據集的準備處理非常耗時,使得我們在每進行一次訓練前都需要花費大量的時間準備待訓練的數據,而此時 GPU 只能空載而等待數據,造成了計算資源的浪費,如下圖:
在這裏插入圖片描述
此時, tf.data 的數據集對象爲我們提供了 Dataset.prefetch() 方法,使得我們可以讓數據集對象 Dataset 在訓練時預取出若干個元素,使得在 GPU 訓練的同時 CPU 可以準備數據,從而提升訓練流程的效率,如下圖所示:

在這裏插入圖片描述

Dataset.prefetch() 的使用方法和前節的 Dataset.batch() 、 Dataset.shuffle() 等非常類似。繼續以前節的 MNIST 數據集爲例,若希望開啓預加載數據,使用如下代碼即可:

mnist_dataset = mnist_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

此處參數 buffer_size 既可手工設置,也可設置爲 tf.data.experimental.AUTOTUNE 從而由 TensorFlow 自動選擇合適的數值。

與此類似,Dataset.map() 也可以利用多 GPU 資源,並行化地對數據項進行變換,從而提高效率。以前節的 MNIST 數據集爲例,假設用於訓練的計算機具有 2 核的 CPU,我們希望充分利用多核心的優勢對數據進行並行化變換(比如 前節 的旋轉 90 度函數 rot90 ),可以使用以下代碼:

mnist_dataset = mnist_dataset.map(map_func=rot90, num_parallel_calls=2)

在這裏插入圖片描述

當然,這裏同樣可以將 num_parallel_calls 設置爲 tf.data.experimental.AUTOTUNE 以讓 TensorFlow 自動選擇合適的數值。

發佈了653 篇原創文章 · 獲贊 795 · 訪問量 188萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章