def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices) # 样本的读取顺序是随机的。
for i in range(0, num_examples, batch_size):
j = nd.array(indices[i: min(i + batch_size, num_examples)])
yield features.take(j), labels.take(j) # take 函数根据索引返回对应元素。
使用:
batch_size = 10
for X, y in data_iter(batch_size, features, labels):
print(X, y)
break