一文總結Keras的loss函數和metrics函數

Loss函數

定義:

keras.losses.mean_squared_error(y_true, y_pred)

用法很簡單,就是計算均方誤差平均值,例如

loss_fn = keras.losses.mean_squared_error
a1 = tf.constant([1,1,1,1])
a2 = tf.constant([2,2,2,2])
loss_fn(a1,a2)
<tf.Tensor: id=718367, shape=(), dtype=int32, numpy=1>

Metrics函數

Metrics函數也用於計算誤差,但是功能比Loss函數要複雜。
定義

tf.keras.metrics.Mean(
    name='mean', dtype=None
)

這個定義過於簡單,舉例說明

mean_loss([1, 3, 5, 7])
mean_loss([1, 3, 5, 7])
mean_loss([1, 1, 1, 1])
mean_loss([2,2])

輸出結果

<tf.Tensor: id=718929, shape=(), dtype=float32, numpy=2.857143>

這個結果等價於

np.mean([1, 3, 5, 7, 1, 3, 5, 7, 1, 1, 1, 1, 2, 2])

這是因爲Metrics函數是狀態函數,在神經網絡訓練過程中會持續不斷地更新狀態,是有記憶的。因爲Metrics函數還帶有下面幾個Methods

reset_states()

Resets all of the metric state variables.
This function is called between epochs/steps, when a metric is evaluated during training.

result()

Computes and returns the metric value tensor.
Result computation is an idempotent operation that simply calculates the metric value using the state variables

update_state(
    values, sample_weight=None
)

Accumulates statistics for computing the reduction metric.

另外注意,Loss函數和Metrics函數的調用形式,

loss_fn = keras.losses.mean_squared_error
mean_loss = keras.metrics.Mean()

mean_loss(1)等價於keras.metrics.Mean()(1),而不是keras.metrics.Mean(1),這個從keras.metrics.Mean函數的定義可以看出。
但是必須先令生成一個實例mean_loss=keras.metrics.Mean(),而不能直接使用keras.metrics.Mean()本身。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章