python 繪製聲紋識別DET曲線

import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from scipy.stats import norm
import numpy as np

def plot_DET_curve():
    # 設置刻度範圍
    pmiss_min = 0.001

    pmiss_max = 0.6  

    pfa_min = 0.001

    pfa_max = 0.6

    # 刻度設置
    pticks = [0.00001, 0.00002, 0.00005, 0.0001, 0.0002, 0.0005,
            0.001, 0.002, 0.005, 0.01, 0.02, 0.05,
            0.1, 0.2, 0.4, 0.6, 0.8, 0.9,
            0.95, 0.98, 0.99, 0.995, 0.998, 0.999,
            0.9995, 0.9998, 0.9999, 0.99995, 0.99998, 0.99999]

    # 刻度*100
    xlabels = [' 0.001', ' 0.002', ' 0.005', ' 0.01 ', ' 0.02 ', ' 0.05 ',
            '  0.1 ', '  0.2 ', ' 0.5  ', '  1   ', '  2   ', '  5   ',
            '  10  ', '  20  ', '  40  ', '  60  ', '  80  ', '  90  ',
            '  95  ', '  98  ', '  99  ', ' 99.5 ', ' 99.8 ', ' 99.9 ',
            ' 99.95', ' 99.98', ' 99.99', '99.995', '99.998', '99.999']

    ylabels = xlabels

    # 確定刻度範圍
    n = len(pticks)
    # 倒敘 
    for k, v in enumerate(pticks[::-1]):
        if pmiss_min <= v:
            tmin_miss = n - k - 1   # 移動最小值索引位置
        if pfa_min <= v:
            tmin_fa = n - k - 1   # 移動最小值索引位置
    # 正序
    for k, v in enumerate(pticks):
        if pmiss_max >= v:   
            tmax_miss = k+1         # 移動最大值索引位置
        if pfa_max >= v:            
            tmax_fa = k+1            # 移動最大值索引位置

    # FRR
    plt.figure()
    plt.xlim(norm.ppf(pfa_min), norm.ppf(pfa_max))

    plt.xticks(norm.ppf(pticks[tmin_fa:tmax_fa]), xlabels[tmin_fa:tmax_fa])
    plt.xlabel('False Alarm probability (in %)')

    # FAR
    plt.ylim(norm.ppf(pmiss_min), norm.ppf(pmiss_max))
    plt.yticks(norm.ppf(pticks[tmin_miss:tmax_miss]), ylabels[tmin_miss:tmax_miss])
    plt.ylabel('Miss probability (in %)')

    return plt

# 計算EER
def compute_EER(frr,far):
    threshold_index = np.argmin(abs(frr - far))  # 平衡點
    eer = (frr[threshold_index]+far[threshold_index])/2
    print("eer=",eer)
    return eer

# 計算minDCF P_miss = frr  P_fa = far
def compute_minDCF2(P_miss,P_fa):
    C_miss = C_fa = 1
    P_true = 0.01
    P_false = 1-P_true

    npts = len(P_miss)
    if npts != len(P_fa):
        print("error,size of Pmiss is not euqal to pfa")
    
    DCF = C_miss * P_miss * P_true + C_fa * P_fa*P_false

    min_DCF = min(DCF)

    print("min_DCF_2=",min_DCF)

    return min_DCF


# 計算minDCF P_miss = frr  P_fa = far
def compute_minDCF3(P_miss,P_fa,min_DCF_2):
    C_miss = C_fa = 1
    P_true = 0.001
    P_false = 1-P_true

    npts = len(P_miss)
    if npts != len(P_fa):
        print("error,size of Pmiss is not euqal to pfa")
    
    DCF = C_miss * P_miss * P_true + C_fa * P_fa*P_false

    # 該操作是我自己加的,因爲論文中的DCF10-3指標均大於DCF10-2且高於0.1以上,所以通過這個來過濾一下,錯誤請指正
    min_DCF = 1
    for dcf in DCF:
        if dcf > min_DCF_2+0.1 and dcf < min_DCF:
            min_DCF = dcf

    print("min_DCF_3=",min_DCF)
    return min_DCF


if __name__ == "__main__":
    # 讀文件獲取y_true和y_score
    y_true = np.load('./dataset/y_true.npy')
    y_score = np.load('./dataset/y_pre.npy')

    # 計算FAR和FRR
    fpr, tpr, thres = roc_curve(y_true, y_score)
    frr = 1 - tpr
    far = fpr
    frr[frr <= 0] = 1e-5
    far[far <= 0] = 1e-5
    frr[frr >= 1] = 1-1e-5
    far[far >= 1] = 1-1e-5


    # 畫圖
    plt = plot_DET_curve()
    x, y = norm.ppf(frr), norm.ppf(far)
    plt.plot(x, y)
    plt.plot([-40, 1], [-40, 1])
    # plt.plot(np.arange(0,40,1),np.arange(0,40,1))
    plt.show()

    
    eer = compute_EER(frr,far)

    min_DCF_2 = compute_minDCF2(frr*100,far*100)

    min_DCF_3 = compute_minDCF3(frr*100,far*100,min_DCF_2)

效果圖:

發佈了9 篇原創文章 · 獲贊 0 · 訪問量 8884
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章