[exo-tflite] Shape inf. for EltwiseSub and EltwiseDiv (#6264)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 6 Aug 2019 03:53:31 +0000 (12:53 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 6 Aug 2019 03:53:31 +0000 (12:53 +0900)
This will enable shape inference of EltwiseSub and EltwiseDiv node

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/exo-tflite/src/ShapeInference.cpp

index 4d37440..f14486d 100644 (file)
@@ -103,6 +103,8 @@ public:
   NODE(FeatureBiasAdd)
   NODE(EltwiseAdd)
   NODE(EltwiseMul)
+  NODE(EltwiseSub)
+  NODE(EltwiseDiv)
 #undef NODE
   // TODO Put all the visit method implementations inside this class declaration
   ShapeDescription visit(loco::ReLU6 *node) { return gd._node_to_shape[node->input()]; }
@@ -469,6 +471,25 @@ ShapeDescription ShapeGetter::visit(loco::EltwiseMul *node)
   return lhs_shape;
 }
 
+ShapeDescription ShapeGetter::visit(loco::EltwiseSub *node)
+{
+  const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
+  const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
+
+  assert(lhs_shape._dims == rhs_shape._dims);
+
+  return lhs_shape;
+}
+
+ShapeDescription ShapeGetter::visit(loco::EltwiseDiv *node)
+{
+  const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
+  const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
+
+  assert(lhs_shape._dims == rhs_shape._dims);
+
+  return lhs_shape;
+}
 } // namespace
 
 namespace