#include <loco/IR/CanonicalNode.h>
#include <loco/IR/CanonicalNodeVisitor.h>
+#include <loco/IR/CanonicalDialect.h>
#include <loco/Service/TypeInference.h>
+#include <locoex/COpDialect.h>
+#include <locoex/Service/COpTypeInference.h>
+
#include <stdex/Memory.h>
#include <type_traits>
void TypeInference::run(loco::Graph *g)
{
- loco::CanonicalTypeInferenceRule rule;
- loco::apply(&rule).to(g);
+ loco::CanonicalTypeInferenceRule canonical_rule;
+ locoex::COpTypeInferenceRule cop_rule; // rule for custom op
+
+ loco::MultiDialectTypeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(locoex::COpDialect::get(), &cop_rule);
+
+ loco::apply(&rules).to(g);
}
tflite::TensorType TypeInference::get(loco::Node *node)