參考pytorch官網教程(https://pytorch.org/tutorials/beginner/data_loading_tutorial.html),自己針對要用的數據寫了數據讀取代碼。
由於我的圖像帶有時間順序,後續要用RNN訓練,本來的數據文件存儲爲"標籤/樣本標號/按順序存放的n張圖像"。
因此,爲了方便寫Dataset類,我先把每張圖片按順序把路徑和對應的標籤存到了csv文件中,再在dataset類裏面讀取。
下面是這個過程中遇到的一些函數使用問題記錄。
路徑拼接:
os.path.join(base_path,i):該函數使用時,輸入爲字符串,無需中間加'/'。
創建多維列表:
img_all=[[]for m in range(2)]:該語句在列表中預留兩行空間,通過索引訪問,後續用append會從第3行開始加入
csv文件的讀寫存儲:
https://blog.csdn.net/guoziqing506/article/details/52014506
使用pandas庫讀csv文件:
import pandas as pd
class myDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
self.paths = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image_name = self.paths.iloc[idx, 0]
label = self.paths.iloc[idx, 1]
.iloc[]的索引方式,和matlab裏的訪問很像,比如:.iloc[:3]表示訪問pandas數據裏面第0,1,2行的數據。注意這裏的0,1,2不是索引號,而是按照順序數的0,1,2。(詳情參考了這篇文:https://zhuanlan.zhihu.com/p/35012884)
文件名稱排序問題:
自然取出來的文件夾名稱list,排序方式是按照字符大小排的,這對要保持圖片順序的情況是不行的,因此用到了sort的關鍵字處理。
list3 = os.listdir(im_dir)
list3.sort(key=lambda x: int(x[:-4]))
list3=os.listdir(im_dir)得到list3結果爲:
<class 'list'>: ['10.png', '11.png', '12.png', '13.png', '3.png', '4.png', '5.png', '6.png', '7.png', '8.png', '9.png']
用sort加關鍵字處理後得到
<class 'list'>: ['3.png', '4.png', '5.png', '6.png', '7.png', '8.png', '9.png', '10.png', '11.png', '12.png', '13.png']
這裏 lambda代表根據數字排列,x代表list的各行,int(x[:-4])表示將x的倒數第四個元素以前的轉化爲數字參與排序。參考https://docs.python.org/zh-cn/3/howto/sorting.html
灰度圖像顯示:
如果直接把單通道圖像用語句plt.imshow()顯示,得到的是一個僞彩色圖像,實現灰度圖像,可以用
plt.imshow(image,cmap='gray')
還有其他方法,可看這篇文章:https://blog.csdn.net/blythe0107/article/details/72818518
-------------------------------------------------------------------------更新線-----------------------------------------------
前面說”由於我的圖像帶有時間順序,後續要用RNN”,所以我希望我的dataloader讀入的一個樣本是一個圖像序列,同時也要便於前期CNN對每張圖片提取特徵。
一開始我只是返回了序列圖像的3維list,天真地以爲就OK了,然而發現這種做法到transform那裏是不可行的,包括ToTensor環節Normalize環節(爬了好久才爬出來),正確的處理方式應該是:
- 在dataload.py中,單張圖像imread()讀入後,先用np.array()轉換成ndarray的數據形式。【transforms.ToTensor()只允許輸入PILImage和ndarray格式的圖像,且會將n*n的矩陣轉化成1*n*n的tensor】
- 然後直接上transform條件判斷部分,對單張圖像轉化。【如果是單通道圖像,那麼正則化部分應該是transforms.Normalize((0.5, ), (0.5, ))的形式】
- 再把處理好的圖像逐個放入list裏面,最後返回圖像序列和對應的標籤。
seq=[]
for i in xxx:
...
img = np.array(image)
if self.transform:
img = self.transform(img)
seq.append(img)
return seq, label