爲了驗證小樣本的思想,先從數據量的更改出發,調整數據集的大小,驗證數據量改變對性能改變到底有怎樣的影響,這篇博客記錄下調整數據量的方法。
PyTorch中單獨提供了一個sampler
模塊,用來對數據進行採樣。RandomSampler
,當dataloader的shuffle
參數爲True時,系統會自動調用這個採樣器,實現打亂數據。默認的是採用SequentialSampler
,它會按順序一個一個進行採樣。
WeightedRandomSampler
,它會根據每個樣本的權重選取數據,在樣本比例不均衡的問題中,可用它來進行重採樣。
WeightedRandomSampler
需提供兩個參數:每個樣本的權重weights
、樣本總數num_samples
,以及一個可選參數replacement
。權重越大的樣本被選中的概率越大,待選取的樣本數目一般小於全部的樣本數目。replacement
用於指定是否可以重複選取某一個樣本,默認爲True,即允許在一個epoch中重複採樣某一個數據。如果設爲False,則當某一類的樣本被全部選取完,但其樣本數目仍未達到num_samples時,sampler將不會再從該類中選擇數據,此時可能導致weights
參數失效。下面舉例說明。
```python
dataset = DogCat('data/dogcat/', transforms=transform)
# 狗的圖片被取出的概率是貓的概率的兩倍
# 兩類圖片被取出的概率與weights的絕對大小無關,只和比值有關
weights = [2 if label == 1 else 1 for data, label in dataset]
weights
from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
num_samples=9,\
replacement=True)
dataloader = DataLoader(dataset,
batch_size=3,
sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
sampler = WeightedRandomSampler(weights, 8, replacement=False)
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())