關於keras 中輸出模型網絡圖的說明-how to plot the lstm model graph and save it to a file

在keras中,可以通過Model visualization模塊進行作圖。以下是來自keras文檔的說明。

Model visualization

Keras provides utility functions to plot a Keras model (using graphviz).

This will plot a graph of the model and save it to a file:

from keras.utils import plot_model
plot_model(model, to_file='model.png')

plot_model takes four optional arguments:

  • show_shapes (defaults to False) controls whether output shapes are shown in the graph.
  • show_layer_names (defaults to True) controls whether layer names are shown in the graph.
  • expand_nested (defaults to False) controls whether to expand nested models into clusters in the graph.
  • dpi (defaults to 96) controls image dpi.

You can also directly obtain the pydot.Graph object and render it yourself, for example to show it in an ipython notebook :

from IPython.display import SVG
from keras.utils import model_to_dot

SVG(model_to_dot(model).create(prog='dot', format='svg'))

 這個過程最大的麻煩在於如何安裝pydot和graphviz兩個包。

下面對踩過的坑進行總結,以饗大家。

 

1、pydot安裝

首先進入anaconda(administrator mode), as shown below:

then, you can input the following command,

2、graphviz

Next , you should install the graphviz package, also it is easy to do this,

 

運行一下文檔中的代碼,如果出現以下錯誤,

ImportError: Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.

這個錯誤按理不應該出現,但是找了好久也沒有解決,最終一個stack overflow上的post幫了大忙,原來就是package的環境變量沒有設置正確。

3、solution

正確方式應該在程序代碼中加入如下兩行:

import os #
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz-2.38//release/bin/'#注意該目錄是你下載graphviz之後解壓到這個文件夾的位置, you can customize it as your own needs 。

4、results

We have the following code:

from keras.utils import plot_model
import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz-2.38//release/bin/'
from math import sin
from math import pi
from math import exp
from random import random
from random import randint
from random import uniform
from numpy import array
from matplotlib import pyplot
from keras.models import Sequential
from keras.layers import LSTM
from keras.layers import Dense
def generate_sequence(length, period, decay):
   return [0.5 + 0.5 * sin(2 * pi * i / period) * exp(-decay * i) for i in range(length)]
# generate input and output pairs of damped sine waves
def generate_examples(length, n_patterns, output):
  X, y = list(), list()
  for _ in range(n_patterns):
   p = randint(10, 20)
   d = uniform(0.01, 0.1)
   sequence = generate_sequence(length + output, p, d)
   X.append(sequence[:-output])
   y.append(sequence[-output:])
  X = array(X).reshape(n_patterns, length, 1)
  y = array(y).reshape(n_patterns, output)
  return X, y

# configure problem
length = 50
output = 5

# define model
model = Sequential()
model.add(LSTM(20, return_sequences=True, input_shape=(length, 1)))
model.add(LSTM(20))
model.add(Dense(output))
model.compile(loss= 'mae' , optimizer= 'adam' )
print(model.summary())
plot_model(model, to_file='model3.png', show_shapes=True)

finally, it produced a pic displayed like this 

There  You go!!!

 

 

from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
SVG(model_to_dot(model).create(prog='dot', format='svg'))

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章