Keras實現一個簡單的線性迴歸

這裏設置x,y的關係是y=2x+3

因爲輸入的x是1維的數字,輸出的y也是1維的,所以dense層的輸出維度爲1,總共的參數爲(1+1)*1=2個,分別是w和b

# import packages
import numpy as np
from keras.models import *
from keras.layers import *
import matplotlib.pyplot as plt

# generate data
x_train = np.linspace(-1, 1, 200)
y_train = 2 * x_train + np.random.normal(3, 0.5, (x_train.shape[-1]))
x_test = np.linspace(-1, 1, 100)
y_test = 2 * x_test + np.random.normal(3, 0.5, (x_test.shape[-1]))

# initialize model
lr_model = Sequential()
lr_model.add(Dense(1, activation='linear', input_dim=1, name='dense_1'))
lr_model.compile(loss='mse', optimizer='sgd')

# train
lr_model.fit(x_train, y_train, batch_size=32, epochs=500, verbose=1)

# evaluate
x = lr_model.evaluate(x_test, y_test, batch_size=32)
print(lr_model.layers[0].get_weights())
print(lr_model.layers[0].get_config())
lr_model.summary()

# visualization
y_pred = lr_model.predict(x_test)
plt.scatter(x_test, y_test)
plt.plot(x_test, y_pred)
plt.show()

通過model.summary()可以查看網絡的層,可以看出,總共有2個可訓練的參數。

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 1)                 2         
=================================================================
Total params: 2
Trainable params: 2
Non-trainable params: 0

 通過get_config()和get_weights()方法可以查看layer的配置和參數,可以看出計算的kernel爲2,bias爲3

 

[array([[2.0756874]], dtype=float32), array([3.0122392], dtype=float32)]
{'name': 'dense_1', 'trainable': True, 'batch_input_shape': (None, 1), 'dtype': 'float32', 'units': 1, 'activation': 'linear', 'use_bias': True, 'kernel_initializer': {'class_name': 'VarianceScaling', 'config': {'scale': 1.0, 'mode': 'fan_avg', 'distribution': 'uniform', 'seed': None}}, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'kernel_regularizer': None, 'bias_regularizer': None, 'activity_regularizer': None, 'kernel_constraint': None, 'bias_constraint': None}

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章