【深度学习 走进tensorflow2.0】TensorFlow 2.0 常用模块:tf.data 数据流加速

无意中发现了一个巨牛的人工智能教程,忍不住分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。点这里可以跳转到教程。人工智能教程

上一篇文章TensorFlow 2.0 常用模块tf.data
介绍了基本的使用读取数据方法,下面我们介绍如何通过 prefetch 和 map 的并行化参数,让 tf.data 的性能得到明显提升。

当训练模型时,我们希望充分利用计算资源,减少 CPU/GPU 的空载时间。然而有时,数据集的准备处理非常耗时,使得我们在每进行一次训练前都需要花费大量的时间准备待训练的数据,而此时 GPU 只能空载而等待数据,造成了计算资源的浪费,如下图:
在这里插入图片描述
此时, tf.data 的数据集对象为我们提供了 Dataset.prefetch() 方法,使得我们可以让数据集对象 Dataset 在训练时预取出若干个元素,使得在 GPU 训练的同时 CPU 可以准备数据,从而提升训练流程的效率,如下图所示:

在这里插入图片描述

Dataset.prefetch() 的使用方法和前节的 Dataset.batch() 、 Dataset.shuffle() 等非常类似。继续以前节的 MNIST 数据集为例,若希望开启预加载数据,使用如下代码即可:

mnist_dataset = mnist_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

此处参数 buffer_size 既可手工设置,也可设置为 tf.data.experimental.AUTOTUNE 从而由 TensorFlow 自动选择合适的数值。

与此类似,Dataset.map() 也可以利用多 GPU 资源,并行化地对数据项进行变换,从而提高效率。以前节的 MNIST 数据集为例,假设用于训练的计算机具有 2 核的 CPU,我们希望充分利用多核心的优势对数据进行并行化变换(比如 前节 的旋转 90 度函数 rot90 ),可以使用以下代码:

mnist_dataset = mnist_dataset.map(map_func=rot90, num_parallel_calls=2)

在这里插入图片描述

当然,这里同样可以将 num_parallel_calls 设置为 tf.data.experimental.AUTOTUNE 以让 TensorFlow 自动选择合适的数值。

发布了653 篇原创文章 · 获赞 795 · 访问量 188万+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章