NUS-WIDE數據集劃分

NUS-WIDE[1]是多標籤數據集,看到幾篇都是類似 [1] 的劃分方式:每個類隨機選 100 個造成 query set。感覺有些謎,問 DCMH 作者,見 [3]。
現在的策略是:按類來抽,保證每個類的樣本數據,而且不放回,保證不重複。(莫非本來就是這個意思?)
我還根據每個類的樣本數,從少到多地選,雖然現在看來似乎沒必要。
用的數據也是 DCMH 作者提供的,見 [3] 所在 repo。

Code

對於 semi-supervised:training set = labeled part,labeled + unlabeled = retrieval set。
test set 跟 query set 同義。

import numpy as np
import scipy.io as sio
import os
from os.path import join
import time
import matplotlib.pyplot as plt

np.random.seed(int(time.time()))

# 讀 label 數據
NUSWIDE = "/usr/local/dataset/nuswide.DCMH/"
labels = sio.loadmat(join(NUSWIDE, "nus-wide-tc21-lall.mat"))["LAll"]
print(labels.shape)  # (195834, 21)


N_CLASS = labels.shape[1]
N_SAMPLE = labels.shape[0]
TEST_PER = 100  # test set 每個類 100 個
TRAIN_PER = 500  # training set 每個類 500 個
N_TEST = TEST_PER * N_CLASS
N_TRAIN = TRAIN_PER * N_CLASS


"""1. 先保證 test set 的每類至少 100"""
indices = list(range(N_SAMPLE))  # 全部索引
np.random.shuffle(indices)

cls_sum = np.sum(labels[indices], axis=0)  # 統計每個類樣本數
#print(cls_sum)
classes = np.argsort(cls_sum)  # 從少到多
#print(classes)

id_test = []
cnt = np.zeros_like(labels[0], dtype=np.int32)  # 默認 int8,會爆
for cls in classes:
    print("--- {} ---".format(cls))
    for i in indices:
        if cnt[cls] >= TEST_PER:  # 此類已抽夠
            break
        if labels[i][cls] == 1:
            id_test.append(i)
            cnt += labels[i]
    #print(cnt)
    assert cnt[cls] >= TEST_PER  # 講道理一趟下來是肯定夠的
    indices = list(set(indices) - set(id_test))  # 去掉已抽部分的 id
    np.random.shuffle(indices)
    #print("left:", len(indices))

assert len(set(id_test)) == len(id_test)  # 驗證沒有重複
#print("cnt:", cnt)
print("#test:", len(id_test))


"""2. 類似地,保證 training set 的每類至少 500"""
indices = list(set(indices) - set(id_test))  # 去掉剛纔選過的那些 test id
np.random.shuffle(indices)
print(len(indices))

cls_sum = np.sum(labels[indices], axis=0)
#print(cls_sum)
classes = np.argsort(cls_sum)
#print(classes)

id_train = []
cnt = np.zeros_like(labels[0], dtype=np.int32)
for cls in classes:
    print("--- {} ---".format(cls))
    for i in indices:
        if cnt[cls] >= TRAIN_PER:
            break
        if labels[i][cls] == 1:
            id_train.append(i)
            cnt += labels[i]
    #print(cnt)
    assert cnt[cls] >= TRAIN_PER
    indices = list(set(indices) - set(id_train))
    np.random.shuffle(indices)
    #print("left:", len(indices))

assert len(set(id_train)) == len(id_train)
#print("cnt:", cnt)
print("#train:", len(id_train))


"""3. 補足 test 和 training set 剩餘的部分"""
indices = list(set(indices) - set(id_train))  # 再去掉剛纔選過的 train id
np.random.shuffle(indices)
#print(len(indices))

lack_test = N_TEST - len(id_test)
lack_train = N_TRAIN - len(id_train)
print("lack:", lack_test, ",", lack_train)

id_test.extend(indices[:lack_test])
id_train.extend(indices[lack_test: lack_test + lack_train])

print("#total test:", len(id_test))
print("#total train:", len(id_train))


"""4. unlabeled 部分"""
# unlabeled = all - labeled(training) - query(test)
id_unlabeled = list(set(indices) - set(id_train) - set(id_test))
print("#unlabeled:", len(id_unlabeled))


"""5. retrieval set"""
id_ret = id_train + id_unlabeled
print("#retrieval:", len(id_ret))


"""保存"""
SAV_P = "/home/tom/codes/reimpl.NINH/split/nuswide/"

test_id = np.asarray(id_test)
labeled_id = np.asarray(id_train)
unlabeled_id = np.asarray(id_unlabeled)
ret_id = np.asarray(id_ret)

np.save(join(SAV_P, "idx_test.npy"), test_id)
np.save(join(SAV_P, "idx_labeled.npy"), labeled_id)
np.save(join(SAV_P, "idx_unlabeled.npy"), unlabeled_id)
np.save(join(SAV_P, "idx_ret.npy"), ret_id)

References

  1. NUS-WIDE
  2. Simultaneous Feature Learning and Hash Coding with Deep Neural Networks
  3. details of partition of NUS-WIDE #8
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章