Introduce addShapeConstr considering broadcasting for ADD, SUB, DIV, MUL (#1852)
author최성진/동작제어Lab(SR)/Principal Engineer/삼성전자 <lotieye.choi@samsung.com>
Thu, 26 Jul 2018 01:59:33 +0000 (10:59 +0900)
committer이춘석/동작제어Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Thu, 26 Jul 2018 01:59:33 +0000 (10:59 +0900)
* Introduce addShapeConstr considering broadcasting

This commit introduces addShapeConstr considering broadcasting for ADD, SUB, DIV, MUL

Signed-off-by: SungJin Choi <lotieye.choi@samsung.com>
* Delete useless cout statement

This commit deletes useless cout statement

Signed-off-by: SungJin Choi <lotieye.choi@samsung.com>
* Modify description and typos

This commit modifies description and typos.

Signed-off-by: SungJin Choi <lotieye.choi@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc
runtimes/pure_arm_compute/src/internal/arm_compute/Cast.h

index 4bd219f..ae54487 100644 (file)
@@ -208,9 +208,12 @@ struct IPlanBuilder
 
   virtual void addShapeConstr(const ::internal::tflite::operand::Index &ind,
                               const ::arm_compute::TensorInfo &info) = 0;
-  virtual void addShapeConstr(const ::internal::tflite::operand::Index &ind,
-                              const ::arm_compute::TensorInfo &info,
-                              const nnfw::util::tensor::Shape &shape) = 0;
+  virtual void addShapeConstr(const ::internal::tflite::operand::Index &lhs_ind,
+                              const ::internal::tflite::operand::Object &lhs_obj,
+                              const nnfw::util::tensor::Shape &lhs_shape,
+                              const ::internal::tflite::operand::Index &rhs_ind,
+                              const ::internal::tflite::operand::Object &rhs_obj,
+                              const nnfw::util::tensor::Shape &rhs_shape) = 0;
   virtual void addSubsumptionConstr(const ::internal::tflite::operand::Index &ind,
                                     const ::internal::tflite::operand::Index &base,
                                     const ::arm_compute::Coordinates &offset,
@@ -372,10 +375,8 @@ private:
 void Planner::visit(const ::internal::tflite::op::Add::Node &node)
 {
   const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
-
   const ::internal::tflite::operand::Index lhs_index{node.param().lhs_index};
   const ::internal::tflite::operand::Index rhs_index{node.param().rhs_index};
-
   const ::internal::tflite::operand::Index activation_index{node.param().activation_index};
 
   const auto ofm_shape = _ctx.at(ofm_index).shape().asTensor();
@@ -384,41 +385,9 @@ void Planner::visit(const ::internal::tflite::op::Add::Node &node)
 
   // TODO Should move to the place where the operand is handled, if it is possible.
   // Set Shape Constraints and TensorInfo
-  _builder.addShapeConstr(ofm_index,
-                          asTensorInfo(ofm_shape, _ctx.at(ofm_index).type(),
-                                       _ctx.at(ofm_index).scale(), _ctx.at(ofm_index).zeroPoint()));
-
-  if (lhs_shape.rank() == 4 && rhs_shape.rank() < 4)
-  {
-    _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type(),
-                                                    _ctx.at(lhs_index).scale(),
-                                                    _ctx.at(lhs_index).zeroPoint()));
-    _builder.addShapeConstr(rhs_index,
-                            asTensorInfoForBroadcast(rhs_shape, _ctx.at(rhs_index).type(),
-                                                     _ctx.at(ofm_index).scale(),
-                                                     _ctx.at(ofm_index).zeroPoint()),
-                            rhs_shape);
-  }
-  else if (rhs_shape.rank() == 4 && lhs_shape.rank() < 4)
-  {
-    _builder.addShapeConstr(lhs_index,
-                            asTensorInfoForBroadcast(lhs_shape, _ctx.at(lhs_index).type(),
-                                                     _ctx.at(lhs_index).scale(),
-                                                     _ctx.at(lhs_index).zeroPoint()),
-                            lhs_shape);
-    _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type(),
-                                                    _ctx.at(rhs_index).scale(),
-                                                    _ctx.at(rhs_index).zeroPoint()));
-  }
-  else
-  {
-    _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type(),
-                                                    _ctx.at(lhs_index).scale(),
-                                                    _ctx.at(lhs_index).zeroPoint()));
-    _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type(),
-                                                    _ctx.at(rhs_index).scale(),
-                                                    _ctx.at(rhs_index).zeroPoint()));
-  }
+  _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type()));
+  _builder.addShapeConstr(lhs_index, _ctx.at(lhs_index), lhs_shape, rhs_index, _ctx.at(rhs_index),
+                          rhs_shape);
 
   // Construct operation parameters
   struct Param
