Implement Split operator (#2448)
author윤지영/동작제어Lab(SR)/Engineer/삼성전자 <jy910.yun@samsung.com>
Mon, 27 Aug 2018 09:28:09 +0000 (18:28 +0900)
committer박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 27 Aug 2018 09:28:09 +0000 (18:28 +0900)
* Implement Split operator

Split operator is implemented using CLSubTensor.
In order to avoid unnecessary copy around data,
it uses shared memroy information.
It does not use separate ACL function.

Signed-off-by: Jiyoung Yun <jy910.yun@samsung.com>
* Handle negative axis case on Split operator

This commit handles a negative axis case on Split operator.
If the axis is negative, it updates the axis value by input tensor's rank.

Signed-off-by: Jiyoung Yun <jy910.yun@samsung.com>
* Remove default parameter of derived class

To avoid setting different default parameters between parent and derived class,
this commit removes the default paramter of derived class.

Signed-off-by: Jiyoung Yun <jy910.yun@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc

index 5f57405..af35074 100644 (file)
@@ -267,7 +267,8 @@ struct IPlanBuilder
   virtual void addSubsumptionConstr(const ::internal::tflite::operand::Index &ind,
                                     const ::internal::tflite::operand::Index &base,
                                     const ::arm_compute::Coordinates &offset,
-                                    const ::arm_compute::TensorShape &shape) = 0;
+                                    const ::arm_compute::TensorShape &shape,
+                                    bool extend_parent = false) = 0;
   virtual void addInitializer(const ::internal::tflite::operand::Index &ind,
                               const Initializer &initializer) = 0;
   virtual void addStage(const Stage &) = 0;
@@ -3317,7 +3318,47 @@ void Planner::visit(const ::internal::tflite::op::Split::Node &node)
 {
   VERBOSE(Split) << "Configure Split operation" << std::endl;
 
-  throw std::runtime_error("Not supported, yet");
+  const ::internal::tflite::operand::Index axis_index{node.param().axis_index};
+  const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
+
+  const auto ifm_shape = _ctx.at(ifm_index).shape();
+  int32_t axis = _ctx.at(axis_index).asScalar<int32_t>();
+
+  // Handle negative axis
+  if (axis < 0)
+  {
+    axis += ifm_shape.rank();
+  }
+
+  const int32_t num_split = node.param().ofm_indexes.size();
+  const auto input_size = ifm_shape.dim(axis);
+  assert(input_size % num_split == 0);
+  const int32_t slice_size = input_size / num_split;
+
+  // Set Shape Constraints and TensorInfo (for input)
+  _builder.addShapeConstr(ifm_index, asTensorInfo(ifm_shape, _ctx.at(ifm_index).type()));
+
+  // Set Shape Constraints and TensorInfo (for output)
+  const auto rank = ifm_shape.rank();
+  const uint32_t coord_index = ToARMComputeAxis(rank, axis).value();
+  uint32_t depth = 0;
+
+  ::arm_compute::Coordinates coordinates;
+  coordinates.set_num_dimensions(rank);
+
+  for (const auto &index : node.param().ofm_indexes)
+  {
+    const ::internal::tflite::operand::Index ofm_index{index};
+    const auto ofm_shape = _ctx.at(ofm_index).shape();
+
+    coordinates[coord_index] = depth;
+
+    _builder.addSubsumptionConstr(ofm_index, ifm_index, coordinates, asTensorShape(ofm_shape),
+                                  true);
+    depth += slice_size;
+  }
+
+  // NOTE Split has no actual operation!
 }
 
 class AllocationContext final : public IAllocationContext
@@ -3389,7 +3430,7 @@ public:
   void addSubsumptionConstr(const ::internal::tflite::operand::Index &ind,
                             const ::internal::tflite::operand::Index &base,
                             const ::arm_compute::Coordinates &offset,
-                            const ::arm_compute::TensorShape &shape) override;
+                            const ::arm_compute::TensorShape &shape, bool extend_parent) override;
 
 public:
   void addInitializer(const ::internal::tflite::operand::Index &ind,
@@ -3415,8 +3456,9 @@ private:
   {
   public:
     Subsumption(const ::internal::tflite::operand::Index &base,
-                const ::arm_compute::Coordinates &offset, const ::arm_compute::TensorShape &shape)
-        : _base{base}, _offset{offset}, _shape{shape}
+                const ::arm_compute::Coordinates &offset, const ::arm_compute::TensorShape &shape,
+                bool extend_parent)
+        : _base{base}, _offset{offset}, _shape{shape}, _extend_parent{extend_parent}
     {
       // DO NOTHING
     }
@@ -3425,11 +3467,13 @@ private:
     const ::internal::tflite::operand::Index &base(void) const { return _base; }
     const ::arm_compute::Coordinates &offset(void) const { return _offset; }
     const ::arm_compute::TensorShape &shape(void) const { return _shape; }
+    const bool extend_parent(void) const { return _extend_parent; }
 
   private:
     const ::internal::tflite::operand::Index _base;
     const ::arm_compute::Coordinates _offset;
     const ::arm_compute::TensorShape _shape;
+    const bool _extend_parent;
   };
 
 private:
@@ -3493,9 +3537,9 @@ void PlanBuilder::addShapeConstr(const ::internal::tflite::operand::Index &lhs_i
 void PlanBuilder::addSubsumptionConstr(const ::internal::tflite::operand::Index &ind,
                                        const ::internal::tflite::operand::Index &base,
                                        const ::arm_compute::Coordinates &offset,
-                                       const ::arm_compute::TensorShape &shape)
+                                       const ::arm_compute::TensorShape &shape, bool extend_parent)
 {
-  _subsumption_ctx[ind.asInt()] = std::make_shared<Subsumption>(base, offset, shape);
+  _subsumption_ctx[ind.asInt()] = std::make_shared<Subsumption>(base, offset, shape, extend_parent);
 }
 
 void PlanBuilder::addInitializer(const ::internal::tflite::operand::Index &ind,
@@ -3537,7 +3581,7 @@ void PlanBuilder::finalize(void) const
     assert(base_tensor != nullptr);
 
     auto curr_tensor = std::make_shared<::arm_compute::CLSubTensor>(
-        CAST_CL(base_tensor), sub_info.shape(), sub_info.offset());
+        CAST_CL(base_tensor), sub_info.shape(), sub_info.offset(), sub_info.extend_parent());
 
     _plan.operands().set(::internal::tflite::operand::Index{curr}, curr_tensor);
   };