[ Tensor ] Support NHWC for dot, add/multiply_strided and other ops
authorAdwaith Anand <adwaith.a@samsung.com>
Wed, 28 Jun 2023 10:19:43 +0000 (15:49 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 21 Aug 2023 06:29:23 +0000 (15:29 +0900)
This PR includes changes of Tensor and TensorDim to support NHWC
computation for dot, add_strided, multiply_strided, cat, split,
and transpose. It also includes unittests to evaluate.

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Adwaith Anand <adwaith.a@samsung.com>
Signed-off-by: Manohara HK <manohara.hk@samsung.com>
Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
nntrainer/tensor/tensor.cpp
nntrainer/tensor/tensor.h

index 8cf7e8c..52fdc8c 100644 (file)
@@ -98,7 +98,7 @@ struct Tensor::BroadcastInfo {
 
 Tensor::Tensor(const TensorDim &d, bool alloc_now, Tensor::Initializer init,
                std::string name_) :
-  Tensor(name_) {
+  Tensor(name_, d.getFormat()) {
   if (d.getDataLen() != 0) {
     dim = d;
     strides = d.computeStrides();
index 72cb53b..23b5dcc 100644 (file)
@@ -324,7 +324,7 @@ public:
   /**
    * @brief     Constructor of Tensor
    * @note      This constructor copies vector again. needs refactoring
-   * @param[in] d data for the Tensor
+   * @param[in] d data for the Tensor. It needs to set format properly.
    */
   Tensor(std::vector<std::vector<std::vector<_FP16>>> const &d,
          ml::train::TensorDim::TensorType t_type) :