[Refactor] Remove if/else(switch case) of rank for addShapeConstr (#1839)
author김수진/동작제어Lab(SR)/Engineer/삼성전자 <sjsujin.kim@samsung.com>
Tue, 10 Jul 2018 05:22:21 +0000 (14:22 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Tue, 10 Jul 2018 05:22:21 +0000 (14:22 +0900)
* Remove if/else(switch case) of rank for addShapeConstr

Related : #1828

This commit removes if/else(switch cases) of rank for `addShapeConstr`.
That's because `asTensorShape` which used in `addShapeConstr` parameter includes if/else for processing shapes by each ranks.

Signed-off-by: sjsujinkim <sjsujin.kim@samsung.com>
* Apply #1741 cases

* Update rank check for Cast

runtimes/pure_arm_compute/src/compilation.cc

index 1240165..d3c1ca5 100644 (file)
@@ -506,82 +506,35 @@ void Planner::visit(const ::internal::tflite::op::Mul::Node &node)
   int32_t ofm_rank = _ctx.at(ofm_index).shape().rank();
   int32_t lhs_rank = _ctx.at(lhs_index).shape().rank();
   int32_t rhs_rank = _ctx.at(rhs_index).shape().rank();
+  const auto ofm_shape = _ctx.at(ofm_index).shape().asTensor();
+  const auto lhs_shape = _ctx.at(lhs_index).shape().asTensor();
+  const auto rhs_shape = _ctx.at(rhs_index).shape().asTensor();
 
   // not tested cases below
   assert(!(ofm_rank == 0 && lhs_rank == 0 && rhs_rank == 0));
   assert(ofm_rank < 4 && lhs_rank < 4 && rhs_rank < 4);
 
-  if (ofm_rank == 1)
-  {
-    const auto ofm_shape = _ctx.at(ofm_index).shape().asVector();
-    _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type()));
-  }
-  else if (ofm_rank == 2)
-  {
-    const auto ofm_shape = _ctx.at(ofm_index).shape().asMatrix();
-    _builder.addShapeConstr(ofm_index,
-                            asTensorInfo(ofm_shape.H, ofm_shape.W, _ctx.at(ofm_index).type()));
-  }
-  else if (ofm_rank == 3)
-  {
-    const auto ofm_shape = _ctx.at(ofm_index).shape().asTensor();
-    _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type()));
-  }
-  else
+  if (ofm_rank > 3)
   {
     throw std::runtime_error("Not supported, yet");
   }
 
-  if (lhs_rank == 0)
-  {
-    _builder.addShapeConstr(lhs_index, asTensorInfo(1, _ctx.at(lhs_index).type()));
-  }
-  else if (lhs_rank == 1)
-  {
-    const auto lhs_shape = _ctx.at(lhs_index).shape().asVector();
-    _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type()));
-  }
-  else if (lhs_rank == 2)
-  {
-    const auto lhs_shape = _ctx.at(lhs_index).shape().asMatrix();
-    _builder.addShapeConstr(lhs_index,
-                            asTensorInfo(lhs_shape.H, lhs_shape.W, _ctx.at(lhs_index).type()));
-  }
-  else if (lhs_rank == 3)
-  {
-    const auto lhs_shape = _ctx.at(lhs_index).shape().asTensor();
-    _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type()));
-  }
-  else
+  _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type()));
+
+  if (lhs_rank > 3)
   {
     throw std::runtime_error("Not supported, yet");
   }
 
-  if (rhs_rank == 0)
-  {
-    _builder.addShapeConstr(rhs_index, asTensorInfo(1, _ctx.at(rhs_index).type()));
-  }
-  else if (rhs_rank == 1)
-  {
-    const auto rhs_shape = _ctx.at(rhs_index).shape().asVector();
-    _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type()));
-  }
-  else if (rhs_rank == 2)
-  {
-    const auto rhs_shape = _ctx.at(rhs_index).shape().asMatrix();
-    _builder.addShapeConstr(rhs_index,
-                            asTensorInfo(rhs_shape.H, rhs_shape.W, _ctx.at(rhs_index).type()));
-  }
-  else if (rhs_rank == 3)
-  {
-    const auto rhs_shape = _ctx.at(rhs_index).shape().asTensor();
-    _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type()));
-  }
-  else
+  _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type()));
+
+  if (rhs_rank > 3)
   {
     throw std::runtime_error("Not supported, yet");
   }
 
+  _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type()));
+
   struct Param
   {
     int ofm_index;
@@ -633,65 +586,31 @@ void Planner::visit(const ::internal::tflite::op::Div::Node &node)
   //      or the operand's dimension size is one.
   const auto ofm_shape = _ctx.at(ofm_index).shape();
   const auto ofm_shape_rank = ofm_shape.rank();
-  if (ofm_shape_rank == 4)
-  {
-    _builder.addShapeConstr(ofm_index,
-                            asTensorInfo(ofm_shape.asFeature(), _ctx.at(ofm_index).type()));
-  }
-  else if (ofm_shape_rank == 1)
-  {
-    _builder.addShapeConstr(ofm_index,
-                            asTensorInfo(ofm_shape.asVector(), _ctx.at(ofm_index).type()));
-  }
-  else
+  if (ofm_shape_rank > 4)
   {
     throw std::runtime_error("Not supported, yet");
   }
 
+  _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape.asTensor(), _ctx.at(ofm_index).type()));
+
   const auto lhs_shape = _ctx.at(lhs_index).shape();
   const auto lhs_shape_rank = lhs_shape.rank();
