[moco-tf] Introduce ShapeInferenceAlgorithm (#7871)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 1 Oct 2019 08:18:49 +0000 (17:18 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 1 Oct 2019 08:18:49 +0000 (17:18 +0900)
This will introduce ShapeInferenceAlgorithm to provide shape inferece by type of TFNode

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Dialect/TFShapeInferenceRule.cpp

index c86963d..07cd8b3 100644 (file)
 
 namespace
 {
+
+class ShapeInferenceAlgorithm final : public moco::tf::TFNodeVisitor<loco::NodeShape>
+{
+public:
+  ShapeInferenceAlgorithm(const loco::ShapeInferenceRule::Context *ctx) : _ctx{ctx}
+  {
+    // DO NOTHING
+  }
+
+private:
+  const loco::ShapeInferenceRule::Context *_ctx;
+
+private:
+  bool shape_known(const loco::Node *node) const { return _ctx->known(node); }
+  loco::NodeShape node_shape(const loco::Node *node) const { return _ctx->get(node); }
+
+public:
+  loco::NodeShape visit(const moco::tf::TFNode *node) final
+  {
+    loco::NodeShape unknown;
+    return unknown;
+  }
+};
+
+} // namespace
+
+namespace
+{
 namespace compat
 {
 
@@ -121,7 +149,13 @@ void TFShapeInferenceRule::infer(const Context *ctx, const loco::Node *node, Sin
   assert(node->dialect() == TFDialect::get());
   assert(dynamic_cast<const TFNode *>(node) != nullptr);
 
-  sink->fail();
+  ShapeInferenceAlgorithm alg{ctx};
+  auto shape = dynamic_cast<const TFNode *>(node)->accept(&alg);
+
+  if (shape.domain() == loco::Domain::Unknown)
+    sink->fail();
+  else
+    sink->okay(shape);
 }
 
 } // namespace tf