From: 최성진/동작제어Lab(SR)/Principal Engineer/삼성전자 Date: Thu, 26 Jul 2018 01:59:33 +0000 (+0900) Subject: Introduce addShapeConstr considering broadcasting for ADD, SUB, DIV, MUL (#1852) X-Git-Tag: 0.2~390 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4f97a145a010ef7e8dee1cbcba523f7fd1192947;p=platform%2Fcore%2Fml%2Fnnfw.git Introduce addShapeConstr considering broadcasting for ADD, SUB, DIV, MUL (#1852) * Introduce addShapeConstr considering broadcasting This commit introduces addShapeConstr considering broadcasting for ADD, SUB, DIV, MUL Signed-off-by: SungJin Choi * Delete useless cout statement This commit deletes useless cout statement Signed-off-by: SungJin Choi * Modify description and typos This commit modifies description and typos. Signed-off-by: SungJin Choi --- diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 4bd219f..ae54487 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -208,9 +208,12 @@ struct IPlanBuilder virtual void addShapeConstr(const ::internal::tflite::operand::Index &ind, const ::arm_compute::TensorInfo &info) = 0; - virtual void addShapeConstr(const ::internal::tflite::operand::Index &ind, - const ::arm_compute::TensorInfo &info, - const nnfw::util::tensor::Shape &shape) = 0; + virtual void addShapeConstr(const ::internal::tflite::operand::Index &lhs_ind, + const ::internal::tflite::operand::Object &lhs_obj, + const nnfw::util::tensor::Shape &lhs_shape, + const ::internal::tflite::operand::Index &rhs_ind, + const ::internal::tflite::operand::Object &rhs_obj, + const nnfw::util::tensor::Shape &rhs_shape) = 0; virtual void addSubsumptionConstr(const ::internal::tflite::operand::Index &ind, const ::internal::tflite::operand::Index &base, const ::arm_compute::Coordinates &offset, @@ -372,10 +375,8 @@ private: void Planner::visit(const ::internal::tflite::op::Add::Node &node) { const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index}; - const ::internal::tflite::operand::Index lhs_index{node.param().lhs_index}; const ::internal::tflite::operand::Index rhs_index{node.param().rhs_index}; - const ::internal::tflite::operand::Index activation_index{node.param().activation_index}; const auto ofm_shape = _ctx.at(ofm_index).shape().asTensor(); @@ -384,41 +385,9 @@ void Planner::visit(const ::internal::tflite::op::Add::Node &node) // TODO Should move to the place where the operand is handled, if it is possible. // Set Shape Constraints and TensorInfo - _builder.addShapeConstr(ofm_index, - asTensorInfo(ofm_shape, _ctx.at(ofm_index).type(), - _ctx.at(ofm_index).scale(), _ctx.at(ofm_index).zeroPoint())); - - if (lhs_shape.rank() == 4 && rhs_shape.rank() < 4) - { - _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type(), - _ctx.at(lhs_index).scale(), - _ctx.at(lhs_index).zeroPoint())); - _builder.addShapeConstr(rhs_index, - asTensorInfoForBroadcast(rhs_shape, _ctx.at(rhs_index).type(), - _ctx.at(ofm_index).scale(), - _ctx.at(ofm_index).zeroPoint()), - rhs_shape); - } - else if (rhs_shape.rank() == 4 && lhs_shape.rank() < 4) - { - _builder.addShapeConstr(lhs_index, - asTensorInfoForBroadcast(lhs_shape, _ctx.at(lhs_index).type(), - _ctx.at(lhs_index).scale(), - _ctx.at(lhs_index).zeroPoint()), - lhs_shape); - _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type(), - _ctx.at(rhs_index).scale(), - _ctx.at(rhs_index).zeroPoint())); - } - else - { - _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type(), - _ctx.at(lhs_index).scale(), - _ctx.at(lhs_index).zeroPoint())); - _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type(), - _ctx.at(rhs_index).scale(), - _ctx.at(rhs_index).zeroPoint())); - } + _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type())); + _builder.addShapeConstr(lhs_index, _ctx.at(lhs_index), lhs_shape, rhs_index, _ctx.at(rhs_index), + rhs_shape); // Construct operation parameters struct Param @@ -479,10 +448,8 @@ void Planner::visit(const ::internal::tflite::op::Add::Node &node) void Planner::visit(const ::internal::tflite::op::Sub::Node &node) { const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index}; - const ::internal::tflite::operand::Index lhs_index{node.param().lhs_index}; const ::internal::tflite::operand::Index rhs_index{node.param().rhs_index}; - const ::internal::tflite::operand::Index activation_index{node.param().activation_index}; const auto ofm_shape = _ctx.at(ofm_index).shape().asTensor(); @@ -490,8 +457,8 @@ void Planner::visit(const ::internal::tflite::op::Sub::Node &node) const auto rhs_shape = _ctx.at(rhs_index).shape().asTensor(); _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type())); - _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type())); - _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type())); + _builder.addShapeConstr(lhs_index, _ctx.at(lhs_index), lhs_shape, rhs_index, _ctx.at(rhs_index), + rhs_shape); // Construct operation parameters struct Param @@ -537,43 +504,13 @@ void Planner::visit(const ::internal::tflite::op::Mul::Node &node) const ::internal::tflite::operand::Index rhs_index{node.param().rhs_index}; const ::internal::tflite::operand::Index activation_index{node.param().activation_index}; - int32_t ofm_rank = _ctx.at(ofm_index).shape().rank(); - int32_t lhs_rank = _ctx.at(lhs_index).shape().rank(); - int32_t rhs_rank = _ctx.at(rhs_index).shape().rank(); const auto ofm_shape = _ctx.at(ofm_index).shape().asTensor(); const auto lhs_shape = _ctx.at(lhs_index).shape().asTensor(); const auto rhs_shape = _ctx.at(rhs_index).shape().asTensor(); - // not tested cases below - assert(!(ofm_rank == 0 && lhs_rank == 0 && rhs_rank == 0)); - assert(ofm_rank < 4 && lhs_rank < 4 && rhs_rank < 4); - - if (ofm_rank > 3) - { - throw std::runtime_error("Not supported, yet"); - } - - _builder.addShapeConstr(ofm_index, - asTensorInfo(ofm_shape, _ctx.at(ofm_index).type(), - _ctx.at(ofm_index).scale(), _ctx.at(ofm_index).zeroPoint())); - - if (lhs_rank > 3) - { - throw std::runtime_error("Not supported, yet"); - } - - _builder.addShapeConstr(lhs_index, - asTensorInfo(lhs_shape, _ctx.at(lhs_index).type(), - _ctx.at(lhs_index).scale(), _ctx.at(lhs_index).zeroPoint())); - - if (rhs_rank > 3) - { - throw std::runtime_error("Not supported, yet"); - } - - _builder.addShapeConstr(rhs_index, - asTensorInfo(rhs_shape, _ctx.at(rhs_index).type(), - _ctx.at(rhs_index).scale(), _ctx.at(rhs_index).zeroPoint())); + _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type())); + _builder.addShapeConstr(lhs_index, _ctx.at(lhs_index), lhs_shape, rhs_index, _ctx.at(rhs_index), + rhs_shape); struct Param { @@ -2671,9 +2608,12 @@ public: public: void addShapeConstr(const ::internal::tflite::operand::Index &ind, const ::arm_compute::TensorInfo &info) override; - void addShapeConstr(const ::internal::tflite::operand::Index &ind, - const ::arm_compute::TensorInfo &info, - const nnfw::util::tensor::Shape &shape) override; + void addShapeConstr(const ::internal::tflite::operand::Index &lhs_ind, + const ::internal::tflite::operand::Object &lhs_obj, + const nnfw::util::tensor::Shape &lhs_shape, + const ::internal::tflite::operand::Index &rhs_ind, + const ::internal::tflite::operand::Object &rhs_obj, + const nnfw::util::tensor::Shape &rhs_shape) override; public: void addSubsumptionConstr(const ::internal::tflite::operand::Index &ind, @@ -2735,41 +2675,45 @@ void PlanBuilder::addShapeConstr(const ::internal::tflite::operand::Index &ind, { _tensor_info_ctx[ind.asInt()] = info; } -void PlanBuilder::addShapeConstr(const ::internal::tflite::operand::Index &ind, - const ::arm_compute::TensorInfo &info, - const nnfw::util::tensor::Shape &shape) -{ - // ACL tensor info - _tensor_info_ctx[ind.asInt()] = info; - // broadcasting tensor shape - internal::tflite::operand::Shape broadcastShape(4); - if (shape.rank() == 1) - { - broadcastShape.dim(0) = 1; - broadcastShape.dim(1) = 1; - broadcastShape.dim(2) = 1; - broadcastShape.dim(3) = shape.dim(0); - } - else if (shape.rank() == 2) - { - broadcastShape.dim(0) = 1; - broadcastShape.dim(1) = 1; - broadcastShape.dim(2) = shape.dim(0); - broadcastShape.dim(3) = shape.dim(1); +// Add tensor shape constraints considering broadcasting +void PlanBuilder::addShapeConstr(const ::internal::tflite::operand::Index &lhs_ind, + const ::internal::tflite::operand::Object &lhs_obj, + const nnfw::util::tensor::Shape &lhs_shape, + const ::internal::tflite::operand::Index &rhs_ind, + const ::internal::tflite::operand::Object &rhs_obj, + const nnfw::util::tensor::Shape &rhs_shape) +{ + // right-side broadcasting + if (lhs_shape.rank() > rhs_shape.rank()) + { + // ACL tensor info + _tensor_info_ctx[lhs_ind.asInt()] = asTensorInfo(lhs_shape, lhs_obj.type()); + _tensor_info_ctx[rhs_ind.asInt()] = + asTensorInfoForBroadcast(rhs_shape, rhs_obj.type(), lhs_shape.rank()); + + // TFlite broadcasting tensor shape + if (lhs_shape.rank() == 4) + _broadcasting_tensor_shape.emplace(rhs_ind.asInt(), + asTensorShapeForTFLiteBroadcast(rhs_shape)); } - else if (shape.rank() == 3) + // left-side broadcasting + else if (lhs_shape.rank() < rhs_shape.rank()) { - broadcastShape.dim(0) = 1; - broadcastShape.dim(1) = shape.dim(0); - broadcastShape.dim(2) = shape.dim(1); - broadcastShape.dim(3) = shape.dim(2); + _tensor_info_ctx[lhs_ind.asInt()] = + asTensorInfoForBroadcast(lhs_shape, lhs_obj.type(), rhs_shape.rank()); + _tensor_info_ctx[rhs_ind.asInt()] = asTensorInfo(rhs_shape, rhs_obj.type()); + + if (rhs_shape.rank() == 4) + _broadcasting_tensor_shape.emplace(lhs_ind.asInt(), + asTensorShapeForTFLiteBroadcast(lhs_shape)); } + // no broadcasting else { - throw std::runtime_error("Not supported, yet"); + _tensor_info_ctx[lhs_ind.asInt()] = asTensorInfo(lhs_shape, lhs_obj.type()); + _tensor_info_ctx[rhs_ind.asInt()] = asTensorInfo(rhs_shape, rhs_obj.type()); } - _broadcasting_tensor_shape.emplace(ind.asInt(), broadcastShape); } void PlanBuilder::addSubsumptionConstr(const ::internal::tflite::operand::Index &ind, diff --git a/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.h b/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.h index 6acfcf3..35b07e9 100644 --- a/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.h +++ b/runtimes/pure_arm_compute/src/internal/arm_compute/Cast.h @@ -49,33 +49,91 @@ } } -// in case of NHWC +// ACL Broadcasting style in case of NHWC // TODO HCHW -::arm_compute::TensorShape asTensorShapeForBroadcast(const nnfw::util::tensor::Shape &shape) +::arm_compute::TensorShape asTensorShapeForBroadcast(const nnfw::util::tensor::Shape &shape, + const size_t baseRank) { - if (shape.rank() == 1) + // The cases that large rank(baseRank) is 4 and small rank is less than 4 need to transform to + // broadcasting TensorInfo because order is different. + if (baseRank == 4) { - return ::arm_compute::TensorShape(1, 1, shape.dim(0), 1); + if (shape.rank() == 0) + { + return ::arm_compute::TensorShape(1); + } + else if (shape.rank() == 1) + { + return ::arm_compute::TensorShape(1, 1, shape.dim(0), 1); + } + else if (shape.rank() == 2) + { + return ::arm_compute::TensorShape(shape.dim(0), 1, shape.dim(1), 1); // w c -> w h c n + } + else if (shape.rank() == 3) + { + return ::arm_compute::TensorShape(shape.dim(1), shape.dim(0), shape.dim(2), + 1); // h w c -> w h c n + } + else if (shape.rank() == 4) + { + assert(shape.dim(0) == + 1); // In case of ADD, SUB, MUL and DIV at ACL OpenCL, 3D inputs are supported. + return ::arm_compute::TensorShape(shape.dim(2), shape.dim(1), shape.dim(3), + shape.dim(0)); // n h w c -> W H C N + } + else + { + throw std::runtime_error("Not supported, yet"); + } } - else if (shape.rank() == 2) + // Other cases that larger rank <= 3 don't need to transform because broadcast shape is the same + // as orignal. For example, ::arm_compute::TensorShape(shape.dim(0), 1, 1) == + // ::arm_compute::TensorShape(shape.dim(0). + else { - return ::arm_compute::TensorShape(shape.dim(0), 1, shape.dim(1), 1); // w c -> w h c n + return asTensorShape(shape); } - else if (shape.rank() == 3) +} + +// TFLite broadcasting style: used for reading an input as broadcasting shape +internal::tflite::operand::Shape +asTensorShapeForTFLiteBroadcast(const nnfw::util::tensor::Shape &shape) +{ + internal::tflite::operand::Shape broadcastShape(4); + if (shape.rank() == 0) { - return ::arm_compute::TensorShape(shape.dim(1), shape.dim(0), shape.dim(2), - 1); // h w c -> w h c n + broadcastShape.dim(0) = 1; + broadcastShape.dim(1) = 1; + broadcastShape.dim(2) = 1; + broadcastShape.dim(3) = 1; } - else if (shape.rank() == 4) + else if (shape.rank() == 1) { - assert(shape.dim(0) == 1); // In case of ADD, SUB, 3D inputs are supported. - return ::arm_compute::TensorShape(shape.dim(2), shape.dim(1), shape.dim(3), - shape.dim(0)); // n h w c -> W H C N + broadcastShape.dim(0) = 1; + broadcastShape.dim(1) = 1; + broadcastShape.dim(2) = 1; + broadcastShape.dim(3) = shape.dim(0); + } + else if (shape.rank() == 2) + { + broadcastShape.dim(0) = 1; + broadcastShape.dim(1) = 1; + broadcastShape.dim(2) = shape.dim(0); + broadcastShape.dim(3) = shape.dim(1); + } + else if (shape.rank() == 3) + { + broadcastShape.dim(0) = 1; + broadcastShape.dim(1) = shape.dim(0); + broadcastShape.dim(2) = shape.dim(1); + broadcastShape.dim(3) = shape.dim(2); } else { throw std::runtime_error("Not supported, yet"); } + return broadcastShape; } inline ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand::Shape &shape) @@ -146,10 +204,11 @@ inline ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand: } ::arm_compute::TensorInfo asTensorInfoForBroadcast(const nnfw::util::tensor::Shape &shape, - const int32_t type, const float scale = 0.0f, + const int32_t type, const size_t baseRank, + const float scale = 0.0f, const int32_t zeroPoint = 0) { - return ::arm_compute::TensorInfo(asTensorShapeForBroadcast(shape), 1, asDataType(type), + return ::arm_compute::TensorInfo(asTensorShapeForBroadcast(shape, baseRank), 1, asDataType(type), asQuantizationInfo(scale, zeroPoint)); }