tensor2tensor報錯'Tensor' object is not iterable.

使用tensor2tensor建立decoder報錯:

使用tensor2tensor建立decoder報錯:
```python
            # Build RNN cell
            decoder_cell = tf.nn.rnn_cell.LSTMCell(num_units=self.args.rnn_dim)
            # Helper
            decoder_inputs_embedded = tf.nn.embedding_lookup(self.embeddings,
                                                             self.gen_outputs)  # gen_outputs:以SOS爲開始的decoder輸入
            projection_layer = tf.layers.Dense(self.vocab_size, kernel_initializer=self.initializer)
            helper = tf.contrib.seq2seq.TrainingHelper(
                decoder_inputs_embedded, self.gen_len, time_major=False) #這邊self.gen_len要不要加一還不知道??
            # Decoder
            decoder = tf.contrib.seq2seq.BasicDecoder(
                decoder_cell, helper, initial_state=state,
                output_layer=projection_layer)
            # Dynamic decoding
            decoder_outputs, _,_ = tf.contrib.seq2seq.dynamic_decode(decoder,
                                                           impute_finished=True,
                                                           maximum_iterations=max_review_length)                               

  報錯顯示:

Traceback (most recent call last):
  File "/data1/home/qlj/pycharm27/CAML/train_CAML.py", line 771, in <module>
    exp = CFExperiment(inject_params=None)
  File "/data1/home/qlj/pycharm27/CAML/train_CAML.py", line 105, in __init__
    num_user=self.num_users)
  File "/data1/home/qlj/pycharm27/CAML/tf_models/model_caml.py", line 89, in __init__
    self.build_graph()
  File "/data1/home/qlj/pycharm27/CAML/tf_models/model_caml.py", line 1156, in build_graph
    self.gen_loss, self.gen_acc, self.key_word_loss = self._gen_review(q1_output, q2_output, r_input)
  File "/data1/home/qlj/pycharm27/CAML/tf_models/model_caml.py", line 720, in _gen_review
    maximum_iterations=max_review_length)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 286, in dynamic_decode
    swap_memory=swap_memory)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2816, in while_loop
    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2640, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2590, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 234, in body
    decoder_finished) = decoder.step(time, inputs, state)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py", line 138, in step
    cell_outputs, cell_state = self._cell(inputs, state)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 183, in __call__
    return super(RNNCell, self).__call__(inputs, state)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/layers/base.py", line 575, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 591, in call
    (c_prev, m_prev) = state
  File "/data1/home/qlj/.conda/envs/py27/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 505, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

原因:

因爲state的shape爲(batchsize,dim)

而decoder_cell這邊設置的是LSTMcell,則初始化的state大小應該爲(2,batchsize,dim),爲什麼是2呢,因爲包含了LSTM中的H和C

解決辦法:

將decoder_cell設置爲GRUcell

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章