keras學習筆記(二):實現f1_score(多分類、二分類)

首先容易谷歌到的兩種方法:

1. 構造metrics

這種方法適用於二分類,在模型訓練的時候可以作爲metrics使用。使用的是固定閾值0.5。

from keras import backend as K
def f1(y_true, y_pred):
    def recall(y_true, y_pred):
        """Recall metric.

        Only computes a batch-wise average of recall.

        Computes the recall, a metric for multi-label classification of
        how many relevant items are selected.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = true_positives / (possible_positives + K.epsilon())
        return recall

    def precision(y_true, y_pred):
        """Precision metric.

        Only computes a batch-wise average of precision.

        Computes the precision, a metric for multi-label classification of
        how many selected items are relevant.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision
    precision = precision(y_true, y_pred)
    recall = recall(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

model.compile(loss='binary_crossentropy',
          optimizer= "adam",
          metrics=[f1])

2.callbacks

此做法不推薦,使用的過程中出現bug。

import numpy as np
from keras.callbacks import Callback
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score
class Metrics(Callback):
def on_train_begin(self, logs={}):
 self.val_f1s = []
 self.val_recalls = []
 self.val_precisions = []
 
def on_epoch_end(self, epoch, logs={}):
 val_predict = (np.asarray(self.model.predict(self.model.validation_data[0]))).round()
 val_targ = self.model.validation_data[1]
 _val_f1 = f1_score(val_targ, val_predict)
 _val_recall = recall_score(val_targ, val_predict)
 _val_precision = precision_score(val_targ, val_predict)
 self.val_f1s.append(_val_f1)
 self.val_recalls.append(_val_recall)
 self.val_precisions.append(_val_precision)
 print “ — val_f1: %f — val_precision: %f — val_recall %f” %(_val_f1, _val_precision, _val_recall)
 return
 
metrics = Metrics()

model.fit(training_data, training_target, 
 validation_data=(validation_data, validation_target),
 nb_epoch=10,
 batch_size=64,
 callbacks=[metrics])

3.最終版:

多分類:

class Metrics(Callback):
    def on_train_begin(self, logs={}):
        self.val_f1s = []
        self.val_recalls = []
        self.val_precisions = []

    def on_epoch_end(self, epoch, logs={}):
#         val_predict = (np.asarray(self.model.predict(self.validation_data[0]))).round()
        val_predict = np.argmax(np.asarray(self.model.predict(self.validation_data[0])), axis=1)
#         val_targ = self.validation_data[1]
        val_targ = np.argmax(self.validation_data[1], axis=1)
        _val_f1 = f1_score(val_targ, val_predict, average='macro')
#         _val_recall = recall_score(val_targ, val_predict)
#         _val_precision = precision_score(val_targ, val_predict)
        self.val_f1s.append(_val_f1)
#         self.val_recalls.append(_val_recall)
#         self.val_precisions.append(_val_precision)
#         print('— val_f1: %f — val_precision: %f — val_recall %f' %(_val_f1, _val_precision, _val_recall))
        print(' — val_f1:' ,_val_f1)
        return

# 其他metrics可自行添加
metrics = Metrics()

二分類:

class Metrics(Callback):
    def on_train_begin(self, logs={}):
        self.val_f1s = []
        # self.val_recalls = []
        # self.val_precisions = []

    def on_epoch_end(self, epoch, logs={}):
    	val_targ = self.validation_data[1]
        val_predict = self.model.predict(self.validation_data[0])

        best_threshold = 0
	    best_f1 = 0
	    for threshold in [i * 0.01 for i in range(25,45)]:
	    	y_pred = y_pred=(y_pred > threshold).astype(int)
	    	# val_recall = recall_score(val_targ, y_pred)
        	# val_precision = precision_score(val_targ, y_pred)
	        val_f1 = f1_score(val_targ, val_predict)
	        if val_f1 > best_f1:
	            best_threshold = threshold
	            best_f1 = val_f1
	            
        self.val_f1s.append(_val_f1)
        # self.val_recalls.append(_val_recall)
        # self.val_precisions.append(_val_precision)
        print('— val_f1: %f' %(_val_f1))
        # print('— val_f1: %f — val_precision: %f — val_recall %f' %(_val_f1, _val_precision, _val_recall))
        return
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章