From af51d863c86adadaac1008a51bc0d9fb5c4f5e23 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 22 Aug 2019 16:21:26 +0900 Subject: [PATCH] [moco-tf] suporting custom op when inferencing shape and type (#6818) Added shape and type support for custom op. Signed-off-by: Hyun Sik Yoon --- compiler/moco-tf/src/Transforms/ShapeInferencePass.cpp | 9 +++++++-- compiler/moco-tf/src/Transforms/TypeInferencePass.cpp | 10 ++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/compiler/moco-tf/src/Transforms/ShapeInferencePass.cpp b/compiler/moco-tf/src/Transforms/ShapeInferencePass.cpp index d82cdf1..7b075e7 100644 --- a/compiler/moco-tf/src/Transforms/ShapeInferencePass.cpp +++ b/compiler/moco-tf/src/Transforms/ShapeInferencePass.cpp @@ -28,6 +28,9 @@ #include #include +#include +#include + namespace moco { namespace tf @@ -37,11 +40,13 @@ bool ShapeInferencePass::run(loco::Graph *graph) { loco::CanonicalShapeInferenceRule canonical_rule; TFShapeInferenceRule tf_rule; + locoex::COpShapeInferenceRule cop_rule; // rule for custop op loco::MultiDialectShapeInferenceRule rules; - rules.bind(loco::CanonicalDialect::get(), &canonical_rule).bind(TFDialect::get(), &tf_rule); - // TODO: add CustomOp shape inference + rules.bind(loco::CanonicalDialect::get(), &canonical_rule) + .bind(TFDialect::get(), &tf_rule) + .bind(locoex::COpDialect::get(), &cop_rule); loco::apply(&rules).to(graph); diff --git a/compiler/moco-tf/src/Transforms/TypeInferencePass.cpp b/compiler/moco-tf/src/Transforms/TypeInferencePass.cpp index 331c055..efdacc5 100644 --- a/compiler/moco-tf/src/Transforms/TypeInferencePass.cpp +++ b/compiler/moco-tf/src/Transforms/TypeInferencePass.cpp @@ -24,6 +24,9 @@ #include #include +#include +#include + namespace moco { namespace tf @@ -32,11 +35,14 @@ namespace tf bool TypeInferencePass::run(loco::Graph *graph) { loco::CanonicalTypeInferenceRule canonical_rule; - TFTypeInferenceRule tf_rule; // rule for TF dialect + TFTypeInferenceRule tf_rule; // rule for TF dialect + locoex::COpTypeInferenceRule cop_rule; // rule for custop op loco::MultiDialectTypeInferenceRule rules; - rules.bind(loco::CanonicalDialect::get(), &canonical_rule).bind(TFDialect::get(), &tf_rule); + rules.bind(loco::CanonicalDialect::get(), &canonical_rule) + .bind(TFDialect::get(), &tf_rule) + .bind(locoex::COpDialect::get(), &cop_rule); loco::apply(&rules).to(graph); -- 2.7.4