From 75bff3133702e1cd19b4ddf52c6c06f9407064b4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 6 Aug 2019 12:53:31 +0900 Subject: [PATCH] [exo-tflite] Shape inf. for EltwiseSub and EltwiseDiv (#6264) This will enable shape inference of EltwiseSub and EltwiseDiv node Signed-off-by: SaeHie Park --- compiler/exo-tflite/src/ShapeInference.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/compiler/exo-tflite/src/ShapeInference.cpp b/compiler/exo-tflite/src/ShapeInference.cpp index 4d37440..f14486d 100644 --- a/compiler/exo-tflite/src/ShapeInference.cpp +++ b/compiler/exo-tflite/src/ShapeInference.cpp @@ -103,6 +103,8 @@ public: NODE(FeatureBiasAdd) NODE(EltwiseAdd) NODE(EltwiseMul) + NODE(EltwiseSub) + NODE(EltwiseDiv) #undef NODE // TODO Put all the visit method implementations inside this class declaration ShapeDescription visit(loco::ReLU6 *node) { return gd._node_to_shape[node->input()]; } @@ -469,6 +471,25 @@ ShapeDescription ShapeGetter::visit(loco::EltwiseMul *node) return lhs_shape; } +ShapeDescription ShapeGetter::visit(loco::EltwiseSub *node) +{ + const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()]; + const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()]; + + assert(lhs_shape._dims == rhs_shape._dims); + + return lhs_shape; +} + +ShapeDescription ShapeGetter::visit(loco::EltwiseDiv *node) +{ + const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()]; + const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()]; + + assert(lhs_shape._dims == rhs_shape._dims); + + return lhs_shape; +} } // namespace namespace -- 2.7.4