似乎數據類型先轉爲float32爲好,除了complex類型的數據。不然可能報錯。
>>> xx.shape
(8, 10, 10)
>>> xx2=tf.constant(xx,tf.float32)
>>> inputs=keras.Input(shape=xx.shape[1:],tensor=xx2)
>>> with tf.Session() as sess:
print(sess.run(inputs))
上面這個是查看最基本的inputs,然而我直接打印BN後的結果出現錯誤,what's up ?
>>> xx3=keras.layers.BatchNormalization(input_shape=xx.shape[1:])(inputs)
>>> with tf.Session() as sess:
print(sess.run(xx3))
Traceback (most recent call last):
File "D:\python\lib\site-packages\tensorflow_core\python\client\session.py", line 1365, in _do_call
return fn(*args)
File "D:\python\lib\site-packages\tensorflow_core\python\client\session.py", line 1350, in _run_fn
target_list, run_metadata)
File "D:\python\lib\site-packages\tensorflow_core\python\client\session.py", line 1443, in _call_tf_sessionrun
run_metadata)
tensorflo