Pure CL: 2D Concat (#1636)
author오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Mon, 11 Jun 2018 02:52:31 +0000 (11:52 +0900)
committer서상민/동작제어Lab(SR)/Staff Engineer/삼성전자 <sangmin7.seo@samsung.com>
Mon, 11 Jun 2018 02:52:31 +0000 (11:52 +0900)
Support 2D concat (axis: row)

Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc

index fb22c16..ac79353 100644 (file)
@@ -976,29 +976,74 @@ void Planner::visit(const ::internal::tflite::op::Concat::Node &node)
 
   // NOTE This implementation assumes that inputs and output are a feature
   // TODO Remove this assumption
-  const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature();
+  uint32_t input_rank = _ctx.at(ofm_index).shape().rank();
+  assert(input_rank == 4 || input_rank == 2);
 
-  // NOTE This implementation assumes concat over feature depth
   // TODO Remove this assumption
-  assert(_ctx.at(::internal::tflite::operand::Index{node.param().axis_index}).asScala<int32_t>() ==
-         3);
+  if (input_rank == 4)
+  {
+    // NOTE This implementation assumes concat over feature depth
+    assert(
+        _ctx.at(::internal::tflite::operand::Index{node.param().axis_index}).asScala<int32_t>() ==
+        3);
 
-  // TODO Should move to the place where the operand is handled, if it is possible.
-  // Set Shape Constraints and TensorInfo (for output)
-  _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type()));
+    const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature();
+
+    // TODO Should move to the place where the operand is handled, if it is possible.
+    // Set Shape Constraints and TensorInfo (for output)
+    _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type()));
 
-  // Set Shape Constraints (for input)
-  uint32_t depth = 0;
+    // Set Shape Constraints (for input)
+    uint32_t depth = 0;
 
-  for (const auto &index : node.param().ifm_indexes)
+    for (const auto &index : node.param().ifm_indexes)
+    {
+      const ::internal::tflite::operand::Index ifm_index{index};
+      const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature();
+
+      _builder.addSubsumptionConstr(ifm_index, ofm_index,
+                                    ::arm_compute::Coordinates{0, 0, depth, 0},
+                                    asTensorShape(ifm_shape));
+
+      depth += ifm_shape.C;
+    }
+  }
+  else if (input_rank == 2)
   {
-    const ::internal::tflite::operand::Index ifm_index{index};
-    const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature();
+    // NOTE This implementation assumes concat over matrix row
+    assert(
+        _ctx.at(::internal::tflite::operand::Index{node.param().axis_index}).asScala<int32_t>() ==
+        0);
 
-    _builder.addSubsumptionConstr(ifm_index, ofm_index, ::arm_compute::Coordinates{0, 0, depth, 0},
-                                  asTensorShape(ifm_shape));
+    const auto ofm_shape = _ctx.at(ofm_index).shape();
+    const auto ofm_rows = ofm_shape.dim(0);
+    const auto ofm_cols = ofm_shape.dim(1);
 
-    depth += ifm_shape.C;
+    // TODO Should move to the place where the operand is handled, if it is possible.
+    // Set Shape Constraints and TensorInfo (for output)
+    _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_rows, ofm_cols, _ctx.at(ofm_index).type()));
+
+    // Set Shape Constraints (for input)
+    uint32_t row_offset = 0;
+
+    for (const auto &index : node.param().ifm_indexes)
+    {
+      const ::internal::tflite::operand::Index ifm_index{index};
+      const auto ifm_shape = _ctx.at(ifm_index).shape();
+      const auto ifm_rows = ifm_shape.dim(0);
+      const auto ifm_cols = ifm_shape.dim(1);
+
+      _builder.addSubsumptionConstr(ifm_index, ofm_index,
+                                    ::arm_compute::Coordinates{0, row_offset, 0, 0},
+                                    asTensorShape(ifm_rows, ifm_cols));
+
+      row_offset += ifm_rows;
+    }
+  }
+  else
+  {
+    // Not implemented yet
+    throw std::runtime_error("Not supported, yet");
   }
 
   // NOTE Concat has no actual operation!