From: Adwaith Anand Date: Wed, 28 Jun 2023 10:19:43 +0000 (+0530) Subject: [ Tensor ] Support NHWC for dot, add/multiply_strided and other ops X-Git-Tag: accepted/tizen/8.0/unified/20231005.093407~78 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=938a6c9370c6941c2d10008a17435f5466ca53cb;p=platform%2Fcore%2Fml%2Fnntrainer.git [ Tensor ] Support NHWC for dot, add/multiply_strided and other ops This PR includes changes of Tensor and TensorDim to support NHWC computation for dot, add_strided, multiply_strided, cat, split, and transpose. It also includes unittests to evaluate. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Adwaith Anand Signed-off-by: Manohara HK Signed-off-by: jijoong.moon --- diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index 8cf7e8c..52fdc8c 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -98,7 +98,7 @@ struct Tensor::BroadcastInfo { Tensor::Tensor(const TensorDim &d, bool alloc_now, Tensor::Initializer init, std::string name_) : - Tensor(name_) { + Tensor(name_, d.getFormat()) { if (d.getDataLen() != 0) { dim = d; strides = d.computeStrides(); diff --git a/nntrainer/tensor/tensor.h b/nntrainer/tensor/tensor.h index 72cb53b..23b5dcc 100644 --- a/nntrainer/tensor/tensor.h +++ b/nntrainer/tensor/tensor.h @@ -324,7 +324,7 @@ public: /** * @brief Constructor of Tensor * @note This constructor copies vector again. needs refactoring - * @param[in] d data for the Tensor + * @param[in] d data for the Tensor. It needs to set format properly. */ Tensor(std::vector>> const &d, ml::train::TensorDim::TensorType t_type) :