今天終於弄明白了python迭代器是什麼(含paddlepaddle部分源碼解析)

前言

最近在學習paddlepaddle深度學習框架,在實戰波士頓房價預測例子的時候,yield函數困擾到了我。索性,我就將python迭代器琢磨了一遍,終於弄清楚了!。

迭代器概念

一句話:迭代器是可以記住遍歷位置的函數。(有C基礎的可對比指針)

迭代器從會從集合的第一個數據開始遍歷,直到結束。

我們拿代碼舉例子吧:

listdata=[1,2,3,4]#集合
listiter = iter(listdata)#將列表轉化爲一個迭代器

for data in listdata:
    print(data)
print("listdata類型:",type(listdata))
print("---")
print("listiter類型:",type(listiter))
for data1 in listiter:
    print(data1)

我們來看一下輸出結果:
在這裏插入圖片描述
從結果中,我們可以看到,listdata和listiter數據類型不一樣,前者是列表,後者是迭代器,但是輸出結果是一樣的。

對於listdata我們也可以這樣遍歷:

for data in [1,2,3,4]:
    print(data)

對,看着很簡單是吧,但是,listiter不能這樣寫。
listiter是通過next()方法依次遍歷數據的。當取完1時,標記位置指向2,再將2取出後,標記位置指向3,因爲3是最後一個,迭代器函數停止執行。
看一下下面這個圖:
在這裏插入圖片描述
總結:

  • 除了列表可以轉化爲迭代器之外,元組和字符串也可以進行轉化。
  • 上述例子,儘管輸出結果一樣,但是兩個所佔用的內存是不一樣的,迭代器經常用於深度學習框架。原因是其相對於列表這些來講,大大的減少了內存佔用空間。
  • 迭代器使用next()方法進行數據依次獲取

如何自己構造一個迭代器

爲什麼列表可以轉化爲一個迭代器,是因爲它本身帶有iter()和next().我們看一下如何自己手寫一個迭代器

class writeiter:

    def __iter__(self):
        self.start=1
        self.end=4
        return self
    def __next__(self):
        if self.start>self.end:
            return
        else:
            x = self.start
            self.start+=1
            return x
testclass = writeiter()
testiter = iter(testclass)
print(next(testiter))
print(next(testiter))
print(next(testiter))
print(next(testiter))

這裏我們實現了輸出從1,4連續的數字。
在這裏插入圖片描述
在我們這個類中,需要加上這兩個魔法方法。也就是iter()和next().
方法前後有雙下劃線的,在python表示魔法方法。具體什麼是魔法方法,大家可以自行百度。

yield函數是什麼

yield函數是一個生成器,凡是使用yield函數的函數(方法)都是生成器函數(方法)。

我們來看一個例子

def test():
    for i in range(5):
        yield i

print(type(test()))
print(next(test()))

輸出:
在這裏插入圖片描述
test()方法中因爲使用了yield函數,它的類型也變成了生成器類型。此時它也就變成了一個迭代器,我們可以使用next()依次訪問從0到5.我們可以把這個yield函數當成存錢罐,不斷往裏面存儲後面的變量。但是取得時候,只能一個一個按順序從頭開始取。

其上述代碼效果等價於:

listdata=[0,1,2,3,4]#集合
listiter = iter(listdata)
print(next(listiter))

yield函數在paddlepaddle中是如何被使用的

好了,到大頭了。
接下來呢,我們就探究paddlepaddle源碼,看看yield函數在這個框架中是怎麼使用的。
我選取的例子是官方給出的預測波士頓房價。

#每次緩存500個數據
BUF_SIZE=500
#每20個組成一個訓練批次
BATCH_SIZE=20
#一次緩存500個數據,其中每次讀取20個
#讀取數據到緩存中,並且打亂數據。
train_reader = paddle.batch(
    paddle.reader.shuffle(paddle.dataset.uci_housing.train(),buf_size=BUF_SIZE),
    batch_size=BATCH_SIZE
    )



上述代碼是將paddle讀取波士頓房價信息的訓練集數據讀取到緩存中。
咱們着重看這一句:

paddle.dataset.uci_housing.train()

我們再進入到train()這個方法中(pycharm中按住control+左鍵)

def train():
    """
    UCI_HOUSING training set creator.

    It returns a reader creator, each sample in the reader is features after
    normalization and price number.

    :return: Training reader creator
    :rtype: callable
    """
    global UCI_TRAIN_DATA
    load_data(paddle.dataset.common.download(URL, 'uci_housing', MD5))

    def reader():
        for d in UCI_TRAIN_DATA:
            yield d[:-1], d[-1:]

    return reader

上述就是train()的源代碼。
咱們可以看到reader()方法中使用了yield函數,所以reader()這個是一個生成器。train()方法將它作爲返回值,實際上就是返回了一個迭代器。

 train_data = next(train_reader())

這個用來獲取train_reader()的下一個值,原理和我們之前講的一樣。
官方也給出了一個例子,咱們也拿出來看一下:

def reader():
    for i in range(10):
        yield i
batch_reader = fluid.io.batch(reader, batch_size=2)
for data in batch_reader():
    print(data)

這個例子就是咱們上面例子的簡化版,batch_reader()在這裏作爲一個迭代器使用。

使用迭代器最大的好處就是節約內存。特別是在我們訓練模型的時候,需要使用大量的數據集。

如果大家看完還有不懂的地方,請留言一起討論。如果有幫助到你,請點個👍

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章