Prepare reshaping input(rank-4) into rank-2 on fc (#2505)
author김용섭/동작제어Lab(SR)/Engineer/삼성전자 <yons.kim@samsung.com>
Tue, 28 Aug 2018 10:34:36 +0000 (19:34 +0900)
committer박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 28 Aug 2018 10:34:36 +0000 (19:34 +0900)
When input which has rank 4 as fc's input comes, it needs to be reshaped
into rank 2. This commit prepares the task by skeleton code and checking
input_size, output_size and batch_size

Signed-off-by: Yongseop Kim <yons.kim@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc

index 181727d..4f09af5 100644 (file)
@@ -1978,15 +1978,48 @@ void Planner::visit(const ::internal::tflite::op::FullyConnected::Node &node)
   assert(_ctx.at(weight_index).shape().rank() == 2);
   assert(_ctx.at(bias_index).shape().rank() == 1);
 
+  const auto input_rank = _ctx.at(input_index).shape().rank();
+  // TODO Currently we are not handling where the case is that the input's rank is 3.
+  // The handling should be added in the future.
+  assert(input_rank != 3);
+
+  const auto output_size = _ctx.at(output_index).shape().dim(1);
+  assert(_ctx.at(bias_index).shape().dim(0) == output_size);
+  assert(_ctx.at(weight_index).shape().dim(0) == output_size);
+  const auto batch_size = _ctx.at(output_index).shape().dim(0);
+  const auto input_size = _ctx.at(weight_index).shape().dim(1);
+
+  // Check for reshaping input's shape into rank-2 and do reshaping
+  if (input_rank == 4)
+  {
+    nnfw::util::feature::Shape ifm_shape_feature = _ctx.at(input_index).shape().asFeature();
+    auto feature_size =
+        ifm_shape_feature.N * ifm_shape_feature.C * ifm_shape_feature.H * ifm_shape_feature.W;
+    assert(feature_size == batch_size * input_size);
+
+    // TODO Add reshaping
+    _builder.addShapeConstr(
+        input_index, asTensorInfo(ifm_shape_feature, _ctx.at(input_index).type(),
+                                  _ctx.at(input_index).scale(), _ctx.at(input_index).zeroPoint()));
+  }
+  else if (input_rank == 2)
+  {
+    auto ifm_shape = _ctx.at(input_index).shape();
+    nnfw::util::matrix::Shape ifm_shape_matrix = ifm_shape.asMatrix();
+    assert(ifm_shape.dim(0) == batch_size);
+    assert(ifm_shape.dim(1) == input_size);
+
+    _builder.addShapeConstr(input_index, asTensorInfo(ifm_shape_matrix, _ctx.at(input_index).type(),
+                                                      _ctx.at(input_index).scale(),
+                                                      _ctx.at(input_index).zeroPoint()));
+  }
+
   // TODO Should move to the place where the operand is handled, if it is possible.
   // Set Shape Constraints
   _builder.addShapeConstr(
       output_index, asTensorInfo(_ctx.at(output_index).shape(), _ctx.at(output_index).type(),
                                  _ctx.at(output_index).scale(), _ctx.at(output_index).zeroPoint()));
   _builder.addShapeConstr(
-      input_index, asTensorInfo(_ctx.at(input_index).shape(), _ctx.at(input_index).type(),
-                                _ctx.at(input_index).scale(), _ctx.at(input_index).zeroPoint()));
-  _builder.addShapeConstr(
       weight_index, asTensorInfo(_ctx.at(weight_index).shape(), _ctx.at(weight_index).type(),
                                  _ctx.at(weight_index).scale(), _ctx.at(weight_index).zeroPoint()));
   _builder.addShapeConstr(