@@ -479,10 +448,8 @@ void Planner::visit(const ::internal::tflite::op::Add::Node &node)
 void Planner::visit(const ::internal::tflite::op::Sub::Node &node)
 {
   const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
-
   const ::internal::tflite::operand::Index lhs_index{node.param().lhs_index};
   const ::internal::tflite::operand::Index rhs_index{node.param().rhs_index};
-
   const ::internal::tflite::operand::Index activation_index{node.param().activation_index};
 
   const auto ofm_shape = _ctx.at(ofm_index).shape().asTensor();
@@ -490,8 +457,8 @@ void Planner::visit(const ::internal::tflite::op::Sub::Node &node)
   const auto rhs_shape = _ctx.at(rhs_index).shape().asTensor();
 
   _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type()));
-  _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type()));
-  _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type()));
+  _builder.addShapeConstr(lhs_index, _ctx.at(lhs_index), lhs_shape, rhs_index, _ctx.at(rhs_index),
+                          rhs_shape);
 
   // Construct operation parameters
   struct Param
@@ -537,43 +504,13 @@ void Planner::visit(const ::internal::tflite::op::Mul::Node &node)
   const ::internal::tflite::operand::Index rhs_index{node.param().rhs_index};
   const ::internal::tflite::operand::Index activation_index{node.param().activation_index};
 
-  int32_t ofm_rank = _ctx.at(ofm_index).shape().rank();
-  int32_t lhs_rank = _ctx.at(lhs_index).shape().rank();
-  int32_t rhs_rank = _ctx.at(rhs_index).shape().rank();
   const auto ofm_shape = _ctx.at(ofm_index).shape().asTensor();
   const auto lhs_shape = _ctx.at(lhs_index).shape().asTensor();
   const auto rhs_shape = _ctx.at(rhs_index).shape().asTensor();
 
