Matlab畫帶星的ROC曲線

function [AUC betterthreshold MCCMAXThreshold Confuse_Matrix MCCs FPRs TPRs F1S] = ROC_2_Class_OneFile_WithStar(begin_threshold, end_threshold, step, ismaxmcc,classificationFileName, star_num, star_style)
%  
%   The format of the file with path 'classificationFileName':
%       truelabel  proabilityOfFirstClass   proabilityOfSecondClass
%           1           0.78                        0.22
%           1           0.68                        0.32
%           2           0.56                        0.44
%           1           0.89                        0.11
%           2           0.22                        0.78
%           ....        ....                        ....
%   ismaxmcc
%  e.g., [AUC betterthreshold MCCMAXThreshold Confuse_Matrix MCCs FPRs TPRs F1S] = ROC_2_Class_OneFile_WithStar(0, 1, 0.01, 1, 'classificationFileName', 50, 'b--')
%

if nargin < 4, error('NNET:Arguments','Not enough input arguments.'); end
if ((begin_threshold>end_threshold) || step < 0.0)
    error('NNET:Arguments','The pattern of a parameter is not right.');
end

A = load(classificationFileName);
[H L] = size(A);

maxTpr = 0;
maxFpr = 0;
balanceTpr = 0;
balanceFpr = 0;
MCCMAXThreshold = 0.5;
MCCPREThreshold = 0.5;
MCCs = [];
FPRs = [];
TPRs = [];
F1S = [];
tprs = [];
fprs = [];
th_fpr_5 = 0;
th_fpr_6 = 0;
th_fpr_7 = 0;
th_fpr_8 = 0;
th_fpr_9 = 0;
th_fpr_10 = 0;
maxMcc = -9999;
maxPre = -9999;
for threshold=begin_threshold:step:end_threshold
    total1_1 = 0;
    total1_2 = 0;
    total2_1 = 0;
    total2_2 = 0;
    for i=1:H
        if A(i, 2)>=threshold
            if 1==A(i, 1) 
                total1_1 = total1_1 + 1; 
            else
                total2_1 = total2_1 + 1; 
            end
        else 
            if 1==A(i, 1)  
                total1_2 = total1_2 + 1;
            else
                total2_2 = total2_2 + 1;
            end
        end
    end   %end_for_i
    tpr = 1.0*total1_1/(total1_1+total1_2);
    fpr = 1.0*total2_1/(total2_1+total2_2);
    
    pre = 1.0*total1_1/(total1_1+total2_1);
    
    f1_score = (2*tpr*pre) / (tpr + tpr);
    FPRs = [FPRs, fpr];
    TPRs = [TPRs, tpr];
    F1S = [F1S, f1_score];
    if fpr>=0.05
        if fpr<=0.055
            th_fpr_5 = threshold;
        end
    end
    
     if fpr>=0.06
        if fpr<=0.065
            th_fpr_6 = threshold;
        end
     end
    
    if fpr>=0.07
        if fpr<=0.075
            th_fpr_7 = threshold;
        end
     end
    
    if fpr>=0.08
        if fpr<=0.085
            th_fpr_8 = threshold;
        end
    end
    
    if fpr>=0.09
        if fpr<=0.095
            th_fpr_9 = threshold;
        end
    end
    
    if fpr>=0.10
        if fpr<=0.105
            th_fpr_10 = threshold;
        end
    end
    
    tprs = [tprs, tpr];
    fprs = [fprs, fpr];
    
    %calculate mcc
    mcc = (total1_1*total2_2 - total2_1*total1_2)/sqrt((total1_1+total2_1)*(total1_1+total1_2)*(total2_2+total2_1)*(total2_2+total1_2));
    MCCs = [MCCs, mcc];
    if mcc > maxMcc
        maxMcc = mcc;
        MCCMAXThreshold = threshold;
        
        maxTpr = tpr;
        maxFpr = fpr;
    end
    
    %calculate pre
    pre = total1_1/(total1_1+total2_1);
    if pre+mcc>maxPre+maxMcc
        maxPre=pre;
        MCCPREThreshold=threshold;
    end
    
