[exo-tflite] Shapeinf for EltwiseAdd and EltwiseMul (#6099)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 1 Aug 2019 07:51:49 +0000 (16:51 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 1 Aug 2019 07:51:49 +0000 (16:51 +0900)
This will add to support ShapeInference for EltwiseAdd and EltwiseMul

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/exo-tflite/src/ShapeInference.cpp

index 0109105..2891b55 100644 (file)
@@ -79,6 +79,8 @@ public:
   NODE(BiasEncode)
   NODE(TensorBiasAdd)
   NODE(FeatureBiasAdd)
+  NODE(EltwiseAdd)
+  NODE(EltwiseMul)
 #undef NODE
   // TODO Put all the visit method implementations inside this class declaration
   ShapeDescription visit(loco::ReLU6 *node) { return gd._node_to_shape[node->input()]; }
@@ -425,6 +427,26 @@ ShapeDescription ShapeGetter::visit(loco::FeatureBiasAdd *node)
   return value_shape;
 }
 
+ShapeDescription ShapeGetter::visit(loco::EltwiseAdd *node)
+{
+  const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
+  const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
+
+  assert(lhs_shape._dims == rhs_shape._dims);
+
+  return lhs_shape;
+}
+
+ShapeDescription ShapeGetter::visit(loco::EltwiseMul *node)
+{
+  const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
+  const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
+
+  assert(lhs_shape._dims == rhs_shape._dims);
+
+  return lhs_shape;
+}
+
 } // namespace
 
 namespace