Mxnet中Metric API中的Evaluation Metric API

1. mxnet.metric.check_lable_shapes(labels, preds, wrap=False, shape=False)

labels: data's labels, ndarray

preds:  predicted values, ndarray

wrap : boolean, if True, 如果 labels/preds 是 single NDarray的話就把它們打包成 list.

shape : boolean, if True的話,就check labels and preds's shape, 否則就僅僅check它們的長度。


2. mxnet.metric.EvalMetric(name, output_names=None, label_names=None, **kwargs)

這是一個類!並且是所有評價度量的基類。這個類中提供了通常的評價度量的接口, 不應該直接使用這個類,而應該創建一個子類來繼承它。

這裏的參數,name是要創建的metric實例的名字,output_names是predictions的名字, 應該在更新時使用

label_names是labels's 名字,應該在更新時使用。

然後這個類中有很多方法

2.1    ___init__(self, name, output_names=None, label_names=None, **kwargs)

這個是用來初始化的,不用說了。

2.2  __str__(self)

這個裏面寫的是:

def __str__(self):  

     return "EvalMetric: {}".format(dict(self.get_name_value()))

所以我們去看一下get_name_value()

2.3

def get_name_value(self)

name, value = self.get()  

if not isinstance(name, list):  

name = [name]  

if not isinstance(value, list):  

value = [value]  

return list(zip(name, value))

所以我們又要先看一下get函數

2. 4

def get(self):

   if self.num_inst == 0:  

return (self.name, float('nan'))  

else:  

return (self.name, self.sum_metric / self.num_inst)

然後接下來看看

2.5 接下來的這兩個函數都用到了update

def get_config(self):

這個直接就是獲得更新後的,metric, name, output_names, label_name

2.6

def update_dict(self, label, pred):

這個就是更新了,pred和label分別根據output_names和label_names

然後再調用

update(self, labels, preds)

2.7

def update(self, labels, preds):

注意這個是一定要實現的,不然直接會報錯。

2.8

def create(metric, *args,**args),這個在上一個博文中已經提到了。





發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章