end 

[L K] = size(tprs);

wc = 1;
tag = 0;
for  x=1:K
    if 0~=tprs(1, x) && 1~=tprs(1,x) && 0~=fprs(1, x)&& 1~=fprs(1,x) && wc>abs(tprs(1, x)+fprs(1,x)-1)
        wc = abs(tprs(1, x)+fprs(1,x)-1);
        tag = x;
        
        balanceTpr = tprs(1, x);
        balanceFpr = fprs(1,x);
    end
end

betterthreshold = begin_threshold + tag*step;

AUC = roc_curve(A(:, 2), A(:, 1), star_num, star_style);

if 1 == ismaxmcc
%     fprintf('Max Threshold = %.3f\n', MCCMAXThreshold);
    threshold = MCCMAXThreshold;
elseif 0 == ismaxmcc
%     fprintf('Balance Threshold = %.3f\n', betterthreshold);
    threshold = betterthreshold;
elseif 5 == ismaxmcc
%     fprintf('FPR=0.05 Threshold = %.3f\n', th_fpr_5);
    threshold = th_fpr_5;
elseif 6 == ismaxmcc
%     fprintf('FPR=0.06 Threshold = %.3f\n', th_fpr_6);
    threshold = th_fpr_6;
elseif 7 == ismaxmcc
%     fprintf('FPR=0.07 Threshold = %.3f\n', th_fpr_7);
    threshold = th_fpr_7;
elseif 8 == ismaxmcc
%     fprintf('FPR=0.08 Threshold = %.3f\n', th_fpr_8);
    threshold = th_fpr_8;
elseif 9 == ismaxmcc
%     fprintf('FPR=0.09 Threshold = %.3f\n', th_fpr_9);
    threshold = th_fpr_9;
elseif 10 == ismaxmcc
%     fprintf('FPR=0.10 Threshold = %.3f\n', th_fpr_10);
    threshold = th_fpr_10;
end
threshold 

total1_1 = 0;
total1_2 = 0;
total2_1 = 0;
total2_2 = 0;

for i=1:H
    if A(i, 2)>=threshold 
        %disp(i);
        if 1==A(i, 1) 
                total1_1 = total1_1 + 1;
        else
                total2_1 = total2_1 + 1;
        end
    else 
        if 1==A(i, 1) 
           total1_2 = total1_2 + 1;
        else
           total2_2 = total2_2 + 1;
        end
    end
end   %end_for_i
Confuse_Matrix = [total1_1 total1_2; total2_1, total2_2];

end

function auc = roc_curve(deci,label_y, star_num, star_style)
    label_y(label_y~=1)=-1;
	[tmp,ind] = sort(deci,'descend');
	roc_y = label_y(ind);
	stack_x = cumsum(roc_y == -1)/sum(roc_y == -1);
	stack_y = cumsum(roc_y == 1)/sum(roc_y == 1);
	auc = sum((stack_x(2:length(roc_y),1)-stack_x(1:length(roc_y)-1,1)).*stack_y(2:length(roc_y),1));
    
    step = floor( size(stack_x, 1)/star_num );
    select_indexes = 1:step:size(stack_x, 1);
    
	plot(stack_x(select_indexes, :), stack_y(select_indexes, :), star_style);
    
    set(gca,'FontName','Times New Roman','FontSize', 12) %設置座標軸字體大小,字型
	xlabel('False Positive Rate', 'FontName', 'Times New Roman', 'FontSize', 20);
	ylabel('True Positive Rate', 'FontName', 'Times New Roman', 'FontSize', 20);
	title('ROC Curve', 'FontName', 'Times New Roman', 'FontSize', 20);    
end

This matlab code can be used to draw a ROC curve with some stars.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章