在圖像分類或是識別任務中,一般要求計算top-1,top-2,tor-5等準確率,下面是用Tensorflow2實現這一功能的基本代碼,可以根據要求改代碼分別計算:
def accuracy(output,target,topk(1,)):
maxk=max(topk)
batch_size=target.shape[0]
pred=tf.math.top_k(output,maxk).indices
pred=tf.transpose(pred,perm=[1,0])
target_=tf.broadcast_to(target,pred.shape)
correct=tf.equal(target_,pred)
res=[]
for k in topk:
correct_k=tf.cast(tf.reshape(correct[:k],[-1]),dtype=tf.float32)
correct_k=tf.reduce_sum(correct_k)
acc=float(correct_k/batch_size)
res.append(acc)
return res