Pytorch學習(一)加載數據

在Pytorch中加載數據

pytorch具有廣泛的神經網絡構建模塊和一個簡單、直觀、穩定的API。Pytorch包括爲您的模型準備和加載通用數據集的包。

介紹

Pytorch加載數據的核心是torch.utils.data.DataLoader類。它表示一個在數據集上的一個Python可迭代對象。Pytorch庫爲我們提供了內置的高質量數據集,去在torch.utils.data.Dataset中使用。數據集可從tochvisiontorchaudiotorchtext中獲得。

我們使用來自torchaudio.datasets的Yesno數據集。我們將演示如何有效地將數據從PyTorch數據集加載到PyTorch DataLoader中。

配置

pip install torchaudio

步驟、

1. 導入必須的庫,來加載我們的數據

2. 訪問數據集中的數據

3. 加載數據

4. 對數據進行迭代

5. 可視化數據(可選擇)

1. Import necessary libraries for loading our data

import torch
import torchaudio

2. Access the data in the dataset

torchaudio.datasets.YESNO(
  root,
  url='http://www.openslr.org/resources/1/waves_yesno.tar.gz',
  folder_in_archive='waves_yesno',
  download=False,
  transform=None,
  target_transform=None)

# * ``download``: If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
# * ``transform``: Using transforms on your data allows you to take it from its source state and transform it into data that’s joined together, de-normalized, and ready for training. Each library in PyTorch supports a growing list of transformations.
# * ``target_transform``: A function/transform that takes in the target and transforms it.
#
# Let’s access our Yesno data:
#

# A data point in Yesno is a tuple (waveform, sample_rate, labels) where labels
# is a list of integers with 1 for yes and 0 for no.
yesno_data_trainset = torchaudio.datasets.YESNO('./', download=True)

# Pick data point number 3 to see an example of the the yesno_data:
n = 3
waveform, sample_rate, labels = yesno_data[n]
print("Waveform: {}\nSample rate: {}\nLabels: {}".format(waveform, sample_rate, labels))

3. Loading the data

data_loader = torch.utils.data.DataLoader(yesno_data,
                                          batch_size=1,
                                          shuffle=True)

4. Iterate over the data

for data in data_loader:
  print("Data: ", data)
  print("Waveform: {}\nSample rate: {}\nLabels: {}".format(data[0], data[1], data[2]))
  break

5. [Optional] Visualize the data

import matplotlib.pyplot as plt

print(data[0][0].numpy())

plt.figure()
plt.plot(waveform.t().numpy())

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章