def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices) # 樣本的讀取順序是隨機的。
for i in range(0, num_examples, batch_size):
j = nd.array(indices[i: min(i + batch_size, num_examples)])
yield features.take(j), labels.take(j) # take 函數根據索引返回對應元素。
使用:
batch_size = 10
for X, y in data_iter(batch_size, features, labels):
print(X, y)
break