Add a default param to asTensorShape() to apply conditionally (#2464)
author장지섭/동작제어Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Fri, 31 Aug 2018 08:22:05 +0000 (17:22 +0900)
committer이춘석/동작제어Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Fri, 31 Aug 2018 08:22:05 +0000 (17:22 +0900)
This commit add a default param to asTensorShape() to apply conditionally dimension coorection.

In some cases, in incorrect dimensions is required.
For example, intput_size is 1 in LSTM. The input-to-input weights([num_units, input_size]) of LSTM is used as the weight of the FullyConnected.
The FullyConnected's weight must be greater or equal than 2-dimensions.
However, if the dimension correction is applied to input_to_input_weights with input_size equal to 1, it will be changed to 1-D.
So input_to_input_weights is not used by the weight of FullyConnected.

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/pure_arm_compute/src/internal/arm_compute/Cast.h

index 470e76c..d4c5dd0 100644 (file)
@@ -5,7 +5,8 @@
 #include "internal/Swizzle.h"
 #include "internal/Model.h"
 
-inline ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand::Shape &shape)
+inline ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand::Shape &shape,
+                                                bool apply_dim_correction = true)
 {
   const uint32_t rank = shape.rank();
 
@@ -15,10 +16,14 @@ inline ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand:
 
   for (uint32_t axis = 0; axis < rank; ++axis)
   {
-    // NOTE Do NOT update TensorShape with operator[] (in ::arm_compute::Dimensions)
-    //      TensorShape::set applies dimension correction after value update.
-    //      Various asserts in ARMCompute work only when this correction is applied.
-    res.set(ToARMComputeAxis(rank, axis).value(), shape.dim(axis));
+    // NOTE In some cases, in incorrect dimensions is required.
+    // For example, intput_size is 1 in LSTM. The input-to-input weights([num_units, input_size]) of
+    // LSTM is used as the weight of the FullyConnected.
+    // The FullyConnected's weight must be greater or equal than 2-dimensions.
+    // However, if the dimension correction is applied to input_to_input_weights with input_size
+    // equal to 1, it will be changed to 1-D.
+    // So input_to_input_weights is not used by the weight of FullyConnected.
+    res.set(ToARMComputeAxis(rank, axis).value(), shape.dim(axis), apply_dim_correction);
   }
 
   return res;