前言
最近在學習paddlepaddle深度學習框架,在實戰波士頓房價預測例子的時候,yield函數困擾到了我。索性,我就將python迭代器琢磨了一遍,終於弄清楚了!。
迭代器概念
一句話:迭代器是可以記住遍歷位置的函數。(有C基礎的可對比指針)
迭代器從會從集合的第一個數據開始遍歷,直到結束。
我們拿代碼舉例子吧:
listdata=[1,2,3,4]#集合
listiter = iter(listdata)#將列表轉化爲一個迭代器
for data in listdata:
print(data)
print("listdata類型:",type(listdata))
print("---")
print("listiter類型:",type(listiter))
for data1 in listiter:
print(data1)
我們來看一下輸出結果:
從結果中,我們可以看到,listdata和listiter數據類型不一樣,前者是列表,後者是迭代器,但是輸出結果是一樣的。
對於listdata我們也可以這樣遍歷:
for data in [1,2,3,4]:
print(data)
對,看着很簡單是吧,但是,listiter不能這樣寫。
listiter是通過next()方法依次遍歷數據的。當取完1時,標記位置指向2,再將2取出後,標記位置指向3,因爲3是最後一個,迭代器函數停止執行。
看一下下面這個圖:
總結:
- 除了列表可以轉化爲迭代器之外,元組和字符串也可以進行轉化。
- 上述例子,儘管輸出結果一樣,但是兩個所佔用的內存是不一樣的,迭代器經常用於深度學習框架。原因是其相對於列表這些來講,大大的減少了內存佔用空間。
- 迭代器使用next()方法進行數據依次獲取
如何自己構造一個迭代器
爲什麼列表可以轉化爲一個迭代器,是因爲它本身帶有iter()和next().我們看一下如何自己手寫一個迭代器
class writeiter:
def __iter__(self):
self.start=1
self.end=4
return self
def __next__(self):
if self.start>self.end:
return
else:
x = self.start
self.start+=1
return x
testclass = writeiter()
testiter = iter(testclass)
print(next(testiter))
print(next(testiter))
print(next(testiter))
print(next(testiter))
這裏我們實現了輸出從1,4連續的數字。
在我們這個類中,需要加上這兩個魔法方法。也就是iter()和next().
方法前後有雙下劃線的,在python表示魔法方法。具體什麼是魔法方法,大家可以自行百度。
yield函數是什麼
yield函數是一個生成器,凡是使用yield函數的函數(方法)都是生成器函數(方法)。
我們來看一個例子
def test():
for i in range(5):
yield i
print(type(test()))
print(next(test()))
輸出:
test()方法中因爲使用了yield函數,它的類型也變成了生成器類型。此時它也就變成了一個迭代器,我們可以使用next()依次訪問從0到5.我們可以把這個yield函數當成存錢罐,不斷往裏面存儲後面的變量。但是取得時候,只能一個一個按順序從頭開始取。
其上述代碼效果等價於:
listdata=[0,1,2,3,4]#集合
listiter = iter(listdata)
print(next(listiter))
yield函數在paddlepaddle中是如何被使用的
好了,到大頭了。
接下來呢,我們就探究paddlepaddle源碼,看看yield函數在這個框架中是怎麼使用的。
我選取的例子是官方給出的預測波士頓房價。
#每次緩存500個數據
BUF_SIZE=500
#每20個組成一個訓練批次
BATCH_SIZE=20
#一次緩存500個數據,其中每次讀取20個
#讀取數據到緩存中,並且打亂數據。
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.uci_housing.train(),buf_size=BUF_SIZE),
batch_size=BATCH_SIZE
)
上述代碼是將paddle讀取波士頓房價信息的訓練集數據讀取到緩存中。
咱們着重看這一句:
paddle.dataset.uci_housing.train()
我們再進入到train()這個方法中(pycharm中按住control+左鍵)
def train():
"""
UCI_HOUSING training set creator.
It returns a reader creator, each sample in the reader is features after
normalization and price number.
:return: Training reader creator
:rtype: callable
"""
global UCI_TRAIN_DATA
load_data(paddle.dataset.common.download(URL, 'uci_housing', MD5))
def reader():
for d in UCI_TRAIN_DATA:
yield d[:-1], d[-1:]
return reader
上述就是train()的源代碼。
咱們可以看到reader()方法中使用了yield函數,所以reader()這個是一個生成器。train()方法將它作爲返回值,實際上就是返回了一個迭代器。
train_data = next(train_reader())
這個用來獲取train_reader()的下一個值,原理和我們之前講的一樣。
官方也給出了一個例子,咱們也拿出來看一下:
def reader():
for i in range(10):
yield i
batch_reader = fluid.io.batch(reader, batch_size=2)
for data in batch_reader():
print(data)
這個例子就是咱們上面例子的簡化版,batch_reader()在這裏作爲一個迭代器使用。
使用迭代器最大的好處就是節約內存。特別是在我們訓練模型的時候,需要使用大量的數據集。
如果大家看完還有不懂的地方,請留言一起討論。如果有幫助到你,請點個👍