三、生成heatmap(二)基於patch畫熱力圖

太長不看版

1、將PIL.Image轉換成批訓練的DataLoader

  • 爲什麼一批一批進去處理

2、載入網絡( torch.load('Resnet.pkl') ),並將數據放入網絡,通過 outputs = model(images) 得到預測值,放在對應的對象中

3、按對象的數字順序排序,生成熱力圖保存

 

耐心看完版

事先準備:

test_path = r'C:\Users\BME419\Desktop\resnet\slide\patch'
background_path = r'E:\WSI\CAMELYON16\Processed\patch-based-classification\raw-data\test\tumor091heatmaps  patches\none'

pre_savename = r'C:\Users\BME419\Desktop\resnet\slide\heatmap'
savename = os.path.join(pre_savename, 'heatmap')   #保存名字爲'heatmap'

batch_size = 64
classes = ['negative','positive']
global positive_prob
positive_prob = []                      #positive_prob類型爲list
def reload_net(model_name):    #可選擇四種網絡
    if  model_name == "VGG":
        trainednet = torch.load('VGGnet.pkl')
    elif model_name == "Google":
        trainednet = torch.load('Google.pkl')
    elif model_name == "Res":
        trainednet = torch.load('Resnet.pkl')
    elif model_name == "Alex":
        trainednet = torch.load('Alexnet.pkl')
    return trainednet

 主函數:

test_data("Res", 224)     #內涵調用heatmap_gen()

具體函數 ↓ ↓ 

def test_data(model_name, input_size):       # input_size = 224
    # 先轉換成 torch 能識別的 Dataset
    testset = torchvision.datasets.ImageFolder(test_path,
                            transform = transforms.Compose([
                            transforms.Resize((input_size, input_size)),
                            # 將圖片縮放到指定大小(h,w)或者保持長寬比並縮放最短的邊到int大小
                            transforms.ToTensor(),
                            ]))      
    # 把 dataset 放入 DataLoader                                
    testloader = torch.utils.data.DataLoader(testset, batch_size = batch_size,            
                                             shuffle = False, num_workers = 0)    # shuffle = False(不打亂),按順序取patch,否則 = True,隨機取
    model = reload_net(model_name)   # load模型,函數具體定義見“事先準備”   
    model.eval()         # 把BN和Dropout固定住,不會取平均,而是用訓練好的值
loader
產生的Dataset和DataLoader
    #將testset.imgs從tuple變爲list(用於append網絡產生的概率(outputs))
    for j in range(len(testset.imgs)):      
        testset.imgs[j] = list(testset.imgs[j])
    #利用訓練好的網絡預測patch概率
    for i, data in enumerate(testloader, 0):
        images, labels = data
        print(labels)
        images = Variable(images, requires_grad=True)   # 轉換數據格式用Variable
        if torch.cuda.is_available():
            images = images.cuda()       # 轉換數據格式用Variable
            model.cuda()
        with torch.no_grad():
            outputs = model(images)
            outputs = outputs.cpu().numpy()        # 將outputs由GPU轉化爲numpy
            positive_prob.extend(outputs[:, 1])    # positive_prob(list):各個patch腫瘤positive的概率
            if i <= (len(testset.imgs) / 64-1):    # 將testset.imgs與各自概率一一對應
                for j in range(64):
                    testset.imgs[j + 64 * i].append(outputs[j, 1])
            else:
                for j in range(len(testset.imgs) - 64 * i):
                    testset.imgs[j + 64 * i].append(outputs[j, 1])

    probmin = np.min(positive_prob)  # 用於背景顯示
    heatmap_gen(positive_prob, testset,input_size,probmin) #調用heatmap_gen(),具體見↓
outputs
0表示'negative', 1表示'positive'

def heatmap_gen(positive_prob,testset,input_size,probmin):
    fig = plt.figure(figsize=(172, 153))   #figsize以英寸爲單位 width=172,height=153
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)

    # 讀背景patch
    testset2 = torchvision.datasets.ImageFolder(background_path,
                            transform=transforms.Compose([
                            transforms.Resize((input_size, input_size)),
                          # 將圖片縮放到指定大小(h,w)或者保持長寬比並縮放最短的邊到int大小
                            transforms.ToTensor(),   # 把一個取值範圍是[0,255]的PIL.Image 轉換成 Tensor
                            ]))
    for n in range(len(testset2.imgs)):   # 將背景patch概率減10,並append到testset
        testset2.imgs[n] = list(testset2.imgs[n])
        testset2.imgs[n].append(probmin-10)     
        testset.imgs.append(testset2.imgs[n])   
 
    #使路徑減成數字,排序
    for n in range(len(testset.imgs)):
        testset.imgs[n][0] = os.path.basename(testset.imgs[n][0])
    testset.imgs.sort(key=lambda x: int(x[0][:-6]))

    positive_prob = [None] * len(testset.imgs)
    positive_prob = np.array(positive_prob)

    # 按照片名從小到大生成positive_prob
    for n in range(len(testset.imgs)):
        positive_prob[n] = testset.imgs[n][2]














    probmin = np.min(positive_prob)
    probmax = np.max(positive_prob)
    heatmap = ((positive_prob - probmin) / (probmax - probmin + 0.000001)) * 255  # float在[0,1]之間,轉換成0-255
    heatmap = heatmap.astype(np.uint8)  # 轉成unit8
    heatmap = heatmap.reshape(153, 172)      # y爲行,x爲列
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 生成heat map
    heatmap = heatmap[:, :, ::-1]  # 注意cv2(BGR)和matplotlib(RGB)通道是相反的
    plt.imshow(heatmap)
    fig.savefig(savename, dpi=10)  #dpi指每英寸有多少個像素,save路徑見最上面

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章