NUS-WIDE[1]是多標籤數據集,看到幾篇都是類似 [1] 的劃分方式:每個類隨機選 100 個造成 query set。感覺有些謎,問 DCMH 作者,見 [3]。
現在的策略是:按類來抽,保證每個類的樣本數據,而且不放回,保證不重複。(莫非本來就是這個意思?)
我還根據每個類的樣本數,從少到多地選,雖然現在看來似乎沒必要。
用的數據也是 DCMH 作者提供的,見 [3] 所在 repo。
Code
對於 semi-supervised:training set = labeled part,labeled + unlabeled = retrieval set。
test set 跟 query set 同義。
import numpy as np
import scipy.io as sio
import os
from os.path import join
import time
import matplotlib.pyplot as plt
np.random.seed(int(time.time()))
# 讀 label 數據
NUSWIDE = "/usr/local/dataset/nuswide.DCMH/"
labels = sio.loadmat(join(NUSWIDE, "nus-wide-tc21-lall.mat"))["LAll"]
print(labels.shape) # (195834, 21)
N_CLASS = labels.shape[1]
N_SAMPLE = labels.shape[0]
TEST_PER = 100 # test set 每個類 100 個
TRAIN_PER = 500 # training set 每個類 500 個
N_TEST = TEST_PER * N_CLASS
N_TRAIN = TRAIN_PER * N_CLASS
"""1. 先保證 test set 的每類至少 100"""
indices = list(range(N_SAMPLE)) # 全部索引
np.random.shuffle(indices)
cls_sum = np.sum(labels[indices], axis=0) # 統計每個類樣本數
#print(cls_sum)
classes = np.argsort(cls_sum) # 從少到多
#print(classes)
id_test = []
cnt = np.zeros_like(labels[0], dtype=np.int32) # 默認 int8,會爆
for cls in classes:
print("--- {} ---".format(cls))
for i in indices:
if cnt[cls] >= TEST_PER: # 此類已抽夠
break
if labels[i][cls] == 1:
id_test.append(i)
cnt += labels[i]
#print(cnt)
assert cnt[cls] >= TEST_PER # 講道理一趟下來是肯定夠的
indices = list(set(indices) - set(id_test)) # 去掉已抽部分的 id
np.random.shuffle(indices)
#print("left:", len(indices))
assert len(set(id_test)) == len(id_test) # 驗證沒有重複
#print("cnt:", cnt)
print("#test:", len(id_test))
"""2. 類似地,保證 training set 的每類至少 500"""
indices = list(set(indices) - set(id_test)) # 去掉剛纔選過的那些 test id
np.random.shuffle(indices)
print(len(indices))
cls_sum = np.sum(labels[indices], axis=0)
#print(cls_sum)
classes = np.argsort(cls_sum)
#print(classes)
id_train = []
cnt = np.zeros_like(labels[0], dtype=np.int32)
for cls in classes:
print("--- {} ---".format(cls))
for i in indices:
if cnt[cls] >= TRAIN_PER:
break
if labels[i][cls] == 1:
id_train.append(i)
cnt += labels[i]
#print(cnt)
assert cnt[cls] >= TRAIN_PER
indices = list(set(indices) - set(id_train))
np.random.shuffle(indices)
#print("left:", len(indices))
assert len(set(id_train)) == len(id_train)
#print("cnt:", cnt)
print("#train:", len(id_train))
"""3. 補足 test 和 training set 剩餘的部分"""
indices = list(set(indices) - set(id_train)) # 再去掉剛纔選過的 train id
np.random.shuffle(indices)
#print(len(indices))
lack_test = N_TEST - len(id_test)
lack_train = N_TRAIN - len(id_train)
print("lack:", lack_test, ",", lack_train)
id_test.extend(indices[:lack_test])
id_train.extend(indices[lack_test: lack_test + lack_train])
print("#total test:", len(id_test))
print("#total train:", len(id_train))
"""4. unlabeled 部分"""
# unlabeled = all - labeled(training) - query(test)
id_unlabeled = list(set(indices) - set(id_train) - set(id_test))
print("#unlabeled:", len(id_unlabeled))
"""5. retrieval set"""
id_ret = id_train + id_unlabeled
print("#retrieval:", len(id_ret))
"""保存"""
SAV_P = "/home/tom/codes/reimpl.NINH/split/nuswide/"
test_id = np.asarray(id_test)
labeled_id = np.asarray(id_train)
unlabeled_id = np.asarray(id_unlabeled)
ret_id = np.asarray(id_ret)
np.save(join(SAV_P, "idx_test.npy"), test_id)
np.save(join(SAV_P, "idx_labeled.npy"), labeled_id)
np.save(join(SAV_P, "idx_unlabeled.npy"), unlabeled_id)
np.save(join(SAV_P, "idx_ret.npy"), ret_id)