#include <loco/Service/CanonicalShapeInferenceRule.h>
#include <loco/Service/MultiDialectShapeInferenceRule.h>
+#include <locoex/COpDialect.h>
+#include <locoex/Service/COpShapeInferenceRule.h>
+
namespace moco
{
namespace tf
{
loco::CanonicalShapeInferenceRule canonical_rule;
TFShapeInferenceRule tf_rule;
+ locoex::COpShapeInferenceRule cop_rule; // rule for custop op
loco::MultiDialectShapeInferenceRule rules;
- rules.bind(loco::CanonicalDialect::get(), &canonical_rule).bind(TFDialect::get(), &tf_rule);
- // TODO: add CustomOp shape inference
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(TFDialect::get(), &tf_rule)
+ .bind(locoex::COpDialect::get(), &cop_rule);
loco::apply(&rules).to(graph);
#include <loco/IR/CanonicalDialect.h>
#include <loco/Service/TypeInference.h>
+#include <locoex/COpDialect.h>
+#include <locoex/Service/COpTypeInference.h>
+
namespace moco
{
namespace tf
bool TypeInferencePass::run(loco::Graph *graph)
{
loco::CanonicalTypeInferenceRule canonical_rule;
- TFTypeInferenceRule tf_rule; // rule for TF dialect
+ TFTypeInferenceRule tf_rule; // rule for TF dialect
+ locoex::COpTypeInferenceRule cop_rule; // rule for custop op
loco::MultiDialectTypeInferenceRule rules;
- rules.bind(loco::CanonicalDialect::get(), &canonical_rule).bind(TFDialect::get(), &tf_rule);
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(TFDialect::get(), &tf_rule)
+ .bind(locoex::COpDialect::get(), &cop_rule);
loco::apply(&rules).to(graph);