[loco] Shape Inference for Eltwise Arithmetic Nodes (#6256)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 6 Aug 2019 03:32:31 +0000 (12:32 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 6 Aug 2019 03:32:31 +0000 (12:32 +0900)
This commit implements shape inference for eltwise arithmetic nodes.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp

index 98d4d75..366c6e5 100644 (file)
 namespace
 {
 
+loco::NodeShape eltwise_binary_node_shape(const loco::Node *node)
+{
+  // This helper works only for binary node.
+  assert(node->arity() == 2);
+
+  auto lhs_shape = loco::shape_get(node->arg(0));
+  auto rhs_shape = loco::shape_get(node->arg(1));
+
+  // ASSERT: lhs_shape == rhs_shape
+
+  return lhs_shape;
+}
+
 /**
  * There are two possible maintenance policies.
  * - Introduce a new canonical node first, and then extend this algorithm later
@@ -46,10 +59,31 @@ public:
   // TODO Support Conv2D
   // TODO Support DepthwiseConv2D
   // TODO Support DepthwiseFilterEncode
-  // TODO Support EltwiseAdd
-  // TODO Support EltwiseDiv
-  // TODO Support EltwiseMul
-  // TODO Support EltwiseSub
+
+  // CASE: EltwiseAdd
+  loco::NodeShape visit(const loco::EltwiseAdd *node) final
+  {
+    return eltwise_binary_node_shape(node);
+  }
+
+  // CASE: EltwiseDiv
+  loco::NodeShape visit(const loco::EltwiseDiv *node) final
+  {
+    return eltwise_binary_node_shape(node);
+  }
+
+  // CASE: EltwiseMul
+  loco::NodeShape visit(const loco::EltwiseMul *node) final
+  {
+    return eltwise_binary_node_shape(node);
+  }
+
+  // CASE: EltwiseSub
+  loco::NodeShape visit(const loco::EltwiseSub *node) final
+  {
+    return eltwise_binary_node_shape(node);
+  }
+
   // TODO Support Forward
   // TODO Support FeatureBiasAdd