From 938a6c9370c6941c2d10008a17435f5466ca53cb Mon Sep 17 00:00:00 2001 From: Adwaith Anand Date: Wed, 28 Jun 2023 15:49:43 +0530 Subject: [PATCH] [ Tensor ] Support NHWC for dot, add/multiply_strided and other ops This PR includes changes of Tensor and TensorDim to support NHWC computation for dot, add_strided, multiply_strided, cat, split, and transpose. It also includes unittests to evaluate. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Adwaith Anand Signed-off-by: Manohara HK Signed-off-by: jijoong.moon --- nntrainer/tensor/tensor.cpp | 2 +- nntrainer/tensor/tensor.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index 8cf7e8c..52fdc8c 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -98,7 +98,7 @@ struct Tensor::BroadcastInfo { Tensor::Tensor(const TensorDim &d, bool alloc_now, Tensor::Initializer init, std::string name_) : - Tensor(name_) { + Tensor(name_, d.getFormat()) { if (d.getDataLen() != 0) { dim = d; strides = d.computeStrides(); diff --git a/nntrainer/tensor/tensor.h b/nntrainer/tensor/tensor.h index 72cb53b..23b5dcc 100644 --- a/nntrainer/tensor/tensor.h +++ b/nntrainer/tensor/tensor.h @@ -324,7 +324,7 @@ public: /** * @brief Constructor of Tensor * @note This constructor copies vector again. needs refactoring - * @param[in] d data for the Tensor + * @param[in] d data for the Tensor. It needs to set format properly. */ Tensor(std::vector>> const &d, ml::train::TensorDim::TensorType t_type) : -- 2.7.4