使用tensor2tensor建立decoder報錯:
使用tensor2tensor建立decoder報錯:
```python
# Build RNN cell
decoder_cell = tf.nn.rnn_cell.LSTMCell(num_units=self.args.rnn_dim)
# Helper
decoder_inputs_embedded = tf.nn.embedding_lookup(self.embeddings,
self.gen_outputs) # gen_outputs:以SOS爲開始的decoder輸入
projection_layer = tf.layers.Dense(self.vocab_size, kernel_initializer=self.initializer)
helper = tf.contrib.seq2seq.TrainingHelper(
decoder_inputs_embedded, self.gen_len, time_major=False) #這邊self.gen_len要不要加一還不知道??
# Decoder
decoder = tf.contrib.seq2seq.BasicDecoder(
decoder_cell, helper, initial_state=state,
output_layer=projection_layer)
# Dynamic decoding
decoder_outputs, _,_ = tf.contrib.seq2seq.dynamic_decode(decoder,
impute_finished=True,
maximum_iterations=max_review_length)
報錯顯示:
Traceback (most recent call last):
File "/data1/home/qlj/pycharm27/CAML/train_CAML.py", line 771, in <module>
exp = CFExperiment(inject_params=None)
File "/data1/home/qlj/pycharm27/CAML/train_CAML.py", line 105, in __init__
num_user=self.num_users)
File "/data1/home/qlj/pycharm27/CAML/tf_models/model_caml.py", line 89, in __init__
self.build_graph()
File "/data1/home/qlj/pycharm27/CAML/tf_models/model_caml.py", line 1156, in build_graph
self.gen_loss, self.gen_acc, self.key_word_loss = self._gen_review(q1_output, q2_output, r_input)
File "/data1/home/qlj/pycharm27/CAML/tf_models/model_caml.py", line 720, in _gen_review
maximum_iterations=max_review_length)
File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 286, in dynamic_decode
swap_memory=swap_memory)
File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2816, in while_loop
result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2640, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2590, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 234, in body
decoder_finished) = decoder.step(time, inputs, state)
File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py", line 138, in step
cell_outputs, cell_state = self._cell(inputs, state)
File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 183, in __call__
return super(RNNCell, self).__call__(inputs, state)
File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/layers/base.py", line 575, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 591, in call
(c_prev, m_prev) = state
File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 505, in __iter__
raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.
原因:
因爲state的shape爲(batchsize,dim)
而decoder_cell這邊設置的是LSTMcell,則初始化的state大小應該爲(2,batchsize,dim),爲什麼是2呢,因爲包含了LSTM中的H和C
解決辦法:
將decoder_cell設置爲GRUcell