[Tensor] Add optional output tensor for tensor concatenation
authorDonghyeon Jeong <dhyeon.jeong@samsung.com>
Mon, 12 Aug 2024 02:13:29 +0000 (11:13 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 16 Aug 2024 01:22:43 +0000 (10:22 +0900)
This PR adds an optional feature in Tensor::cat to pass the output tensor to the function.
This change allows the user-given tensor to store the result of the concatenation without creating a new tensor.

**Changes proposed in this PR:**
- Add optional argument output (the output tensor) to the cat function.
- Add negative test cases for tensor concatenation.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <dhyeon.jeong@samsung.com>
nntrainer/tensor/float_tensor.cpp
nntrainer/tensor/float_tensor.h
nntrainer/tensor/half_tensor.cpp
nntrainer/tensor/half_tensor.h
nntrainer/tensor/tensor.cpp
nntrainer/tensor/tensor.h
nntrainer/tensor/tensor_base.cpp
nntrainer/tensor/tensor_base.h
test/unittest/unittest_nntrainer_tensor.cpp

index c925995fc9678be1ecb6279dccd16dd982cc931a..b35894f9ec6c960502cc3ad3bc875f5c19eea09e 100644 (file)
@@ -1027,28 +1027,10 @@ std::vector<Tensor> FloatTensor::split(std::vector<size_t> sizes, int axis) {
   return ret;
 }
 
-Tensor FloatTensor::concat(const std::vector<Tensor> &tensors, int axis) {
-  if (axis == -1) {
-    axis = 3;
-  }
+Tensor FloatTensor::concat(const std::vector<Tensor> &tensors, int axis,
+                           Tensor &output) {
+  bool is_format_nchw = (tensors.front().getDim().getFormat() == Tformat::NCHW);
 
-  auto ref_dim = tensors.front().getDim();
-  bool is_format_nchw = (ref_dim.getFormat() == Tformat::NCHW);
-  ref_dim.setTensorDim(axis, 1);
-  NNTR_THROW_IF(!std::all_of(tensors.begin(), tensors.end(),
-                             [&ref_dim, axis](const Tensor &t) {
-                               auto cur_dim = t.getDim();
-                               cur_dim.setTensorDim(axis, 1);
-                               return ref_dim == cur_dim;
-                             }),
-                std::invalid_argument)
-    << " all tensor must have the same dimension except for the axis, ref_dim: "
-    << ref_dim << " axis : " << axis;
-
-  auto axis_dim = std::accumulate(tensors.begin(), tensors.end(), 0u,
-                                  [axis](unsigned cur, const Tensor &t) {
-                                    return cur += t.getDim().getTensorDim(axis);
-                                  });
   auto iter_value =
     [is_format_nchw](std::array<unsigned, 4> &loc,
                      const std::array<unsigned, 4> &start_loc, Tensor &t,
@@ -1068,45 +1050,38 @@ Tensor FloatTensor::concat(const std::vector<Tensor> &tensors, int axis) {
     return value;
   };
 
-  auto ret_dim = ref_dim;
-  ret_dim.setTensorDim(axis, axis_dim);
-
-  Tensor ret = Tensor(ret_dim);
-
   std::array<unsigned, 4> loc = {0, 0, 0, 0};
   for (auto &t : tensors) {
     std::array<unsigned, 4> start_loc = loc;
     std::array<unsigned, 4> tensor_dim_arr;
-    if (is_format_nchw) {
-      tensor_dim_arr[0] = t.getDim().getTensorDim(0);
-      tensor_dim_arr[1] = t.getDim().getTensorDim(1);
-      tensor_dim_arr[2] = t.getDim().getTensorDim(2);
-      tensor_dim_arr[3] = t.getDim().getTensorDim(3);
-    } else {
-      tensor_dim_arr[0] = t.getDim().getTensorDim(0);
-      tensor_dim_arr[1] = t.getDim().getTensorDim(2);
-      tensor_dim_arr[2] = t.getDim().getTensorDim(3);
-      tensor_dim_arr[3] = t.getDim().getTensorDim(1);
-    }
+    TensorDim curr_dim = t.getDim();
+
+    tensor_dim_arr[0] = curr_dim.getTensorDim(0);
+    tensor_dim_arr[1] =
+      is_format_nchw ? curr_dim.getTensorDim(1) : curr_dim.getTensorDim(2);
+    tensor_dim_arr[2] =
+      is_format_nchw ? curr_dim.getTensorDim(2) : curr_dim.getTensorDim(3);
+    tensor_dim_arr[3] =
+      is_format_nchw ? curr_dim.getTensorDim(3) : curr_dim.getTensorDim(1);
 
     for (size_t i = 0u, sz = t.size(); i < sz; ++i) {
-      iter_value(loc, start_loc, ret, tensor_dim_arr) = t.getValue<float>(i);
+      iter_value(loc, start_loc, output, tensor_dim_arr) = t.getValue<float>(i);
     }
 
     if (is_format_nchw) {
-      loc[axis] += t.getDim().getTensorDim(axis);
+      loc[axis] += curr_dim.getTensorDim(axis);
     } else {
       if (axis == 0) {
-        loc[0] += t.getDim().getTensorDim(axis);
+        loc[0] += curr_dim.getTensorDim(axis);
       } else if (axis == 1) {
-        loc[3] += t.getDim().getTensorDim(axis);
+        loc[3] += curr_dim.getTensorDim(axis);
       } else if (axis == 2 || axis == 3) {
-        loc[axis - 1] += t.getDim().getTensorDim(axis);
+        loc[axis - 1] += curr_dim.getTensorDim(axis);
       }
     }
   }
 
-  return ret;
+  return output;
 }
 
 void FloatTensor::print(std::ostream &out) const {
index 017433e7c9feed5d0757eb4dbc79aad8489deb1c..2818f18cee702bc272e54f717045dfcbca53810c 100644 (file)
@@ -407,9 +407,10 @@ public:
   std::vector<Tensor> split(std::vector<size_t> sizes, int axis) override;
 
   /**
-   * @copydoc Tensor::cat(const std::vector<Tensor> &tensors, int axis)
+   * @copydoc Tensor::concat()
    */
-  Tensor concat(const std::vector<Tensor> &tensors, int axis) override;
+  Tensor concat(const std::vector<Tensor> &tensors, int axis,
+                Tensor &output) override;
 
   /**
    * @copydoc Tensor::copy(const Tensor &from)
index aa43dda0482076d3f6da79fb39989212810ca02f..bea483df54f7c1147036249ad2419d865c1cdf08 100644 (file)
@@ -836,27 +836,10 @@ std::vector<Tensor> HalfTensor::split(std::vector<size_t> sizes, int axis) {
   return ret;
 }
 
-Tensor HalfTensor::concat(const std::vector<Tensor> &tensors, int axis) {
-  if (axis == -1) {
-    axis = 3;
-  }
-  auto ref_dim = tensors.front().getDim();
-  bool is_format_nchw = (ref_dim.getFormat() == Tformat::NCHW);
-  ref_dim.setTensorDim(axis, 1);
-  NNTR_THROW_IF(!std::all_of(tensors.begin(), tensors.end(),
-                             [&ref_dim, axis](const Tensor &t) {
-                               auto cur_dim = t.getDim();
-                               cur_dim.setTensorDim(axis, 1);
-                               return ref_dim == cur_dim;
-                             }),
-                std::invalid_argument)
-    << " all tensor must have the same dimension except for the axis, ref_dim: "
-    << ref_dim << " axis : " << axis;
+Tensor HalfTensor::concat(const std::vector<Tensor> &tensors, int axis,
+                          Tensor &output) {
+  bool is_format_nchw = (tensors.front().getDim().getFormat() == Tformat::NCHW);
 
-  auto axis_dim = std::accumulate(tensors.begin(), tensors.end(), 0u,
-                                  [axis](unsigned cur, const Tensor &t) {
-                                    return cur += t.getDim().getTensorDim(axis);
-                                  });
   auto iter_value =
     [is_format_nchw](std::array<unsigned, 4> &loc,
                      const std::array<unsigned, 4> &start_loc, Tensor &t,
@@ -876,40 +859,33 @@ Tensor HalfTensor::concat(const std::vector<Tensor> &tensors, int axis) {
     return value;
   };
 
-  auto ret_dim = ref_dim;
-  ret_dim.setTensorDim(axis, axis_dim);
-
-  Tensor output = Tensor(ret_dim);
-
   std::array<unsigned, 4> loc = {0, 0, 0, 0};
   for (auto &t : tensors) {
     std::array<unsigned, 4> start_loc = loc;
     std::array<unsigned, 4> tensor_dim_arr;
-    if (is_format_nchw) {
-      tensor_dim_arr[0] = t.getDim().getTensorDim(0);
-      tensor_dim_arr[1] = t.getDim().getTensorDim(1);
-      tensor_dim_arr[2] = t.getDim().getTensorDim(2);
-      tensor_dim_arr[3] = t.getDim().getTensorDim(3);
-    } else {
-      tensor_dim_arr[0] = t.getDim().getTensorDim(0);
-      tensor_dim_arr[1] = t.getDim().getTensorDim(2);
-      tensor_dim_arr[2] = t.getDim().getTensorDim(3);
-      tensor_dim_arr[3] = t.getDim().getTensorDim(1);
-    }
+    TensorDim curr_dim = t.getDim();
+
+    tensor_dim_arr[0] = curr_dim.getTensorDim(0);
+    tensor_dim_arr[1] =
+      is_format_nchw ? curr_dim.getTensorDim(1) : curr_dim.getTensorDim(2);
+    tensor_dim_arr[2] =
+      is_format_nchw ? curr_dim.getTensorDim(2) : curr_dim.getTensorDim(3);
+    tensor_dim_arr[3] =
+      is_format_nchw ? curr_dim.getTensorDim(3) : curr_dim.getTensorDim(1);
 
     for (size_t i = 0u, sz = t.size(); i < sz; ++i) {
       iter_value(loc, start_loc, output, tensor_dim_arr) = t.getValue<_FP16>(i);
     }
 
     if (is_format_nchw) {
-      loc[axis] += t.getDim().getTensorDim(axis);
+      loc[axis] += curr_dim.getTensorDim(axis);
     } else {
       if (axis == 0) {
-        loc[0] += t.getDim().getTensorDim(axis);
+        loc[0] += curr_dim.getTensorDim(axis);
       } else if (axis == 1) {
-        loc[3] += t.getDim().getTensorDim(axis);
+        loc[3] += curr_dim.getTensorDim(axis);
       } else if (axis == 2 || axis == 3) {
-        loc[axis - 1] += t.getDim().getTensorDim(axis);
+        loc[axis - 1] += curr_dim.getTensorDim(axis);
       }
     }
   }
index 8db09c0cce00db89e40e58a201d4945846a47617..7849d04d19fb9fdd3f50990fc7664f746eb24113 100644 (file)
@@ -397,9 +397,10 @@ public:
   std::vector<Tensor> split(std::vector<size_t> sizes, int axis) override;
 
   /**
-   * @copydoc Tensor::cat(const std::vector<Tensor> &tensors, int axis)
+   * @copydoc Tensor::concat()
    */
-  Tensor concat(const std::vector<Tensor> &tensors, int axis) override;
+  Tensor concat(const std::vector<Tensor> &tensors, int axis,
+                Tensor &output) override;
 
   /**
    * @copydoc Tensor::copy(const Tensor &from)
index ac45520108d671df9c01d57d645628dd2328588a..8ffe57929d310417d9d369eb791ea9c9c01af0c8 100644 (file)
@@ -17,7 +17,6 @@
 #ifdef ENABLE_FP16
 #include <half_tensor.h>
 #endif
-
 namespace nntrainer {
 
 Tensor::Tensor(std::string name_, Tformat fm, Tdatatype d_type) {
@@ -822,19 +821,37 @@ std::vector<Tensor> Tensor::split(std::vector<size_t> sizes, int axis) {
   return itensor->split(sizes, axis);
 }
 
-Tensor Tensor::concat(const std::vector<Tensor> &tensors, int axis) {
-  NNTR_THROW_IF(!(-1 <= axis && axis < 4), std::invalid_argument)
-    << "cannot split axis of axis: " << axis;
+Tensor Tensor::concat(const std::vector<Tensor> &tensors, int axis,
+                      Tensor &output) {
+  return itensor->concat(tensors, axis, output);
+}
+
+Tensor Tensor::cat(const std::vector<Tensor> &tensors, int axis) {
+  if (axis == -1) {
+    axis = 3;
+  }
 
-  NNTR_THROW_IF(tensors.empty(), std::invalid_argument)
-    << "given tensor vector is empty";
+  // Create an output tensor to store the concatenation result
+  TensorDim out_dim = Tensor::calculateConcatOutputDim(tensors, axis);
+  Tensor output = Tensor(out_dim);
 
-  return itensor->concat(tensors, axis);
+  return output.concat(tensors, axis, output);
 }
 
-Tensor Tensor::cat(const std::vector<Tensor> &tensors, int axis) {
-  Tensor input = tensors[0];
-  return input.concat(tensors, axis);
+Tensor Tensor::cat(const std::vector<Tensor> &tensors, int axis,
+                   Tensor &output) {
+  if (axis == -1) {
+    axis = 3;
+  }
+
+  // Check if the given output tensor dimension is valid
+  TensorDim out_dim = Tensor::calculateConcatOutputDim(tensors, axis);
+
+  NNTR_THROW_IF(out_dim != output.getDim(), std::invalid_argument)
+    << "invalid output dim for concatenation " << output.getDim()
+    << "expected output dim " << out_dim;
+
+  return output.concat(tensors, axis, output);
 }
 
 void Tensor::print(std::ostream &out) const {
@@ -1088,6 +1105,46 @@ void Tensor::setTensorVar(TensorDim d, void *buf, size_t offset) {
   itensor->setTensorVar(d, buf, offset);
 }
 
+TensorDim Tensor::calculateConcatOutputDim(const std::vector<Tensor> &tensors,
+                                           int axis) {
+  // Check axis, in which the tensors are concatenated, is valid.
+  NNTR_THROW_IF(!(-1 <= axis && axis < 4), std::invalid_argument)
+    << "cannot concatenate tensors along an axis: " << axis;
+
+  // Check if the number of input tensors is valid.
+  NNTR_THROW_IF(tensors.size() <= 1, std::invalid_argument)
+    << "received an invalid tensor vector. size must be greater than 1.";
+
+  auto out_dim = tensors.front().getDim();
+
+  // Check if all tensor data types are the same.
+  for (auto &t : tensors) {
+    NNTR_THROW_IF(t.getDataType() != out_dim.getDataType(),
+                  std::invalid_argument)
+      << "cannot concatenate tensors with different data types.";
+  }
+
+  // Compute the dimensions of an output tensor.
+  out_dim.setTensorDim(axis, 1);
+  NNTR_THROW_IF(!std::all_of(tensors.begin(), tensors.end(),
+                             [&out_dim, axis](const Tensor &t) {
+                               auto cur_dim = t.getDim();
+                               cur_dim.setTensorDim(axis, 1);
+                               return out_dim == cur_dim;
+                             }),
+                std::invalid_argument)
+    << " all tensor must have the same dimension except for the axis, out_dim: "
+    << out_dim << " axis : " << axis;
+
+  auto axis_dim = std::accumulate(tensors.begin(), tensors.end(), 0u,
+                                  [axis](unsigned cur, const Tensor &t) {
+                                    return cur += t.getDim().getTensorDim(axis);
+                                  });
+
+  out_dim.setTensorDim(axis, axis_dim);
+  return out_dim;
+}
+
 std::ostream &operator<<(std::ostream &out, Tensor const &input) {
   input.print(out);
   return out;
index b564b4316802c1cfbd597502374c6a266dd788ef..472d694f4b61675875be2af58a1adddfd7a3aa03 100644 (file)
@@ -1202,9 +1202,12 @@ public:
    *
    * @param tensors tensors to be concatenated to the first tensor
    * @param axis axis
+   * @param output output tensor to store the result
    * @return Tensor concatenated tensor
+   *
+   * @note  This function should not be used directly. Please use cat() instead.
    */
-  Tensor concat(const std::vector<Tensor> &tensors, int axis = 0);
+  Tensor concat(const std::vector<Tensor> &tensors, int axis, Tensor &output);
 
   /**
    * @brief concatenate tensors along axis
@@ -1215,6 +1218,17 @@ public:
    */
   static Tensor cat(const std::vector<Tensor> &tensors, int axis = 0);
 
+  /**
+   * @brief concatenate tensors along axis
+   *
+   * @param tensors tensors to be concatenated to the first tensor
+   * @param axis axis
+   * @param output output tensor to store the result
+   * @return Tensor concatenated tensor
+   */
+  static Tensor cat(const std::vector<Tensor> &tensors, int axis,
+                    Tensor &output);
+
   /**
    * @brief     Print element
    * @param[in] out out stream
@@ -1546,6 +1560,16 @@ private:
    * @param[in] offset offset to be used
    */
   void setTensorVar(TensorDim d, void *buf, size_t offset);
+
+  /**
+   * @brief Calculate the output tensor dimension of the concatenating a list of
+   * tensors as an input.
+   *
+   * @param[in] tensors tensors to be concatenated to the first tensor
+   * @param[in] axis axis
+   */
+  static TensorDim calculateConcatOutputDim(const std::vector<Tensor> &tensors,
+                                            int axis);
 };
 
 /**
index d982a4147e308805432c9780e599ec9c4de81906..8471dcb6e2bc0a802e7acac7f4cf0f1d52b32edc 100644 (file)
@@ -542,7 +542,8 @@ std::vector<Tensor> TensorBase::split(std::vector<size_t> sizes, int axis) {
   return ret;
 }
 
-Tensor TensorBase::concat(const std::vector<Tensor> &tensors, int axis) {
+Tensor TensorBase::concat(const std::vector<Tensor> &tensors, int axis,
+                          Tensor &output) {
   throw std::invalid_argument(
     "Tensor::concat() is currently not supported in tensor data type " +
     getStringDataType());
index 8caaeadd340a84765a29b1ee05febf5087ae5726..f8a043eb4f7613744b27a61e031bf837a9ee3757 100644 (file)
@@ -397,9 +397,10 @@ public:
   virtual std::vector<Tensor> split(std::vector<size_t> sizes, int axis);
 
   /**
-   * @copydoc Tensor::concat(const std::vector<Tensor> &tensors, int axis)
+   * @copydoc Tensor::concat()
    */
-  virtual Tensor concat(const std::vector<Tensor> &tensors, int axis);
+  virtual Tensor concat(const std::vector<Tensor> &tensors, int axis,
+                        Tensor &output);
 
   /**
    * @copydoc Tensor::print(std::ostream &out)
index 71b0a384ae7f4aa61a25f59d2f13c4126e239990..e07dc8bafc0af3d363bc81b57b7025cda5c9af07 100644 (file)
@@ -3898,7 +3898,7 @@ TEST(nntrainer_Tensor, cat_01_p) {
       18, 54, 55, 56, 36, 37, 19, 57, 58, 59, 38, 39, 20, 60, 61, 62, 40, 41,
       21, 63, 64, 65, 42, 43, 22, 66, 67, 68, 44, 45, 23, 69, 70, 71, 46, 47};
     nntrainer::Tensor answer(ml::train::TensorDim{3, 2, 4, 6}, answer_data);
-    EXPECT_EQ(nntrainer::Tensor::cat(inputs, 3), answer);
+    EXPECT_EQ(nntrainer::Tensor::cat(inputs, -1), answer);
   }
 }
 
@@ -3912,6 +3912,51 @@ TEST(nntrainer_Tensor, cat_02_n) {
   }
 }
 
+// concatenate an empty list of tensors
+TEST(nntrainer_Tensor, cat_03_n) {
+  std::vector<nntrainer::Tensor> inputs;
+  EXPECT_THROW(nntrainer::Tensor::cat(inputs, 0), std::invalid_argument);
+}
+
+// concatenate a single tensor
+TEST(nntrainer_Tensor, cat_04_n) {
+  std::vector<nntrainer::Tensor> inputs;
+  inputs.reserve(1);
+  inputs.emplace_back(nntrainer::Tensor(2, 1, 1, 2));
+  EXPECT_THROW(nntrainer::Tensor::cat(inputs, 0), std::invalid_argument);
+}
+
+// concatenate tensors with different data types
+TEST(nntrainer_Tensor, cat_05_n) {
+  std::vector<nntrainer::Tensor> inputs;
+  inputs.reserve(2);
+  inputs.emplace_back(nntrainer::Tensor(2, 1, 1, 2));
+  inputs.emplace_back(nntrainer::Tensor(
+    2, 1, 1, 2, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::QINT8}));
+  EXPECT_THROW(nntrainer::Tensor::cat(inputs, 0), std::invalid_argument);
+}
+
+// incorrect output tensor dimension
+TEST(nntrainer_Tensor, cat_06_n) {
+  std::vector<nntrainer::Tensor> inputs;
+  inputs.reserve(2);
+  inputs.emplace_back(nntrainer::Tensor(3, 2, 4, 1));
+  inputs.emplace_back(nntrainer::Tensor(3, 2, 4, 3));
+  nntrainer::Tensor output(3, 2, 4, 5);
+  EXPECT_THROW(nntrainer::Tensor::cat(inputs, 3, output),
+               std::invalid_argument);
+}
+
+// tensors not having the same shape except for the axis
+TEST(nntrainer_Tensor, cat_07_n) {
+  std::vector<nntrainer::Tensor> inputs;
+  inputs.reserve(2);
+  inputs.emplace_back(nntrainer::Tensor(3, 2, 4, 1));
+  inputs.emplace_back(nntrainer::Tensor(3, 1, 4, 3));
+  EXPECT_THROW(nntrainer::Tensor::cat(inputs, 1), std::invalid_argument);
+  EXPECT_THROW(nntrainer::Tensor::cat(inputs, 3), std::invalid_argument);
+}
+
 TEST(nntrainer_Tensor, zoneout_mask_01_n) {
   const float zoneout_rate = 0.3f;
   nntrainer::Tensor t(10, 10, 10, 10);