tensorflow transpose 的c 實現

1.Tensorflow 的 transpose 調用的是 eigen 庫的 tensor類的 tensor.shuffle 

2.

3.Eigen 庫中tensor.shuffle 中最核心的代碼是下面這樣的。沒有進行內存搬運,而是修改了幾個tensor 的成員變量

最重要的是修改了m_inputStrides 這個成員變量。

 

EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
      : m_impl(op.expression(), device)
  {
    const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
    const Shuffle& shuffle = op.shufflePermutation();
    for (int i = 0; i < NumDims; ++i) {
      m_dimensions[i] = input_dims[shuffle[i]];
    }

    array<Index, NumDims> inputStrides;

    if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
      inputStrides[0] = 1;
      m_outputStrides[0] = 1;
      for (int i = 1; i < NumDims; ++i) {
        inputStrides[i] = inputStrides[i - 1] * input_dims[i - 1];
        m_outputStrides[i] = m_outputStrides[i - 1] * m_dimensions[i - 1];
      }
    } else {
      inputStrides[NumDims - 1] = 1;
      m_outputStrides[NumDims - 1] = 1;
      for (int i = NumDims - 2; i >= 0; --i) {
        inputStrides[i] = inputStrides[i + 1] * input_dims[i + 1];
        m_outputStrides[i] = m_outputStrides[i + 1] * m_dimensions[i + 1];
      }
    }

    for (int i = 0; i < NumDims; ++i) {
      m_inputStrides[i] = inputStrides[shuffle[i]];
    }
  }

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章