[moco-tf] suporting custom op when inferencing shape and type (#6818)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Thu, 22 Aug 2019 07:21:26 +0000 (16:21 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 22 Aug 2019 07:21:26 +0000 (16:21 +0900)
Added shape and type support for custom op.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/moco-tf/src/Transforms/ShapeInferencePass.cpp
compiler/moco-tf/src/Transforms/TypeInferencePass.cpp

index d82cdf1..7b075e7 100644 (file)
@@ -28,6 +28,9 @@
 #include <loco/Service/CanonicalShapeInferenceRule.h>
 #include <loco/Service/MultiDialectShapeInferenceRule.h>
 
+#include <locoex/COpDialect.h>
+#include <locoex/Service/COpShapeInferenceRule.h>
+
 namespace moco
 {
 namespace tf
@@ -37,11 +40,13 @@ bool ShapeInferencePass::run(loco::Graph *graph)
 {
   loco::CanonicalShapeInferenceRule canonical_rule;
   TFShapeInferenceRule tf_rule;
+  locoex::COpShapeInferenceRule cop_rule; // rule for custop op
 
   loco::MultiDialectShapeInferenceRule rules;
 
-  rules.bind(loco::CanonicalDialect::get(), &canonical_rule).bind(TFDialect::get(), &tf_rule);
-  // TODO: add CustomOp shape inference
+  rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+      .bind(TFDialect::get(), &tf_rule)
+      .bind(locoex::COpDialect::get(), &cop_rule);
 
   loco::apply(&rules).to(graph);
 
index 331c055..efdacc5 100644 (file)
@@ -24,6 +24,9 @@
 #include <loco/IR/CanonicalDialect.h>
 #include <loco/Service/TypeInference.h>
 
+#include <locoex/COpDialect.h>
+#include <locoex/Service/COpTypeInference.h>
+
 namespace moco
 {
 namespace tf
@@ -32,11 +35,14 @@ namespace tf
 bool TypeInferencePass::run(loco::Graph *graph)
 {
   loco::CanonicalTypeInferenceRule canonical_rule;
-  TFTypeInferenceRule tf_rule; // rule for TF dialect
+  TFTypeInferenceRule tf_rule;           // rule for TF dialect
+  locoex::COpTypeInferenceRule cop_rule; // rule for custop op
 
   loco::MultiDialectTypeInferenceRule rules;
 
-  rules.bind(loco::CanonicalDialect::get(), &canonical_rule).bind(TFDialect::get(), &tf_rule);
+  rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+      .bind(TFDialect::get(), &tf_rule)
+      .bind(locoex::COpDialect::get(), &cop_rule);
 
   loco::apply(&rules).to(graph);