在PAI上, 使用TensorFlow讀取OSS文件
作者: 萬千鈞
轉載的出處
本文適合有一定TensorFlow基礎, 且準備使用PAI的同學閱讀
目錄
1. 如何PAI上讀取數據
2. 如何減少讀取的費用開支
3. 使用OSS需要注意的問題
1. 在PAI上讀取數據
Python不支持讀取oss的數據, 故所有調用 python Open(), os.path.exist() 等文件, 文件夾操作的函數的代碼都無法執行.
如Scipy.misc.imread(), numpy.load() 等
那如何在PAI讀取數據呢, 通常我們採用兩種辦法.
如果只是簡單的讀取一張圖片, 或者一個文本等, 可以使用tf.gfile下的函數, 具體成員函數如下
tf.gfile.Copy(oldpath, newpath, overwrite=False) # 拷貝文件
tf.gfile.DeleteRecursively(dirname) # 遞歸刪除目錄下所有文件
tf.gfile.Exists(filename) # 文件是否存在
tf.gfile.FastGFile(name, mode='r') # 無阻塞讀寫文件
tf.gfile.GFile(name, mode='r') # 讀寫文件
tf.gfile.Glob(filename) # 列出文件夾下所有文件, 支持pattern
tf.gfile.IsDirectory(dirname) # 返回dirname是否爲一個目錄
tf.gfile.ListDirectory(dirname) # 列出dirname下所有文件
tf.gfile.MakeDirs(dirname) # 在dirname下創建一個文件夾, 如果父目錄不存在, 會自動創建父目錄. 如果
文件夾已經存在, 且文件夾可寫, 會返回成功
tf.gfile.MkDir(dirname) # 在dirname處創建一個文件夾
tf.gfile.Remove(filename) # 刪除filename
tf.gfile.Rename(oldname, newname, overwrite=False) # 重命名
tf.gfile.Stat(dirname) # 返回目錄的統計數據
tf.gfile.Walk(top, inOrder=True) # 返回目錄的文件樹
具體的文檔可以參照這裏(可能需要翻牆)
如果是一批一批的讀取文件, 一般會採用tf.WholeFileReader() 和 tf.train.batch() / tf.train.shuffer_batch()
接下來會重點介紹常用的 tf.gfile.Glob, tf.gfile.FastGFile, tf.WholeFileReader() 和 tf.train.shuffer_batch()
讀取文件一般有兩步
1. 獲取文件列表
2. 讀取文件
如果是批量讀取, 還有第三步
3. 創建batch
從代碼上手: 在使用PAI的時候, 通常需要在右側設置讀取目錄, 代碼文件等參數, 這些參數都會通過--XXX的形式傳入
tf.flags可以提供了這個功能
import tensorflow as tf
import os
FLAGS = tf.flags.FLAGS
# 前面的buckets, checkpointDir都是固定的, 不建議更改
tf.flags.DEFINE_string('buckets', 'oss://XXX', '訓練圖片所在文件夾')
tf.flags.DEFINE_string('batch_size', '15', 'batch大小')
# 獲取文件列表
files = tf.gfile.Glob(os.path.join(FLAGS.buckets,'*.jpg')) # 如我想列出buckets下所有jpg文件路徑
接下來就分兩種情況了
1. (小規模讀取時建議) tf.gfile.FastGfile()
for path in files:
file_content = tf.gfile.FastGFile(path, 'rb').read() # 一定記得使用rb讀取, 不然很多情況下都會報錯
image = tf.image.decode_jpeg(file_content, channels=3) # 本教程以JPG圖片爲例
2. (大批量讀取時建議) tf.WholeFileReader()
reader = tf.WholeFileReader() # 實例化一個reader
fileQueue = tf.train.string_input_producer(files) # 創建一個供reader讀取的隊列
file_name, file_content = reader.read(fileQueue) # 使reader從隊列中讀取一個文件
image = tf.image.decode_jpeg(file_content, channels=3) # 講讀取結果解碼爲圖片
label = XXX # 這裏省略處理label的過程
batch = tf.train.shuffle_batch([label, image], batch_size=FLAGS.batch_size, num_threads=4,
capacity=1000 + 3 * FLAGS.batch_size, min_after_dequeue=1000)
sess = tf.Session() # 創建Session
tf.train.start_queue_runners(sess=sess) # 重要!!! 這個函數是啓動隊列, 不加這句線程會一直阻塞
labels, images = sess.run(batch) # 獲取結果
解釋下其中重要的部分tf.train.string_input_producer, 這個是把files轉換成一個隊列, 並且需要 tf.train.start_queue_runners 來啓動隊列
tf.train.shuffle_batch 參數解釋
batch_size 批大小, 每次運行這個batch, 返回多少個數據
num_threads 運行線程數, 在PAI上4個就好
capacity 隨機取文件範圍, 比如你的數據集有10000個數據, 你想從5000個數據中隨機取, capacity就設置成5000.
min_after_dequeue 維持隊列的最小長度, 這裏只要注意不要大於capacity即可
2.費用開支
這裏只討論讀取文件所需要的費用開支
原則上來說, PAI不跨區域讀取OSS是不收費的, 但是OSS的API是收費的. PAI在使用 tf.gile.Glob 的時候 會產生GET請求, 在寫入 tensorboard 的時候, 也會產生PUT請求. 這兩種請求都是按次收費的, 具體價格如下
標準型單價: 0.01元/萬次
低頻訪問型單價: 0.1元/萬次
歸檔型單價: 0.1元/萬次
當數據集有幾十萬圖片, 通過 tf.gile.Glob 一次就需要幾毛錢. 所以減少費用開支的方法就是減少GET請求次數
這裏給出幾種解決思路
1. 最好的解決思路, 把所有會使用到的數據, 一併上傳傳到OSS, 然後使用tensorflow拷貝到運行時目錄, 最後通過tensorflow讀取, 這樣是最節省開支的.
2. 通過tfrecords, 在本地, 提前把幾十上百張圖片通過tfrecords存下來, 這樣讀取的時候可以減少GET請求
3. 把訓練使用的圖片隨着代碼的壓縮包一起傳上去, 不走OSS讀取
三種方法都可以顯著的減少開支.
3.使用中需要注意的
事實上, 每次讀取傳過來的地址就是 oss://你的buckets名字/XXX, 本以爲不需要在PAI界面上 設置, 直接讀取這個目錄就好, 事實上並不如此.
PAI沒有權限讀取不在數據源目錄和輸出目錄下的文件, 所以在使用路徑前, 確保他們已經在控制檯右側設置過.
另外如果需要寫入文件到OSS, 可以使用 tf.gfile.fastGfile('OSS路徑', 'wb').write('內容')
OSS路徑推薦使用
FLAGS.checkpointDir
FLAGS.summaryDIr
這樣的形式傳入, 經過測試好像也只有這兩個目錄下有寫權限, FLAGS.buckets有讀權限