ROC曲線繪製

1. 引入相關包

使用matplotlib包作爲繪圖庫,故要引入相關的包

爲了使畫出的圖更爲符合期刊要求,這裏引入SciencePlots。

它是一個基於Matplotlib的補充包,裏面主要包含了一些以.mplstyle爲後綴的圖表樣式的配置文件。這樣,你畫圖的時候只需要通過調用這些配置文件,就能畫出比較好看的數據可視化圖表,也避免了你每次畫圖時都要從頭開始手動配置圖表的格式。

pip install SciencePlots

還要引入numpy對數據進行處理

要計算AUC,還應該引入sklearn中計算相關值的包。

import matplotlib.pyplot as plt
plt.style.use(['science'])
import numpy as np
from sklearn.metrics import roc_curve, auc

然後導入相關數據

# 真實值
y = np.load('npy\\y_test.npy')
# 各種預測值 並非0或1 而是概率
yp_ann = np.load('npy\\ann.npy')
yp_lstm = np.load('npy\\lstm.npy')
yp_lr = np.load('npy\\lr.npy')
yp_rf = np.load('npy\\rf.npy')
yp_xgb = np.load('npy\\xgb.npy')
yp_lgbm = np.load('npy\\lgbm.npy')
yp_catb = np.load('npy\\catb.npy')

2. 計算AUC值

AUC,即AUROC,指的是由TPRFPR圍成的ROC曲線下的面積

將分類任務的實際值和預測值作爲參數輸入給roc_curve()方法可以得到FPR、TPR和對應的閾值。

auc()方法可以計算曲線下的面積,將FPR和TPR作爲參數輸入,即可獲得AUC值。

fpr_1, tpr_1, threshold_1 = roc_curve(y, yp_ann)  # 計算FPR和TPR
auc_1 = auc(fpr_1, tpr_1)  # 計算AUC值

fpr_2, tpr_2, threshold_2 = roc_curve(y, yp_lstm)
auc_2 = auc(fpr_2, tpr_2)

fpr_3, tpr_3, threshold_3 = roc_curve(y, yp_lr)
auc_3 = auc(fpr_3, tpr_3)

fpr_4, tpr_4, threshold_4 = roc_curve(y, yp_rf)
auc_4 = auc(fpr_4, tpr_4)

fpr_5, tpr_5, threshold_5 = roc_curve(y, yp_xgb)
auc_5 = auc(fpr_5, tpr_5)

fpr_6, tpr_6, threshold_6 = roc_curve(y, yp_lgbm)
auc_6 = auc(fpr_6, tpr_6)

fpr_7, tpr_7, threshold_7 = roc_curve(y, yp_catb)
auc_7 = auc(fpr_7, tpr_7)

3. 繪製曲線

首先定義曲線的寬度和圖的大小,如下所示。

line_width = 1  # 曲線的寬度
plt.figure(figsize=(16, 10))  # 圖的大小

