圖片分類之繪製混淆矩陣+PCA降維繪圖

繪製混淆矩陣代碼

# 繪製混淆矩陣
def plotCM(matrix, classes):
    
    def plot_confusion_matrix(cm, labels,title='Confusion Matrix', cmap = plt.cm.Blues):
        # plt.figure(figsize=(950,950))
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()
        xlocations = np.array(range(len(labels)))
        plt.xticks(xlocations, labels, rotation=90)
        plt.yticks(xlocations, labels)
        plt.ylabel('True label')
        plt.xlabel('Predicted label')
        plt.savefig('./HAR_cm.png')
        plt.show()

    """classes: a list of class names"""
    # Normalize by row
   
    cm_normalized = matrix.astype('float')/matrix.sum(axis=1)[:, np.newaxis]
    print(cm_normalized.shape)

    # plot
    fig = plt.figure()
    ind_array = np.arange(len(classes))
    x, y = np.meshgrid(ind_array, ind_array)
    print(x.shape)


    tick_marks = np.array(range(len(classes))) + 0.5

    for x_val, y_val in zip(x.flatten(), y.flatten()):
        c = cm_normalized[y_val][x_val]
        if (c > 0.01):
	        plt.text(x_val, y_val, "%0.2f" %(c,), color='red', fontsize=7, va='center', ha='center')

    #offset the tick
    plt.gca().set_xticks(tick_marks, minor=True)
    plt.gca().set_yticks(tick_marks, minor=True)
    plt.gca().xaxis.set_ticks_position('none')
    plt.gca().yaxis.set_ticks_position('none')
    plt.grid(True, which='minor', linestyle='-')
    plt.gcf().subplots_adjust(bottom=0.15)

    plot_confusion_matrix(cm_normalized, classes,title='Normalized confusion matrix')

調用代碼

  • classes代表類別標籤列表,注意要按照y_true中的順序排列。
  • label_id_name_dict 是一個字典,key是索引,value是標籤名字
# 繪製混淆矩陣
    # classes = list(infer.label_id_name_dict.values())
    # classes = list(range(infer.num_classes))
    classes = list(set(y_true))
    classes.sort()
    classes_name = [infer.label_id_name_dict[c] for c in classes]
    cf_matrix = confusion_matrix(y_true,y_pred)
    plotCM(cf_matrix,classes_name)

效果圖

在這裏插入圖片描述

PCA降維繪圖

降低到3維
  • colors是顏色列表 x_train是特徵向量(降維後)、y_train是特徵標籤(降維後)、class_names是標籤列表(按照y_train的順序排序)
  • 降低到2維就把Axes3D換成Axes2D即可

def plot_pca_scatter(x_train,y_train,class_names):
    print(x_train.shape)
    print(y_train.shape)

    colors = ['r', 'g', 'b', 'c', 'k', 'm', 'y', 'r', 'g', 'b', 'c', 'k', 'm', 'y', 'r', 'g', 'b', 'c', 'k', 'm', 'y', 'r', 'g', 'b', 'c', 'k', 'm', 'y', 'r', 'g', 'b', 'c', 'k', 'm', 'y', 'r', 'g', 'b', 'c', 'k', 'm', 'y', 'r', 'g', 'b', 'c', 'k', 'm', 'y', 'r']

    ax = Axes3D(plt.figure())
    # for c, i, target_name in zip(colors,list(range(len(class_names))), class_names):
    #     plt.scatter(x_train[y_train==i, 0], x_train[y_train==i, 1], c=c, label=target_name)
    
    for c, i, target_name in zip(colors,list(range(len(class_names))), class_names):
        ax.scatter(x_train[y_train==i, 0], x_train[y_train==i, 1],x_train[y_train==i, 2], c=c, label=target_name)

    #設置每個座標的取值範圍
    # plt.axis([-20,20,-20,20])
    # plt.xlabel('Dimension1')
    # plt.ylabel('Dimension2')
    plt.title('data distribution')
    plt.legend()
    plt.show()

效果圖

在這裏插入圖片描述

發佈了9 篇原創文章 · 獲贊 0 · 訪問量 8887
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章