TensorFlow2.0 數組排序 正確率計算(top_k)

import tensorflow as tf
import numpy as np
print(tf.__version__)

1.數組排序

tf.sort tf.argsort

# 創建 0-4的數組  隨機打散
arry = tf.random.shuffle(tf.range(0, 5))
print(ary)

# 降序排序 tf.sort
arry_sort = tf.sort(arry, direction="DESCENDING")
print(arry_sort)

# 降序序號 tf.argsort
arry_argsort = tf.argsort(arry, direction="DESCENDING")
print(arry_argsort)

# 使用降序序號 tf.gather
print(tf.gather(arry, arry_argsort))

在這裏插入圖片描述

# 創建 3*3 的隨機二維數組 
arry = tf.random.uniform((3, 3), maxval= 10, dtype=tf.int32)
print(arry)

# 對二維數組 升序 排序(默認最後一個維度) axis=-1
arry_sort = tf.sort(arry)
print("二維數組排序 升序:\n", arry_sort.numpy())

# 對二維數組 降序排序
arry_sort_descending = tf.sort(arry, direction="DESCENDING")
print("二維數組排序 降序:\n", arry_sort_descending.numpy())

# 獲取二維數組 升序排序  索引
arry_argsort = tf.argsort(arry)
print("二維數組排序 升序索引:\n", arry_argsort)

在這裏插入圖片描述

2.正確率計算(top_k)

tf.math.top_k

# 創建 隨機的 3*3 二維數組
arry = tf.random.uniform((3, 3), maxval= 10, dtype=tf.int32)
print(arry)

# tf.math.top_k 獲取 二維數組 前兩個
arry_topk = tf.math.top_k(arry, 2)

# 獲取結果的 索引
print(arry_topk.indices)

# 獲取結果的 排序後的值
print(arry_topk.values)

在這裏插入圖片描述

3.計算Top5的正確率

# top_k函數
def accuracy(y_pred, y_ture, k=(1,)):
    # 獲取 k的最大值
    max_k = max(k)
    
    # 獲取數據數量
    data_len = y_ture.shape[0]
    
    # 輸入數據 top_k 的索引值
    y_pred = tf.math.top_k(y_pred, max_k).indices
    # 轉置 數組 
    y_pred = tf.transpose(y_pred, perm=[1, 0])
    
    # 目標 broadcast
    y_ture_b = tf.broadcast_to(y_ture, y_pred.shape)
    
    # 對比 正確與否
    correct = tf.equal(y_pred, y_ture_b)
    res = []
    # 循環統計 
    for i in k:
        correct_k = tf.cast(tf.reshape(correct[:i], [-1]), dtype=tf.float32)
        correct_k = tf.reduce_sum(correct_k)
        accy = correct_k * (100 / data_len)        
        res.append(accy.numpy())
    return res
# top_k 計算 正確率

# 創建 10 * 6 的隨機二維數組
y_pred = tf.random.normal((10, 6))
# tf.math.softmax  總和歸一化
y_pred = tf.math.softmax(y_pred, axis=-1)
print(y_pred)

# 隨機 生成 10個 目標結果
y_true = tf.random.uniform((10,), maxval=6, dtype=tf.int32)
print(y_true)

# 獲取最大值索引 預測值 
y_pred_max = tf.argmax(y_pred, axis=1)
print(y_pred_max)


# 調用函數 計算 正確率
acc = accuracy(y_pred, y_true, (1, 2, 3, 4, 5))
print(acc)

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章