1. mxnet.metric.check_lable_shapes(labels, preds, wrap=False, shape=False)
labels: data's labels, ndarray
preds: predicted values, ndarray
wrap : boolean, if True, 如果 labels/preds 是 single NDarray的話就把它們打包成 list.
shape : boolean, if True的話,就check labels and preds's shape, 否則就僅僅check它們的長度。
2. mxnet.metric.EvalMetric(name, output_names=None, label_names=None, **kwargs)
這是一個類!並且是所有評價度量的基類。這個類中提供了通常的評價度量的接口, 不應該直接使用這個類,而應該創建一個子類來繼承它。
這裏的參數,name是要創建的metric實例的名字,output_names是predictions的名字, 應該在更新時使用
label_names是labels's 名字,應該在更新時使用。
然後這個類中有很多方法
2.1 ___init__(self, name, output_names=None, label_names=None, **kwargs)
這個是用來初始化的,不用說了。
2.2 __str__(self)
這個裏面寫的是:
def __str__(self):
return "EvalMetric: {}".format(dict(self.get_name_value()))
所以我們去看一下get_name_value()
2.3
def get_name_value(self)
name, value = self.get()
if not isinstance(name, list):
name = [name]
if not isinstance(value, list):
value = [value]
return list(zip(name, value))
所以我們又要先看一下get函數
2. 4
def get(self):
if self.num_inst == 0:
return (self.name, float('nan'))
else:
return (self.name, self.sum_metric / self.num_inst)
然後接下來看看
2.5 接下來的這兩個函數都用到了update
def get_config(self):
這個直接就是獲得更新後的,metric, name, output_names, label_name
和
2.6
def update_dict(self, label, pred):
這個就是更新了,pred和label分別根據output_names和label_names
然後再調用
update(self, labels, preds)
2.7
def update(self, labels, preds):
注意這個是一定要實現的,不然直接會報錯。
2.8
def create(metric, *args,**args),這個在上一個博文中已經提到了。