【Matplotlib】在Jupyter交互頁面中繪製折線圖對比(自用函數)

0x00 前言

最近數據對比的任務比較常見,比如好些模型的橫向對比,
對於 Loss、PRF、Hits 之類的數據,有時需要作log,有時需要去除前面幾個值,
還要考慮數據不對齊、記錄文件格式不一致等諸多問題,總之主需求是魯棒性,
稍微寫了個畫表格的函數,暫時還比較亂,後續有時間再作優化好了,
現在暫時寫在這作爲記錄,便於易於在不同機器上獲取使用以及方便後續優化更新。

0x01 用法

# single
draw(y, y_label='Data', title='undefined title', semilog=False)

# multiple
drawx(y_list, label_list=None, title='undefined title', 
          start=0, end=None, x_list=None, x_label=None, semilog=False, axh=0)

0x02 Source Code


import matplotlib.pyplot as plt
from pylab import *  # 支持中文
import warnings
warnings.filterwarnings('ignore')
mpl.rcParams['font.sans-serif'] = ['SimHei']

#plt.plot(x, y, 'ro-')
#plt.plot(x, y1, 'bo-')
#pl.xlim(-1, 11)  # 限定橫軸的範圍
#pl.ylim(-1, 110)  # 限定縱軸的範圍

def focus(filename, label='[Valid_F1_G]', pos=1, delim='\t'):
    # used for logs like '[Valid_F1_G]\t0.667\t0.667\t0.667'
    target = filter(
        lambda x: x.startswith(label),
        [line for line in open(filename, 'r')]
    )
    return map(
        lambda x: x.split(delim)[pos].strip(),
        target
    )


def draw(y, y_label='Data', title='undefined title', semilog=False):
    x = range(1, y.__len__()+1)
    plt.plot(x, y, marker='o', mec='r', mfc='w', label=y_label)
    plt.legend()  # 讓圖例生效
    #plt.xticks(x, x, rotation=0)
    plt.margins(0.05)
    plt.subplots_adjust(bottom=0.15)
    plt.xlabel(u"Epoch") # X軸標籤
    plt.title("Brief Figure for {}".format(title)) #標題
    if semilog: plt.semilogy()
    plt.show()

    
def draw2(y1, y2, 
          y1_label='Data1', y2_label='Data2', 
          semilog=False, title='undefined title', start=0, end=9999, axh=0):
    length = min(end, max(y1.__len__(), y2.__len__()))
    y1_len = min(length, y1.__len__())
    y2_len = min(length, y2.__len__())
    x = range(1, length+1)
    plt.figure(figsize=(15, 5))
    plt.plot(x[start:y1_len], y1[start:y1_len], marker='o', mec='r', mfc='w', label=y1_label)
    plt.plot(x[start:y2_len], y2[start:y2_len], marker='X', mfc='w', ms=8, label=y2_label)
    plt.legend()  # 讓圖例生效
    # plt.xticks(x1, x1, rotation=0)
    plt.margins(0.05)
    plt.subplots_adjust(bottom=0.15)
    plt.xlabel(u"Epoch") #X軸標籤
    plt.title("Brief Figure for {}".format(title)) #標題
    if axh: plt.axhline(axh)
    if semilog: plt.semilogy()
    plt.show()

    
def draw4(y1, y2, y3, y4,
          y1_label, y2_label, y3_label, y4_label,
          title = 'title', start=0, semilog=False, axh=0):
    length = min(y1.__len__(), y2.__len__(), y3.__len__(), y4.__len__())
    x = range(1, length+1)
    plt.figure(figsize=(15, 5))
    plt.plot(x[start:], y1[start:length], marker='o', mec='r', mfc='w', label=y1_label)
    plt.plot(x[start:], y2[start:length], marker='X', mfc='w', ms=8, label=y2_label)
    plt.plot(x[start:], y3[start:length], marker='*', mfc='w', mec='b', label=y3_label)
    plt.plot(x[start:], y4[start:length], marker='.', mfc='w', label=y4_label)
    plt.legend()  # 讓圖例生效
    plt.margins(0.05)
    plt.subplots_adjust(bottom=0.15)
    plt.xlabel(u"Epoch") #X軸標籤
    plt.title("Brief Figure for {}".format(title)) #標題
    if axh: plt.axhline(axh)
    if semilog: plt.semilogy()
    plt.grid(False)
    plt.show()

    
def drawx(y_list, label_list=None, title='undefined title', 
          start=0, end=None, x_list=None, x_label=None, semilog=False, axh=0):
    length = min(end, max(map(len, y_list))) if end else max(map(len, y_list))
    y_length = map(lambda x: min(length, x.__len__()), y_list)
    x_list = range(1, length+1) if x_list is None else x_list[:length]
    plt.figure(figsize=(15, 5))
    if label_list is None:
        label_list = map(
            lambda x: u"Data_{}".format(x), 
            range(y_list.__len__()))
    for idx, label in enumerate(label_list):
        plt.plot(
            x_list[start: y_length[idx]], 
            y_list[idx][start: y_length[idx]], 
            marker='o',
            label=label)
    plt.legend()  # let legends work
    plt.margins(0.05)
    plt.subplots_adjust(bottom=0.15)
    plt.xlabel(u"Epoch" if x_label is None else x_label)
    plt.title(u"Brief Figure for {}".format(title)) #標題
    if axh:  # Add a horizontal line across the axis.
        plt.axhline(axh)
    if semilog: 
        plt.semilogy()
        
    plt.grid(False)
    plt.show()
    ```
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章