使用plt的plot()方法可以繪製曲線,通常可以傳入的參數有以下幾種:

  • x軸的數據
  • y軸的數據
  • lw:線條的寬度
  • label:曲線的標籤(曲線標籤甚至支持LaTex公式,例如$K_{d,1}$
  • color:曲線的顏色(如果不指定,plt會自動選擇)
  • linestyle:線型,包括“-”代表實線,“--”代表虛線,“-.”代表中間有點的虛線,“:”點型虛線
plt.plot(fpr_1, tpr_1, lw=line_width, label='Ann (AUC = %0.4f)' % auc_1,)
plt.plot(fpr_2, tpr_2, lw=line_width, label='Lstm (AUC = %0.4f)' % auc_2,)
plt.plot(fpr_3, tpr_3, lw=line_width, label='LogisticRegression (AUC = %0.4f)' % auc_3,)
plt.plot(fpr_4, tpr_4, lw=line_width, label='RandomForest (AUC = %0.4f)' % auc_4,)
plt.plot(fpr_5, tpr_5, lw=line_width, label='XGboost (AUC = %0.4f)' % auc_5,)
plt.plot(fpr_6, tpr_6, lw=line_width, label='LightGBM (AUC = %0.4f)' % auc_6,)
plt.plot(fpr_7, tpr_7, lw=line_width, label='Catboost (AUC = %0.4f)' % auc_7,)

4. 座標軸範圍和標題

限定x軸和y軸的範圍,如下所示。

plt.xlim([0.0, 1.0])  # 限定x軸的範圍
plt.ylim([0.0, 1.0])  # 限定y軸的範圍

也可以通過xticks()和yticks()直接調整座標軸的刻度,如下所示。

# plt.xticks(range(0, 10, 1)) # 修改x軸的刻度
# plt.yticks(range(0, 10, 1)) # 修改y軸的刻度

指定座標軸的標題,如下所示。

plt.xlabel('False Positive Rate')  # x座標軸標題
plt.ylabel('True Positive Rate')  # y座標軸標題

使用grid()方法在圖中添加網格,如下所示。

plt.grid()  # 在圖中添加網格

顯示圖例並指定圖例位置,常見位置包括{upper,center,lower} {left,center,right},如下所示。

plt.legend(loc="lower right")  # 顯示圖例並指定圖例位置

5. 中文處理問題

如果在座標軸、標題等地方出現了中文,plt會顯示亂碼,添加以下兩條語句可以解決中文處理問題。

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

6. 展示圖片和保存

TIFF格式(Tag Image File Format,TIFF)是常見的論文圖片投稿格式,TIFF格式能夠製作質量非常高的圖像,多數出版社(如Springer、Elsevier)都接受並推薦使用dpi=300的TIFF格式的插圖。

plt.savefig('AUC.tif', dpi=300)

使用plt的show方法展示曲線,如下所示。

plt.show()

7. 示例代碼

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import precision_recall_curve

y = np.load('npy\\y_test.npy')
yp_ann = np.load('npy\\ann.npy')
yp_lstm = np.load('npy\\lstm.npy')
yp_lr = np.load('npy\\lr.npy')
yp_rf = np.load('npy\\rf.npy')
yp_xgb = np.load('npy\\xgb.npy')
yp_lgbm = np.load('npy\\lgbm.npy')
yp_catb = np.load('npy\\catb.npy')

fpr_1, tpr_1, threshold_1 = roc_curve(y, yp_ann)  # 計算FPR和TPR
auc_1 = auc(fpr_1, tpr_1)  # 計算AUC值

fpr_2, tpr_2, threshold_2 = roc_curve(y, yp_lstm)
auc_2 = auc(fpr_2, tpr_2)

fpr_3, tpr_3, threshold_3 = roc_curve(y, yp_lr)
auc_3 = auc(fpr_3, tpr_3)

fpr_4, tpr_4, threshold_4 = roc_curve(y, yp_rf)
auc_4 = auc(fpr_4, tpr_4)

fpr_5, tpr_5, threshold_5 = roc_curve(y, yp_xgb)
auc_5 = auc(fpr_5, tpr_5)

fpr_6, tpr_6, threshold_6 = roc_curve(y, yp_lgbm)
auc_6 = auc(fpr_6, tpr_6)

fpr_7, tpr_7, threshold_7 = roc_curve(y, yp_catb)
auc_7 = auc(fpr_7, tpr_7)

plt.style.use(['science'])
line_width = 2  # 曲線的寬度
plt.figure(figsize=(8, 5))  # 圖的大小

plt.plot(fpr_1, tpr_1, lw=line_width, label='Ann (AUC = %0.4f)' % auc_1,)
plt.plot(fpr_2, tpr_2, lw=line_width, label='Lstm (AUC = %0.4f)' % auc_2,)
plt.plot(fpr_3, tpr_3, lw=line_width, label='LogisticRegression (AUC = %0.4f)' % auc_3,)
plt.plot(fpr_4, tpr_4, lw=line_width, label='RandomForest (AUC = %0.4f)' % auc_4,)
plt.plot(fpr_5, tpr_5, lw=line_width, label='XGboost (AUC = %0.4f)' % auc_5,)
plt.plot(fpr_6, tpr_6, lw=line_width, label='LightGBM (AUC = %0.4f)' % auc_6,)
plt.plot(fpr_7, tpr_7, lw=line_width, label='Catboost (AUC = %0.4f)' % auc_7,)


plt.xlim([0.0, 1.0])  # 限定x軸的範圍
plt.ylim([0.0, 1.0])  # 限定y軸的範圍
plt.xlabel('False Positive Rate', fontsize=16)  # x座標軸標題
plt.ylabel('True Positive Rate', fontsize=16)  # y座標軸標題
plt.title('ROC', fontsize=16)  # 標題
plt.grid()  # 在圖中添加網格
plt.legend(loc="lower right", fontsize=16)  # 顯示圖例並指定圖例位置

plt.savefig('ROC.tif', dpi=300)
plt.show()

image-20221026202136784

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章