From 562867e32c940acd13617f4834372afbcdba9737 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 1 Aug 2019 16:51:49 +0900 Subject: [PATCH] [exo-tflite] Shapeinf for EltwiseAdd and EltwiseMul (#6099) This will add to support ShapeInference for EltwiseAdd and EltwiseMul Signed-off-by: SaeHie Park --- compiler/exo-tflite/src/ShapeInference.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/compiler/exo-tflite/src/ShapeInference.cpp b/compiler/exo-tflite/src/ShapeInference.cpp index 0109105..2891b55 100644 --- a/compiler/exo-tflite/src/ShapeInference.cpp +++ b/compiler/exo-tflite/src/ShapeInference.cpp @@ -79,6 +79,8 @@ public: NODE(BiasEncode) NODE(TensorBiasAdd) NODE(FeatureBiasAdd) + NODE(EltwiseAdd) + NODE(EltwiseMul) #undef NODE // TODO Put all the visit method implementations inside this class declaration ShapeDescription visit(loco::ReLU6 *node) { return gd._node_to_shape[node->input()]; } @@ -425,6 +427,26 @@ ShapeDescription ShapeGetter::visit(loco::FeatureBiasAdd *node) return value_shape; } +ShapeDescription ShapeGetter::visit(loco::EltwiseAdd *node) +{ + const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()]; + const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()]; + + assert(lhs_shape._dims == rhs_shape._dims); + + return lhs_shape; +} + +ShapeDescription ShapeGetter::visit(loco::EltwiseMul *node) +{ + const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()]; + const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()]; + + assert(lhs_shape._dims == rhs_shape._dims); + + return lhs_shape; +} + } // namespace namespace -- 2.7.4