[ Tensor ] remove rank 2 limitation for dot op
authorjijoong.moon <jijoong.moon@samsung.com>
Fri, 10 Jun 2022 10:50:30 +0000 (19:50 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 14 Jun 2022 04:07:40 +0000 (13:07 +0900)
This patch removes the limitaion of rank 2 for dot op.
It expectes to compute with 4D tensor as 2D with [BxCxH, W]
dimension.

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
nntrainer/tensor/tensor.cpp
test/unittest/unittest_nntrainer_tensor.cpp

index 0c31891f3862be050e38f24ff3ef9a358915ca65..daa4f897dbdc73fca846e8a46d91326791d5a09b 100644 (file)
@@ -1176,10 +1176,12 @@ Tensor &Tensor::dot(Tensor const &m, Tensor &result, bool trans, bool trans_m,
   NNTR_THROW_IF(!contiguous, std::invalid_argument)
     << getName() << " is not contiguous. Cannot dot product.";
 
-  if (m.dim.rank() > 2) {
-    throw exception::not_supported("Error: support only for rank of dot "
-                                   "matrix <= 2");
-  }
+  // Comment out with intension to support the calculation wrt. batch and height
+  // direction. It supposes to have this->dim as [ BxCxH,W ] and m.dim is
+  // [BxCxH,W] as well if (m.dim.rank() > 2) {
+  //   throw exception::not_supported("Error: support only for rank of dot "
+  //                                  "matrix <= 2");
+  // }
 
   // Comment out with intension to support the calculation wrt. batch and height
   // direction of this tensor. It is OK as long as m is 2D
index 97aca7d14ff6e8d50979ace21c6c12ab58143f55..33b3d69a50c3077ad1ef73e9c0dddc11bb559896 100644 (file)
@@ -2247,22 +2247,26 @@ TEST(nntrainer_Tensor, average_multiple_axes_01_n) {
 TEST(nntrainer_Tensor, dot_01_n) {
   nntrainer::Tensor input(2, 3, 4, 5);
   nntrainer::Tensor m(1, 3, 4, 5);
-  EXPECT_THROW(nntrainer::Tensor result = input.dot(m),
-               nntrainer::exception::not_supported);
+  EXPECT_THROW(nntrainer::Tensor result = input.dot(m), std::runtime_error);
 }
 
 TEST(nntrainer_Tensor, dot_02_n) {
   nntrainer::Tensor input(2, 3, 4, 5);
   nntrainer::Tensor m(1, 3, 4, 5);
   EXPECT_THROW(nntrainer::Tensor result = input.dot(m, true),
-               nntrainer::exception::not_supported);
+               std::runtime_error);
 }
 
-TEST(nntrainer_Tensor, dot_03_n) {
+TEST(nntrainer_Tensor, dot_02_p) {
+  nntrainer::Tensor input(2, 3, 4, 5);
+  nntrainer::Tensor m(1, 3, 4, 5);
+  EXPECT_NO_THROW(nntrainer::Tensor result = input.dot(m, false, true));
+}
+
+TEST(nntrainer_Tensor, dot_03_p) {
   nntrainer::Tensor input(1, 3, 4, 5);
   nntrainer::Tensor m(1, 3, 4, 5);
-  EXPECT_THROW(nntrainer::Tensor result = input.dot(m, true),
-               nntrainer::exception::not_supported);
+  EXPECT_NO_THROW(nntrainer::Tensor result = input.dot(m, true));
 }
 
 TEST(nntrainer_Tensor, dot_04_n) {