-  // not tested cases below
-  assert(!(ofm_rank == 0 && lhs_rank == 0 && rhs_rank == 0));
-  assert(ofm_rank < 4 && lhs_rank < 4 && rhs_rank < 4);
-
-  if (ofm_rank > 3)
-  {
-    throw std::runtime_error("Not supported, yet");
-  }
-
-  _builder.addShapeConstr(ofm_index,
-                          asTensorInfo(ofm_shape, _ctx.at(ofm_index).type(),
-                                       _ctx.at(ofm_index).scale(), _ctx.at(ofm_index).zeroPoint()));
-
-  if (lhs_rank > 3)
-  {
-    throw std::runtime_error("Not supported, yet");
-  }
-
-  _builder.addShapeConstr(lhs_index,
-                          asTensorInfo(lhs_shape, _ctx.at(lhs_index).type(),
-                                       _ctx.at(lhs_index).scale(), _ctx.at(lhs_index).zeroPoint()));
-
-  if (rhs_rank > 3)
-  {
-    throw std::runtime_error("Not supported, yet");
-  }
-
-  _builder.addShapeConstr(rhs_index,
-                          asTensorInfo(rhs_shape, _ctx.at(rhs_index).type(),
-                                       _ctx.at(rhs_index).scale(), _ctx.at(rhs_index).zeroPoint()));
+  _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type()));
+  _builder.addShapeConstr(lhs_index, _ctx.at(lhs_index), lhs_shape, rhs_index, _ctx.at(rhs_index),
+                          rhs_shape);
 
   struct Param
   {
@@ -2671,9 +2608,12 @@ public:
 public:
   void addShapeConstr(const ::internal::tflite::operand::Index &ind,
                       const ::arm_compute::TensorInfo &info) override;
-  void addShapeConstr(const ::internal::tflite::operand::Index &ind,
-                      const ::arm_compute::TensorInfo &info,
-                      const nnfw::util::tensor::Shape &shape) override;
+  void addShapeConstr(const ::internal::tflite::operand::Index &lhs_ind,
+                      const ::internal::tflite::operand::Object &lhs_obj,
+                      const nnfw::util::tensor::Shape &lhs_shape,
+                      const ::internal::tflite::operand::Index &rhs_ind,
+                      const ::internal::tflite::operand::Object &rhs_obj,
+                      const nnfw::util::tensor::Shape &rhs_shape) override;
 
 public:
   void addSubsumptionConstr(const ::internal::tflite::operand::Index &ind,
@@ -2735,41 +2675,45 @@ void PlanBuilder::addShapeConstr(const ::internal::tflite::operand::Index &ind,
 {
   _tensor_info_ctx[ind.asInt()] = info;
 }
-void PlanBuilder::addShapeConstr(const ::internal::tflite::operand::Index &ind,
-                                 const ::arm_compute::TensorInfo &info,
-                                 const nnfw::util::tensor::Shape &shape)
-{
-  // ACL tensor info
-  _tensor_info_ctx[ind.asInt()] = info;
 
-  // broadcasting tensor shape
-  internal::tflite::operand::Shape broadcastShape(4);
-  if (shape.rank() == 1)
-  {
-    broadcastShape.dim(0) = 1;
-    broadcastShape.dim(1) = 1;
-    broadcastShape.dim(2) = 1;
-    broadcastShape.dim(3) = shape.dim(0);
-  }
-  else if (shape.rank() == 2)
-  {
-    broadcastShape.dim(0) = 1;
-    broadcastShape.dim(1) = 1;
-    broadcastShape.dim(2) = shape.dim(0);
-    broadcastShape.dim(3) = shape.dim(1);
+// Add tensor shape constraints considering broadcasting
+void PlanBuilder::addShapeConstr(const ::internal::tflite::operand::Index &lhs_ind,
+                                 const ::internal::tflite::operand::Object &lhs_obj,
+                                 const nnfw::util::tensor::Shape &lhs_shape,
+                                 const ::internal::tflite::operand::Index &rhs_ind,
+                                 const ::internal::tflite::operand::Object &rhs_obj,
+                                 const nnfw::util::tensor::Shape &rhs_shape)
+{
+  // right-side broadcasting
+  if (lhs_shape.rank() > rhs_shape.rank())
+  {
+    // ACL tensor info
+    _tensor_info_ctx[lhs_ind.asInt()] = asTensorInfo(lhs_shape, lhs_obj.type());
+    _tensor_info_ctx[rhs_ind.asInt()] =
+        asTensorInfoForBroadcast(rhs_shape, rhs_obj.type(), lhs_shape.rank());
+
+    // TFlite broadcasting tensor shape
+    if (lhs_shape.rank() == 4)
+      _broadcasting_tensor_shape.emplace(rhs_ind.asInt(),
+                                         asTensorShapeForTFLiteBroadcast(rhs_shape));
   }
-  else if (shape.rank() == 3)
+  // left-side broadcasting
+  else if (lhs_shape.rank() < rhs_shape.rank())
   {
-    broadcastShape.dim(0) = 1;
-    broadcastShape.dim(1) = shape.dim(0);
-    broadcastShape.dim(2) = shape.dim(1);
-    broadcastShape.dim(3) = shape.dim(2);
+    _tensor_info_ctx[lhs_ind.asInt()] =
+        asTensorInfoForBroadcast(lhs_shape, lhs_obj.type(), rhs_shape.rank());
+    _tensor_info_ctx[rhs_ind.asInt()] = asTensorInfo(rhs_shape, rhs_obj.type());
+
+    if (rhs_shape.rank() == 4)
+      _broadcasting_tensor_shape.emplace(lhs_ind.asInt(),
+                                         asTensorShapeForTFLiteBroadcast(lhs_shape));
   }
+  // no broadcasting
   else
   {
-    throw std::runtime_error("Not supported, yet");
+    _tensor_info_ctx[lhs_ind.asInt()] = asTensorInfo(lhs_shape, lhs_obj.type());
+    _tensor_info_ctx[rhs_ind.asInt()] = asTensorInfo(rhs_shape, rhs_obj.type());
   }
-  _broadcasting_tensor_shape.emplace(ind.asInt(), broadcastShape);
 }
 
 void PlanBuilder::addSubsumptionConstr(const ::internal::tflite::operand::Index &ind,
index 6acfcf3..35b07e9 100644 (file)
   }
 }
 
-// in case of NHWC
+// ACL Broadcasting style in case of NHWC
 // TODO HCHW
-::arm_compute::TensorShape asTensorShapeForBroadcast(const nnfw::util::tensor::Shape &shape)
+::arm_compute::TensorShape asTensorShapeForBroadcast(const nnfw::util::tensor::Shape &shape,
+                                                     const size_t baseRank)
 {
-  if (shape.rank() == 1)
+  // The cases that large rank(baseRank) is 4 and small rank is less than 4 need to transform to
+  // broadcasting TensorInfo because order is different.
+  if (baseRank == 4)
   {
-    return ::arm_compute::TensorShape(1, 1, shape.dim(0), 1);
+    if (shape.rank() == 0)
+    {
+      return ::arm_compute::TensorShape(1);
+    }
+    else if (shape.rank() == 1)
+    {
+      return ::arm_compute::TensorShape(1, 1, shape.dim(0), 1);
+    }
+    else if (shape.rank() == 2)
+    {
+      return ::arm_compute::TensorShape(shape.dim(0), 1, shape.dim(1), 1); // w c -> w h c n
+    }
+    else if (shape.rank() == 3)
+    {
+      return ::arm_compute::TensorShape(shape.dim(1), shape.dim(0), shape.dim(2),
+                                        1); // h w c -> w h c n
+    }
+    else if (shape.rank() == 4)
+    {
+      assert(shape.dim(0) ==
+             1); // In case of ADD, SUB, MUL and DIV at ACL OpenCL, 3D inputs are supported.
+      return ::arm_compute::TensorShape(shape.dim(2), shape.dim(1), shape.dim(3),
+                                        shape.dim(0)); // n h w c -> W H C N
+    }
+    else
+    {
+      throw std::runtime_error("Not supported, yet");
+    }
   }
-  else if (shape.rank() == 2)
+  // Other cases that larger rank <= 3 don't need to transform because broadcast shape is the same
+  // as orignal. For example, ::arm_compute::TensorShape(shape.dim(0), 1, 1) ==
+  // ::arm_compute::TensorShape(shape.dim(0).
+  else
   {
-    return ::arm_compute::TensorShape(shape.dim(0), 1, shape.dim(1), 1); // w c -> w h c n
+    return asTensorShape(shape);
   }
-  else if (shape.rank() == 3)
+}
+
+// TFLite broadcasting style: used for reading an input as broadcasting shape
+internal::tflite::operand::Shape
+asTensorShapeForTFLiteBroadcast(const nnfw::util::tensor::Shape &shape)
+{
+  internal::tflite::operand::Shape broadcastShape(4);
+  if (shape.rank() == 0)
   {
-    return ::arm_compute::TensorShape(shape.dim(1), shape.dim(0), shape.dim(2),
-                                      1); // h w c -> w h c n
+    broadcastShape.dim(0) = 1;
+    broadcastShape.dim(1) = 1;
+    broadcastShape.dim(2) = 1;
+    broadcastShape.dim(3) = 1;
   }
-  else if (shape.rank() == 4)
+  else if (shape.rank() == 1)
   {
-    assert(shape.dim(0) == 1); // In case of ADD, SUB, 3D inputs are supported.
-    return ::arm_compute::TensorShape(shape.dim(2), shape.dim(1), shape.dim(3),
-                                      shape.dim(0)); // n h w c -> W H C N
+    broadcastShape.dim(0) = 1;
+    broadcastShape.dim(1) = 1;
+    broadcastShape.dim(2) = 1;
+    broadcastShape.dim(3) = shape.dim(0);
+  }
+  else if (shape.rank() == 2)
+  {
+    broadcastShape.dim(0) = 1;
+    broadcastShape.dim(1) = 1;
+    broadcastShape.dim(2) = shape.dim(0);
+    broadcastShape.dim(3) = shape.dim(1);
+  }
+  else if (shape.rank() == 3)
+  {
+    broadcastShape.dim(0) = 1;
+    broadcastShape.dim(1) = shape.dim(0);
+    broadcastShape.dim(2) = shape.dim(1);
+    broadcastShape.dim(3) = shape.dim(2);
   }
   else
   {
     throw std::runtime_error("Not supported, yet");
   }
+  return broadcastShape;
 }
 
 inline ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand::Shape &shape)
@@ -146,10 +204,11 @@ inline ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand:
 }
 
 ::arm_compute::TensorInfo asTensorInfoForBroadcast(const nnfw::util::tensor::Shape &shape,
-                                                   const int32_t type, const float scale = 0.0f,
+                                                   const int32_t type, const size_t baseRank,
+                                                   const float scale = 0.0f,
                                                    const int32_t zeroPoint = 0)
 {
-  return ::arm_compute::TensorInfo(asTensorShapeForBroadcast(shape), 1, asDataType(type),
+  return ::arm_compute::TensorInfo(asTensorShapeForBroadcast(shape, baseRank), 1, asDataType(type),
                                    asQuantizationInfo(scale, zeroPoint));
 }