-  if (lhs_shape_rank == 4)
-  {
-    _builder.addShapeConstr(lhs_index,
-                            asTensorInfo(lhs_shape.asFeature(), _ctx.at(lhs_index).type()));
-  }
-  else if (lhs_shape_rank == 1)
-  {
-    _builder.addShapeConstr(lhs_index,
-                            asTensorInfo(lhs_shape.asVector(), _ctx.at(lhs_index).type()));
-  }
-  else if (lhs_shape_rank == 0)
-  {
-    // scalar
-    _builder.addShapeConstr(lhs_index, asTensorInfo(1, _ctx.at(lhs_index).type()));
-  }
-  else
+  if (lhs_shape_rank > 4)
   {
     throw std::runtime_error("Not supported, yet");
   }
 
+  _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape.asTensor(), _ctx.at(lhs_index).type()));
+
   const auto rhs_shape = _ctx.at(rhs_index).shape();
   const auto rhs_shape_rank = rhs_shape.rank();
-  if (rhs_shape_rank == 4)
-  {
-    _builder.addShapeConstr(rhs_index,
-                            asTensorInfo(rhs_shape.asFeature(), _ctx.at(rhs_index).type()));
-  }
-  else if (rhs_shape_rank == 1)
-  {
-    _builder.addShapeConstr(rhs_index,
-                            asTensorInfo(rhs_shape.asVector(), _ctx.at(rhs_index).type()));
-  }
-  else if (rhs_shape_rank == 0)
-  {
-    // scalar
-    _builder.addShapeConstr(rhs_index, asTensorInfo(1, _ctx.at(rhs_index).type()));
-  }
-  else
+  if (rhs_shape_rank > 4)
   {
     throw std::runtime_error("Not supported, yet");
   }
 
+  _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape.asTensor(), _ctx.at(rhs_index).type()));
+
   // Construct operation parameters
   struct Param
   {
@@ -1783,56 +1702,17 @@ void Planner::visit(const ::internal::tflite::op::Cast::Node &node)
   const ::internal::tflite::operand::Index output_index{node.param().output_index};
   const ::internal::tflite::operand::Index input_index{node.param().input_index};
 
-  const auto output_shape = _ctx.at(output_index).shape();
-  const auto input_shape = _ctx.at(input_index).shape();
-  assert(output_shape.rank() == input_shape.rank());
-  for (uint32_t n = 0; n < input_shape.rank(); ++n)
-  {
-    assert(output_shape.dim(n) == input_shape.dim(n));
-  }
+  const auto output_shape = _ctx.at(output_index).shape().asTensor();
+  const auto input_shape = _ctx.at(input_index).shape().asTensor();
 
-  // TODO Should move to the place where the operand is handled, if it is possible.
-  // Set Shape Constraints and TensorInfo
-  switch (input_shape.rank())
-  {
-    case 0: // scalar
-    {
-      _builder.addShapeConstr(output_index, asTensorInfo(1, _ctx.at(output_index).type(),
-                                                         _ctx.at(output_index).scale(),
-                                                         _ctx.at(output_index).zeroPoint()));
-      _builder.addShapeConstr(input_index, asTensorInfo(1, _ctx.at(input_index).type(),
-                                                        _ctx.at(input_index).scale(),
-                                                        _ctx.at(input_index).zeroPoint()));
-      break;
-    }
-    case 1: // vector
-    {
-      _builder.addShapeConstr(output_index,
-                              asTensorInfo(input_shape.asVector(), _ctx.at(output_index).type(),
-                                           _ctx.at(output_index).scale(),
-                                           _ctx.at(output_index).zeroPoint()));
-      _builder.addShapeConstr(input_index,
-                              asTensorInfo(output_shape.asVector(), _ctx.at(input_index).type(),
-                                           _ctx.at(input_index).scale(),
-                                           _ctx.at(input_index).zeroPoint()));
-      break;
-    }
-    case 4: // feature
-    {
-      _builder.addShapeConstr(output_index,
-                              asTensorInfo(input_shape.asFeature(), _ctx.at(output_index).type(),
-                                           _ctx.at(output_index).scale(),
-                                           _ctx.at(output_index).zeroPoint()));
-      _builder.addShapeConstr(input_index,
-                              asTensorInfo(output_shape.asFeature(), _ctx.at(input_index).type(),
-                                           _ctx.at(input_index).scale(),
-                                           _ctx.at(input_index).zeroPoint()));
-      break;
-    }
-    default:
-      throw std::runtime_error("Not supported, yet");
-      break;
-  }
+  assert(output_shape == input_shape);
+
+  _builder.addShapeConstr(output_index, asTensorInfo(input_shape, _ctx.at(output_index).type(),
+                                                     _ctx.at(output_index).scale(),
+                                                     _ctx.at(output_index).zeroPoint()));
+  _builder.addShapeConstr(input_index, asTensorInfo(output_shape, _ctx.at(input_index).type(),
+                                                    _ctx.at(input_index).scale(),
+                                                    _ctx.at(input_index).zeroPoint()));
 
   // Construct operation parameters
   struct Param