[loco] TypeInferenceRule to use inference rules of multiple dialects (#6198)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Tue, 6 Aug 2019 08:13:10 +0000 (17:13 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 6 Aug 2019 08:13:10 +0000 (17:13 +0900)
* [loco] TypeInferenceRule to use inference rules of multiple dialects

This commit adds MultiDialectTypeInferenceRule that uses inference rules of multiple dialects.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
* using loco::link(..)

* using loco::link(..) for output too

* assert

* bind method

* struct to class

* //// -> ///

* modifying assert

compiler/loco/include/loco/Service/TypeInference.h
compiler/loco/src/Service/TypeInference.cpp
compiler/loco/src/Service/TypeInference.test.cpp

index cb72fd8..46ad312 100644 (file)
@@ -23,6 +23,8 @@
 #include "loco/IR/Dialect.h"
 #include "loco/IR/Graph.h"
 
+#include <map>
+
 /**
  * @file This file implements dialect-agnostic type inference framework.
  *
@@ -63,6 +65,22 @@ struct CanonicalTypeInferenceRule final : public TypeInferenceRule
   bool infer(const Node *, DataType &) const final;
 };
 
+/**
+ * @brief Type Inference Rule for multiple dialects
+ */
+class MultiDialectTypeInferenceRule final : public TypeInferenceRule
+{
+public:
+  bool recognize(const Dialect *) const final;
+  bool infer(const Node *, DataType &) const final;
+
+  /// @brief Bind a specific rule to a Dialect
+  MultiDialectTypeInferenceRule &bind(const Dialect *d, const TypeInferenceRule *rule);
+
+private:
+  std::map<const Dialect *, const TypeInferenceRule *> _rules;
+};
+
 class TypeInferenceSession
 {
 public:
index 7e70f48..3c18cb3 100644 (file)
@@ -141,4 +141,44 @@ bool CanonicalTypeInferenceRule::infer(const Node *node, DataType &dtype) const
   return true;
 }
 
+bool MultiDialectTypeInferenceRule::recognize(const Dialect *d) const
+{
+  const auto found = _rules.find(d);
+
+  if (found == _rules.cend())
+    return false;
+
+  auto rule = found->second;
+  auto result = rule->recognize(d);
+
+  assert(result);
+
+  return result;
+}
+
+bool MultiDialectTypeInferenceRule::infer(const Node *node, DataType &dtype) const
+{
+  const auto found = _rules.find(node->dialect());
+
+  if (found == _rules.cend())
+    return false;
+
+  auto rule = found->second;
+  if (rule->infer(node, dtype))
+    return true;
+
+  return false;
+}
+
+MultiDialectTypeInferenceRule &MultiDialectTypeInferenceRule::bind(const Dialect *d,
+                                                                   const TypeInferenceRule *rule)
+{
+  assert(_rules.find(d) == _rules.end());
+  assert(rule->recognize(d));
+
+  _rules[d] = rule;
+
+  return (*this);
+}
+
 } // namespace loco
index 238dde3..9c65946 100644 (file)
@@ -163,3 +163,96 @@ TEST(CanonicalTypeInferenceRuleTest, relu6)
   ASSERT_TRUE(loco::dtype_known(relu6_node));
   ASSERT_EQ(loco::dtype_get(relu6_node), loco::DataType::FLOAT32);
 }
+
+// mockup for MultiDialectTypeInferenceRule
+// OpNode of a specific loco datatype (defined in template) will be used.
+// And a Dialect for the OpNode and its inference rules are created.
+#include <loco/IR/Dialect.h>
+
+namespace
+{
+
+template <loco::DataType N> class TestDialect final : public loco::Dialect
+{
+public:
+  static Dialect *get(void)
+  {
+    static TestDialect<N> d;
+    return &d;
+  }
+};
+
+template <loco::DataType N>
+struct TestOpNode final : public loco::FixedArityNode<1, loco::Node>,
+                          public loco::NodeMixin<loco::NodeTrait::DataType>
+{
+  void input(Node *node) { at(0)->node(node); }
+  const loco::Dialect *dialect(void) const final { return TestDialect<N>::get(); }
+  uint32_t opnum(void) const final { return static_cast<uint32_t>(N); }
+};
+
+template <loco::DataType N> struct TestTypeInferenceRule final : public loco::TypeInferenceRule
+{
+public:
+  bool recognize(const loco::Dialect *d) const final { return (d == TestDialect<N>::get()); }
+
+  bool infer(const loco::Node *node, loco::DataType &dtype) const final
+  {
+    assert(node->dialect() == TestDialect<N>::get());
+    auto test_node = dynamic_cast<const TestOpNode<N> *>(node);
+    assert(test_node != nullptr);
+
+    dtype = N;
+    return true;
+  }
+};
+
+} // namespace
+
+TEST(MultiDialectTypeInferenceRuleTest, test1)
+{
+  // Create a simple network : Pull - S8 - U8 - Push
+  auto g = loco::make_graph();
+
+  auto pull_node = g->nodes()->create<loco::Pull>();
+  pull_node->dtype(loco::DataType::FLOAT32);
+
+  auto s8_node = g->nodes()->create<TestOpNode<loco::DataType::S8>>();
+  s8_node->input(pull_node);
+
+  auto u8_node = g->nodes()->create<TestOpNode<loco::DataType::U8>>();
+  u8_node->input(s8_node);
+
+  auto push_node = g->nodes()->create<loco::Push>();
+  push_node->from(u8_node);
+
+  auto graph_input = g->inputs()->create();
+  graph_input->name("input");
+  loco::link(graph_input, pull_node);
+
+  auto graph_output = g->outputs()->create();
+  graph_output->name("output");
+  loco::link(graph_output, push_node);
+
+  // initially they don't have type info
+  ASSERT_FALSE(loco::dtype_known(s8_node));
+  ASSERT_FALSE(loco::dtype_known(u8_node));
+
+  // Run Type Inference
+  TestTypeInferenceRule<loco::DataType::U8> u8_rule;
+  TestTypeInferenceRule<loco::DataType::S8> s8_rule;
+
+  loco::MultiDialectTypeInferenceRule rules;
+
+  rules.bind(TestDialect<loco::DataType::S8>::get(), &s8_rule)
+      .bind(TestDialect<loco::DataType::U8>::get(), &u8_rule);
+
+  loco::apply(&rules).to(g.get());
+
+  // Verify!
+  ASSERT_TRUE(loco::dtype_known(s8_node));
+  ASSERT_EQ(loco::dtype_get(s8_node), loco::DataType::S8);
+
+  ASSERT_TRUE(loco::dtype_known(u8_node));
+  ASSERT_EQ(loco::dtype_get(u8_node), loco::DataType::U8);
+}