按batch_size讀取數據

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)  # 樣本的讀取順序是隨機的。
    for i in range(0, num_examples, batch_size):
        j = nd.array(indices[i: min(i + batch_size, num_examples)])
        yield features.take(j), labels.take(j)  # take 函數根據索引返回對應元素。

使用:

batch_size = 10

for X, y in data_iter(batch_size, features, labels):
    print(X, y)
    break
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章