TensorFlow2.0 数组排序 正确率计算(top_k)

import tensorflow as tf
import numpy as np
print(tf.__version__)

1.数组排序

tf.sort tf.argsort

# 创建 0-4的数组  随机打散
arry = tf.random.shuffle(tf.range(0, 5))
print(ary)

# 降序排序 tf.sort
arry_sort = tf.sort(arry, direction="DESCENDING")
print(arry_sort)

# 降序序号 tf.argsort
arry_argsort = tf.argsort(arry, direction="DESCENDING")
print(arry_argsort)

# 使用降序序号 tf.gather
print(tf.gather(arry, arry_argsort))

在这里插入图片描述

# 创建 3*3 的随机二维数组 
arry = tf.random.uniform((3, 3), maxval= 10, dtype=tf.int32)
print(arry)

# 对二维数组 升序 排序(默认最后一个维度) axis=-1
arry_sort = tf.sort(arry)
print("二维数组排序 升序:\n", arry_sort.numpy())

# 对二维数组 降序排序
arry_sort_descending = tf.sort(arry, direction="DESCENDING")
print("二维数组排序 降序:\n", arry_sort_descending.numpy())

# 获取二维数组 升序排序  索引
arry_argsort = tf.argsort(arry)
print("二维数组排序 升序索引:\n", arry_argsort)

在这里插入图片描述

2.正确率计算(top_k)

tf.math.top_k

# 创建 随机的 3*3 二维数组
arry = tf.random.uniform((3, 3), maxval= 10, dtype=tf.int32)
print(arry)

# tf.math.top_k 获取 二维数组 前两个
arry_topk = tf.math.top_k(arry, 2)

# 获取结果的 索引
print(arry_topk.indices)

# 获取结果的 排序后的值
print(arry_topk.values)

在这里插入图片描述

3.计算Top5的正确率

# top_k函数
def accuracy(y_pred, y_ture, k=(1,)):
    # 获取 k的最大值
    max_k = max(k)
    
    # 获取数据数量
    data_len = y_ture.shape[0]
    
    # 输入数据 top_k 的索引值
    y_pred = tf.math.top_k(y_pred, max_k).indices
    # 转置 数组 
    y_pred = tf.transpose(y_pred, perm=[1, 0])
    
    # 目标 broadcast
    y_ture_b = tf.broadcast_to(y_ture, y_pred.shape)
    
    # 对比 正确与否
    correct = tf.equal(y_pred, y_ture_b)
    res = []
    # 循环统计 
    for i in k:
        correct_k = tf.cast(tf.reshape(correct[:i], [-1]), dtype=tf.float32)
        correct_k = tf.reduce_sum(correct_k)
        accy = correct_k * (100 / data_len)        
        res.append(accy.numpy())
    return res
# top_k 计算 正确率

# 创建 10 * 6 的随机二维数组
y_pred = tf.random.normal((10, 6))
# tf.math.softmax  总和归一化
y_pred = tf.math.softmax(y_pred, axis=-1)
print(y_pred)

# 随机 生成 10个 目标结果
y_true = tf.random.uniform((10,), maxval=6, dtype=tf.int32)
print(y_true)

# 获取最大值索引 预测值 
y_pred_max = tf.argmax(y_pred, axis=1)
print(y_pred_max)


# 调用函数 计算 正确率
acc = accuracy(y_pred, y_true, (1, 2, 3, 4, 5))
print(acc)

在这里插入图片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章