使用PyTorch進行批量預測
%matplotlib inline
此示例遵循Torch的遷移學習教程。我們會
-
對特定任務(螞蟻與蜜蜂)微調預訓練的卷積神經網絡。
-
使用Dask羣集對該模型進行批量預測。
主要重點是使用Dask羣集進行批次預測。
請注意,examples.dask.org Binder上的基本環境不包括PyTorch或torchvision。要運行此示例,您需要運行
!conda install -y pytorch-cpu torchvision
這將需要一些時間才能運行。
下載資料
PyTorch文檔包含一小組數據。我們將在本地下載並解壓縮。
import urllib.request import zipfile
filename, _ = urllib.request.urlretrieve("https://download.pytorch.org/tutorial/hymenoptera_data.zip", "data.zip") zipfile.ZipFile(filename).extractall()
目錄看起來像
hymenoptera_data/ train/ ants/ 0013035.jpg ... 1030023514_aad5c608f9.jpg bees/ 1092977343_cb42b38d62.jpg ... 2486729079_62df0920be.jpg train/ ants/ 0013025.jpg ... 1030023514_aad5c606d9.jpg bees/ 1092977343_cb42b38e62.jpg ... 2486729079_62df0921be.jpg
在學習完本教程之後,我們將對模型進行微調。
import torchvision from tutorial_helper import (imshow, train_model, visualize_model, dataloaders, class_names, finetune_model)
微調模型
我們的基本模型是resnet18。它可以預測1,000種類別,而我們只預測2種(螞蟻或蜜蜂)。爲了使該模型在examples.dask.org上快速培訓,我們僅使用幾個紀元。
import dask
%%time model = finetune_model()
時代0/1 ---------- 火車損失:0.6196累積:0.6844 val損失:0.2042 Acc:0.9281 時代1/1 ---------- 火車損失:0.4517累積:0.7787 val損失:0.1458 Acc:0.9477 訓練在0m 4s內完成 最佳增值值:0.947712 CPU時間:用戶3.92 s,系統:2.03 s,總計:5.95 s 掛牆時間:6.33 s
在一些隨機圖像上,事情似乎還可以:
visualize_model(model)
使用Dask進行批量預測
現在是主要主題:在Dask集羣上使用預訓練的模型進行批量預測。有兩個主要的複雜性,它們都涉及最小化移動的數據量:
-
將數據加載到工作程序上。。我們將用於
dask.delayed
將數據加載到工作程序上,而不是將數據加載到客戶端上並將其發送給工作程序。 -
PyTorch神經網絡很大。我們不希望它們出現在Dask任務圖中,我們只希望將它們移動一次。
from distributed import Client client = Client(n_workers=2, threads_per_worker=2) client
客戶
|
簇
|
將數據加載到工作人員上
首先,我們將定義幾個助手來加載數據並對神經網絡進行預處理。我們將dask.delayed
在這裏使用它,以便執行是懶惰的,並且在集羣上進行。有關使用的更多信息,請參見延遲的示例dask.delayed
。
import glob import toolz import dask import dask.array as da import torch from torchvision import transforms from PIL import Image @dask.delayed def load(path, fs=__builtins__): with fs.open(path, 'rb') as f: img = Image.open(f).convert("RGB") return img @dask.delayed def transform(img): trn = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) return trn(img)
objs = [load(x) for x in glob.glob("hymenoptera_data/val/*/*.jpg")]
要從雲存儲(例如Amazon S3)加載數據,您將使用
import s3fs fs = s3fs.S3FileSystem(...) objs = [load(x, fs=fs) for x in fs.glob(...)]
PyTorch模型期望特定形狀的張量,因此讓我們對其進行轉換。
tensors = [transform(x) for x in objs]
而且該模型需要成批的輸入,因此讓我們將其堆疊在一起。
batches = [dask.delayed(torch.stack)(batch) for batch in toolz.partition_all(10, tensors)] batches[:5]
[已延遲('stack-da59d324-464a-4dce-adfa-0dc99dc53299', 延遲的('stack-939f881b-58ba-4bb5-b4eb-1df6ccfa850f'), 延遲('stack-e3809d5d-84f2-4279-a1a6-71131f4d2c53'), 延遲的('stack-a172c545-7cdd-467f-a2bc-e5c5ae611d50'), 延遲('stack-8698c88b-6e05-442d-8346-8af67d0992ae')]
最後,我們將編寫一個小的predict
幫助程序來預測輸出類(0或1)。
@dask.delayed def predict(batch, model): with torch.no_grad(): out = model(batch) _, predicted = torch.max(out, 1) predicted = predicted.numpy() return predicted
移動模型
PyTorch神經網絡很大,因此我們不想在任務圖中重複很多次(每批一次)。
import pickle dask.utils.format_bytes(len(pickle.dumps(model)))
'44 .80 MB'
相反,我們還將模型本身包裝在中dask.delayed
。這意味着該模型在Dask圖中僅顯示一次。
此外,由於我們在上面進行了微調(如果可以在GPU上運行,則可以在GPU上運行),因此我們應該將模型移回CPU。
dmodel = dask.delayed(model.cpu()) # ensuring model is on the CPU
現在,我們將使用(延遲)predict
方法來獲得我們的預測。
predictions = [predict(batch, dmodel) for batch in batches] dask.visualize(predictions[:2])
可視化有些混亂,但是大型的PyTorch模型是這兩個predict
任務的始祖。
現在,我們可以使用Dask集羣執行所有工作了。由於我們正在使用的數據集很小,因此僅dask.compute
將結果帶回本地客戶端是安全的。對於較大的數據集,您將要寫入磁盤或雲存儲,或者繼續處理集羣上的預測。
predictions = dask.compute(*predictions) predictions
(數組([1,1,1,0,1,0,1,1,1,1]), 數組([1,1,1,1,1,1,1,1,1,1,1]), 數組([1,1,1,1,1,1,1,1,1,1,1]), 數組([1,1,1,1,1,1,1,1,1,1,1]), 數組([1,1,1,1,1,1,1,1,1,1,1]), 數組([1,1,1,1,1,1,1,0,1,0]), 數組([1,1,1,1,1,1,1,1,1,1,1]), 數組([1,1,1,1,1,1,1,1,1,1,1]), array([1,1,1,0,0,0,0,0,0,0]), array([0,0,0,1,0,0,0,0,0,0]), 數組([1,0,0,0,0,0,0,0,0,0]), array([0,0,0,1,0,0,0,0,0,0]), array([0,0,0,1,0,0,0,0,0,0]), 數組([0,0,0,0,0,0,0,0,0,0]), 數組([0,0,0,0,0,0,0,0,0,0]), 數組([0,0,0]))
概要
本示例說明了如何使用PyTorch和Dask對一組圖像進行批量預測。我們非常小心地將數據遠程加載到羣集上,並且只對大型神經網絡進行一次序列化。