windows下tensorflow的CNN框架

環境:windows+tensorflow-GPU-1.8+python3.6
代碼地址:https://download.csdn.net/download/lsjweiyi/10466889
我在cifar10代碼的基礎上將代碼修改,排版,使功能模塊更清晰,也更容易更換網絡模型,代碼里加上大量中文註釋。有點基礎的應該都能看懂了,主要目的是爲了方便像我這樣的新入門tensorflow的人寫CNN網絡,使用的時候要注意修改數據的路徑,模型保存路徑,圖片的SIZE,類別NUM_CLASSES等一些參數,這些參數我都放在了每個PY文件的最前面,方便找到並進行修改。

我在其中加入了保存混淆矩陣和圖片分類錯誤路徑的功能,可以有更多途徑分析網絡訓練的效果。

此外,內置的網絡模型是lenet-5,這個只是爲了看起來簡單,我規範化代碼的目的就是方便調用別的網絡,比如tensorflow裏自帶的其他網絡,接下來給出調用的方式:

比如我要調用Alexnet等網絡:

#import的方式
from tensorflow.contrib.slim.nets import resnet_v1 #import resnet_v1
from tensorflow.contrib.slim.nets import alexnet # import alexnet
#這裏還有VGG,inception等網絡,都可以直接引用

接下來是調用這些網絡:

#首先找到train.py或者test.py裏面的如下代碼
logits=model.model(images)#調用模型,返回預測的概率矩陣

#這句話,替換爲如下語句,這裏以alexnet爲例,當然上面的import不能少
logits,_=alexnet.alexnet_v2(images,num_classes=model.NUM_CLASSES,is_training=True,dropout_keep_prob=0.5)
#其中dropout_keep_prob=0.5是dropout的概率,具體解釋,請自行看源碼上的註釋,train和test有一點點區別
#這裏有一個不好的地方就是,調用這些模型,他是限制了圖片的輸入大小的,例如alexnet限制的是224*224,不同的網絡不同,可以在源碼中找到要求。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章