关于Keras中函数式编程的建模方式其代码编写形式的理解

学习笔记:

在刚开始接触Keras中神经网络模型的函数式编程建模方式时有一点点不太理解,例如下面代码段中的第二句代码。

input = Input(shape=(config.lag, 5))
lstm = LSTM(config.units_num, input_shape=(config.lag, 5), return_sequences=True)(input)

在初步了解了Keras中模型的函数式编程建模方式后,我们可以知道,这两句代码的含义是首先创建一个模型的输入层(input)对应于代码段中的第一句代码,然后将这个输入层传入给LSTM层(lstm),作为LSTM层的输入层,对应于代码段中的第二句代码。

我的疑问是LSTM(config.units_num, input_shape=(config.lag, 5), return_sequences=True),这句代码应该会返回一个LSTM类的对象,那么直接使用对象+()即后面跟着的(input)的方式就可以完成将input设置为LSTM层的输入层这样一个任务吗,其实现的原理是什么呢。

在查阅了一些资料后发现,对于python语言而言,使用类对象+()的方式会自动调用类中的__call__()方法。这样的操作相当于把类对象当作函数在调用它,故当程序运行到LSTM(...)(input)这行代码时,实际上就完成了两个操作:第一:创建了一个LSTM类对象,第二:将input对象作为形参传入到LSTM类对象的__call__(self,inputs,...[该方法有很多参数,就不一一写出了])方法中,并调用该方法,完成LSTM层的创建。

至此Keras中函数式编程建模代码之所以可以这样写,我们已经了解清楚了。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章