#include "loco/IR/Dialect.h"
#include "loco/IR/Graph.h"
+#include <map>
+
/**
* @file This file implements dialect-agnostic type inference framework.
*
bool infer(const Node *, DataType &) const final;
};
+/**
+ * @brief Type Inference Rule for multiple dialects
+ */
+class MultiDialectTypeInferenceRule final : public TypeInferenceRule
+{
+public:
+ bool recognize(const Dialect *) const final;
+ bool infer(const Node *, DataType &) const final;
+
+ /// @brief Bind a specific rule to a Dialect
+ MultiDialectTypeInferenceRule &bind(const Dialect *d, const TypeInferenceRule *rule);
+
+private:
+ std::map<const Dialect *, const TypeInferenceRule *> _rules;
+};
+
class TypeInferenceSession
{
public:
return true;
}
+bool MultiDialectTypeInferenceRule::recognize(const Dialect *d) const
+{
+ const auto found = _rules.find(d);
+
+ if (found == _rules.cend())
+ return false;
+
+ auto rule = found->second;
+ auto result = rule->recognize(d);
+
+ assert(result);
+
+ return result;
+}
+
+bool MultiDialectTypeInferenceRule::infer(const Node *node, DataType &dtype) const
+{
+ const auto found = _rules.find(node->dialect());
+
+ if (found == _rules.cend())
+ return false;
+
+ auto rule = found->second;
+ if (rule->infer(node, dtype))
+ return true;
+
+ return false;
+}
+
+MultiDialectTypeInferenceRule &MultiDialectTypeInferenceRule::bind(const Dialect *d,
+ const TypeInferenceRule *rule)
+{
+ assert(_rules.find(d) == _rules.end());
+ assert(rule->recognize(d));
+
+ _rules[d] = rule;
+
+ return (*this);
+}
+
} // namespace loco
ASSERT_TRUE(loco::dtype_known(relu6_node));
ASSERT_EQ(loco::dtype_get(relu6_node), loco::DataType::FLOAT32);
}
+
+// mockup for MultiDialectTypeInferenceRule
+// OpNode of a specific loco datatype (defined in template) will be used.
+// And a Dialect for the OpNode and its inference rules are created.
+#include <loco/IR/Dialect.h>
+
+namespace
+{
+
+template <loco::DataType N> class TestDialect final : public loco::Dialect
+{
+public:
+ static Dialect *get(void)
+ {
+ static TestDialect<N> d;
+ return &d;
+ }
+};
+
+template <loco::DataType N>
+struct TestOpNode final : public loco::FixedArityNode<1, loco::Node>,
+ public loco::NodeMixin<loco::NodeTrait::DataType>
+{
+ void input(Node *node) { at(0)->node(node); }
+ const loco::Dialect *dialect(void) const final { return TestDialect<N>::get(); }
+ uint32_t opnum(void) const final { return static_cast<uint32_t>(N); }
+};
+
+template <loco::DataType N> struct TestTypeInferenceRule final : public loco::TypeInferenceRule
+{
+public:
+ bool recognize(const loco::Dialect *d) const final { return (d == TestDialect<N>::get()); }
+
+ bool infer(const loco::Node *node, loco::DataType &dtype) const final
+ {
+ assert(node->dialect() == TestDialect<N>::get());
+ auto test_node = dynamic_cast<const TestOpNode<N> *>(node);
+ assert(test_node != nullptr);
+
+ dtype = N;
+ return true;
+ }
+};
+
+} // namespace
+
+TEST(MultiDialectTypeInferenceRuleTest, test1)
+{
+ // Create a simple network : Pull - S8 - U8 - Push
+ auto g = loco::make_graph();
+
+ auto pull_node = g->nodes()->create<loco::Pull>();
+ pull_node->dtype(loco::DataType::FLOAT32);
+
+ auto s8_node = g->nodes()->create<TestOpNode<loco::DataType::S8>>();
+ s8_node->input(pull_node);
+
+ auto u8_node = g->nodes()->create<TestOpNode<loco::DataType::U8>>();
+ u8_node->input(s8_node);
+
+ auto push_node = g->nodes()->create<loco::Push>();
+ push_node->from(u8_node);
+
+ auto graph_input = g->inputs()->create();
+ graph_input->name("input");
+ loco::link(graph_input, pull_node);
+
+ auto graph_output = g->outputs()->create();
+ graph_output->name("output");
+ loco::link(graph_output, push_node);
+
+ // initially they don't have type info
+ ASSERT_FALSE(loco::dtype_known(s8_node));
+ ASSERT_FALSE(loco::dtype_known(u8_node));
+
+ // Run Type Inference
+ TestTypeInferenceRule<loco::DataType::U8> u8_rule;
+ TestTypeInferenceRule<loco::DataType::S8> s8_rule;
+
+ loco::MultiDialectTypeInferenceRule rules;
+
+ rules.bind(TestDialect<loco::DataType::S8>::get(), &s8_rule)
+ .bind(TestDialect<loco::DataType::U8>::get(), &u8_rule);
+
+ loco::apply(&rules).to(g.get());
+
+ // Verify!
+ ASSERT_TRUE(loco::dtype_known(s8_node));
+ ASSERT_EQ(loco::dtype_get(s8_node), loco::DataType::S8);
+
+ ASSERT_TRUE(loco::dtype_known(u8_node));
+ ASSERT_EQ(loco::dtype_get(u8_node), loco::DataType::U8);
+}