From 742f059cba1542fe28c0bbde17d76a2bad0017d4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EA=B9=80=EC=88=98=EC=A7=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Tue, 10 Jul 2018 14:22:21 +0900 Subject: [PATCH] [Refactor] Remove if/else(switch case) of rank for addShapeConstr (#1839) * Remove if/else(switch case) of rank for addShapeConstr Related : #1828 This commit removes if/else(switch cases) of rank for `addShapeConstr`. That's because `asTensorShape` which used in `addShapeConstr` parameter includes if/else for processing shapes by each ranks. Signed-off-by: sjsujinkim * Apply #1741 cases * Update rank check for Cast --- runtimes/pure_arm_compute/src/compilation.cc | 182 +++++---------------------- 1 file changed, 31 insertions(+), 151 deletions(-) diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 1240165..d3c1ca5 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -506,82 +506,35 @@ void Planner::visit(const ::internal::tflite::op::Mul::Node &node) int32_t ofm_rank = _ctx.at(ofm_index).shape().rank(); int32_t lhs_rank = _ctx.at(lhs_index).shape().rank(); int32_t rhs_rank = _ctx.at(rhs_index).shape().rank(); + const auto ofm_shape = _ctx.at(ofm_index).shape().asTensor(); + const auto lhs_shape = _ctx.at(lhs_index).shape().asTensor(); + const auto rhs_shape = _ctx.at(rhs_index).shape().asTensor(); // not tested cases below assert(!(ofm_rank == 0 && lhs_rank == 0 && rhs_rank == 0)); assert(ofm_rank < 4 && lhs_rank < 4 && rhs_rank < 4); - if (ofm_rank == 1) - { - const auto ofm_shape = _ctx.at(ofm_index).shape().asVector(); - _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type())); - } - else if (ofm_rank == 2) - { - const auto ofm_shape = _ctx.at(ofm_index).shape().asMatrix(); - _builder.addShapeConstr(ofm_index, - asTensorInfo(ofm_shape.H, ofm_shape.W, _ctx.at(ofm_index).type())); - } - else if (ofm_rank == 3) - { - const auto ofm_shape = _ctx.at(ofm_index).shape().asTensor(); - _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type())); - } - else + if (ofm_rank > 3) { throw std::runtime_error("Not supported, yet"); } - if (lhs_rank == 0) - { - _builder.addShapeConstr(lhs_index, asTensorInfo(1, _ctx.at(lhs_index).type())); - } - else if (lhs_rank == 1) - { - const auto lhs_shape = _ctx.at(lhs_index).shape().asVector(); - _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type())); - } - else if (lhs_rank == 2) - { - const auto lhs_shape = _ctx.at(lhs_index).shape().asMatrix(); - _builder.addShapeConstr(lhs_index, - asTensorInfo(lhs_shape.H, lhs_shape.W, _ctx.at(lhs_index).type())); - } - else if (lhs_rank == 3) - { - const auto lhs_shape = _ctx.at(lhs_index).shape().asTensor(); - _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type())); - } - else + _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type())); + + if (lhs_rank > 3) { throw std::runtime_error("Not supported, yet"); } - if (rhs_rank == 0) - { - _builder.addShapeConstr(rhs_index, asTensorInfo(1, _ctx.at(rhs_index).type())); - } - else if (rhs_rank == 1) - { - const auto rhs_shape = _ctx.at(rhs_index).shape().asVector(); - _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type())); - } - else if (rhs_rank == 2) - { - const auto rhs_shape = _ctx.at(rhs_index).shape().asMatrix(); - _builder.addShapeConstr(rhs_index, - asTensorInfo(rhs_shape.H, rhs_shape.W, _ctx.at(rhs_index).type())); - } - else if (rhs_rank == 3) - { - const auto rhs_shape = _ctx.at(rhs_index).shape().asTensor(); - _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type())); - } - else + _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type())); + + if (rhs_rank > 3) { throw std::runtime_error("Not supported, yet"); } + _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type())); + struct Param { int ofm_index; @@ -633,65 +586,31 @@ void Planner::visit(const ::internal::tflite::op::Div::Node &node) // or the operand's dimension size is one. const auto ofm_shape = _ctx.at(ofm_index).shape(); const auto ofm_shape_rank = ofm_shape.rank(); - if (ofm_shape_rank == 4) - { - _builder.addShapeConstr(ofm_index, - asTensorInfo(ofm_shape.asFeature(), _ctx.at(ofm_index).type())); - } - else if (ofm_shape_rank == 1) - { - _builder.addShapeConstr(ofm_index, - asTensorInfo(ofm_shape.asVector(), _ctx.at(ofm_index).type())); - } - else + if (ofm_shape_rank > 4) { throw std::runtime_error("Not supported, yet"); } + _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape.asTensor(), _ctx.at(ofm_index).type())); + const auto lhs_shape = _ctx.at(lhs_index).shape(); const auto lhs_shape_rank = lhs_shape.rank(); - if (lhs_shape_rank == 4) - { - _builder.addShapeConstr(lhs_index, - asTensorInfo(lhs_shape.asFeature(), _ctx.at(lhs_index).type())); - } - else if (lhs_shape_rank == 1) - { - _builder.addShapeConstr(lhs_index, - asTensorInfo(lhs_shape.asVector(), _ctx.at(lhs_index).type())); - } - else if (lhs_shape_rank == 0) - { - // scalar - _builder.addShapeConstr(lhs_index, asTensorInfo(1, _ctx.at(lhs_index).type())); - } - else + if (lhs_shape_rank > 4) { throw std::runtime_error("Not supported, yet"); } + _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape.asTensor(), _ctx.at(lhs_index).type())); + const auto rhs_shape = _ctx.at(rhs_index).shape(); const auto rhs_shape_rank = rhs_shape.rank(); - if (rhs_shape_rank == 4) - { - _builder.addShapeConstr(rhs_index, - asTensorInfo(rhs_shape.asFeature(), _ctx.at(rhs_index).type())); - } - else if (rhs_shape_rank == 1) - { - _builder.addShapeConstr(rhs_index, - asTensorInfo(rhs_shape.asVector(), _ctx.at(rhs_index).type())); - } - else if (rhs_shape_rank == 0) - { - // scalar - _builder.addShapeConstr(rhs_index, asTensorInfo(1, _ctx.at(rhs_index).type())); - } - else + if (rhs_shape_rank > 4) { throw std::runtime_error("Not supported, yet"); } + _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape.asTensor(), _ctx.at(rhs_index).type())); + // Construct operation parameters struct Param { @@ -1783,56 +1702,17 @@ void Planner::visit(const ::internal::tflite::op::Cast::Node &node) const ::internal::tflite::operand::Index output_index{node.param().output_index}; const ::internal::tflite::operand::Index input_index{node.param().input_index}; - const auto output_shape = _ctx.at(output_index).shape(); - const auto input_shape = _ctx.at(input_index).shape(); - assert(output_shape.rank() == input_shape.rank()); - for (uint32_t n = 0; n < input_shape.rank(); ++n) - { - assert(output_shape.dim(n) == input_shape.dim(n)); - } + const auto output_shape = _ctx.at(output_index).shape().asTensor(); + const auto input_shape = _ctx.at(input_index).shape().asTensor(); - // TODO Should move to the place where the operand is handled, if it is possible. - // Set Shape Constraints and TensorInfo - switch (input_shape.rank()) - { - case 0: // scalar - { - _builder.addShapeConstr(output_index, asTensorInfo(1, _ctx.at(output_index).type(), - _ctx.at(output_index).scale(), - _ctx.at(output_index).zeroPoint())); - _builder.addShapeConstr(input_index, asTensorInfo(1, _ctx.at(input_index).type(), - _ctx.at(input_index).scale(), - _ctx.at(input_index).zeroPoint())); - break; - } - case 1: // vector - { - _builder.addShapeConstr(output_index, - asTensorInfo(input_shape.asVector(), _ctx.at(output_index).type(), - _ctx.at(output_index).scale(), - _ctx.at(output_index).zeroPoint())); - _builder.addShapeConstr(input_index, - asTensorInfo(output_shape.asVector(), _ctx.at(input_index).type(), - _ctx.at(input_index).scale(), - _ctx.at(input_index).zeroPoint())); - break; - } - case 4: // feature - { - _builder.addShapeConstr(output_index, - asTensorInfo(input_shape.asFeature(), _ctx.at(output_index).type(), - _ctx.at(output_index).scale(), - _ctx.at(output_index).zeroPoint())); - _builder.addShapeConstr(input_index, - asTensorInfo(output_shape.asFeature(), _ctx.at(input_index).type(), - _ctx.at(input_index).scale(), - _ctx.at(input_index).zeroPoint())); - break; - } - default: - throw std::runtime_error("Not supported, yet"); - break; - } + assert(output_shape == input_shape); + + _builder.addShapeConstr(output_index, asTensorInfo(input_shape, _ctx.at(output_index).type(), + _ctx.at(output_index).scale(), + _ctx.at(output_index).zeroPoint())); + _builder.addShapeConstr(input_index, asTensorInfo(output_shape, _ctx.at(input_index).type(), + _ctx.at(input_index).scale(), + _ctx.at(input_index).zeroPoint())); // Construct operation parameters struct Param -- 2.7.4