[loco/service] ShapeInferenceRule for multiple dialects (#6587)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Wed, 14 Aug 2019 06:58:50 +0000 (15:58 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 14 Aug 2019 06:58:50 +0000 (15:58 +0900)
* [loco/service] ShapeInferenceRule for multiple dialects

This adds a loco service, a ShapeInferenceRule for multiple dialects.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
* fix typo

* modified include

* add #include <map>

* Remove always true assert

compiler/loco/include/loco/Service/MultiDialectShapeInferenceRule.h [new file with mode: 0644]
compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp [new file with mode: 0644]
compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp [new file with mode: 0644]

diff --git a/compiler/loco/include/loco/Service/MultiDialectShapeInferenceRule.h b/compiler/loco/include/loco/Service/MultiDialectShapeInferenceRule.h
new file mode 100644 (file)
index 0000000..1a6c85b
--- /dev/null
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCO_SERVICE_MULTI_DIALECT_SHAPE_INFERENCE_RULE_H__
+#define __LOCO_SERVICE_MULTI_DIALECT_SHAPE_INFERENCE_RULE_H__
+
+#include "loco/Service/ShapeInferenceRule.h"
+
+#include <map>
+
+namespace loco
+{
+
+/**
+ * @brief Shape inference rule for multiple dialects
+ */
+class MultiDialectShapeInferenceRule final : public ShapeInferenceRule
+{
+public:
+  bool recognize(const Dialect *) const final;
+  bool infer(const Node *, NodeShape &) const final;
+
+  /// @brief Bind a specific rule to a Dialect
+  MultiDialectShapeInferenceRule &bind(const Dialect *d, const ShapeInferenceRule *rule);
+
+private:
+  std::map<const Dialect *, const ShapeInferenceRule *> _rules;
+};
+
+} // namespace loco
+
+#endif // __LOCO_SERVICE_MULTI_DIALECT_SHAPE_INFERENCE_RULE_H__
diff --git a/compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp
new file mode 100644 (file)
index 0000000..2178f5d
--- /dev/null
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/MultiDialectShapeInferenceRule.h"
+#include "loco/Service/ShapeInferenceRule.h"
+
+#include <loco/IR/Dialect.h>
+#include <loco/IR/Node.h>
+#include <loco/IR/NodeShape.h>
+
+#include <cassert>
+
+namespace loco
+{
+
+bool MultiDialectShapeInferenceRule::recognize(const Dialect *d) const
+{
+  const auto found = _rules.find(d);
+
+  if (found == _rules.cend())
+    return false;
+
+  auto rule = found->second;
+  auto result = rule->recognize(d);
+
+  return result;
+}
+
+bool MultiDialectShapeInferenceRule::infer(const Node *node, NodeShape &shape) const
+{
+  const auto found = _rules.find(node->dialect());
+
+  if (found == _rules.cend())
+    return false;
+
+  auto rule = found->second;
+  if (rule->infer(node, shape))
+    return true;
+
+  return false;
+}
+
+MultiDialectShapeInferenceRule &MultiDialectShapeInferenceRule::bind(const Dialect *d,
+                                                                     const ShapeInferenceRule *rule)
+{
+  assert(_rules.find(d) == _rules.end());
+  assert(rule->recognize(d));
+
+  _rules[d] = rule;
+
+  return (*this);
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp
new file mode 100644 (file)
index 0000000..5198d1d
--- /dev/null
@@ -0,0 +1,129 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/MultiDialectShapeInferenceRule.h"
+#include "loco/Service/ShapeInference.h"
+
+#include <loco/IR/Dialect.h>
+
+#include <gtest/gtest.h>
+
+#include <cassert>
+#include <vector>
+
+// mockup for MultiDialectShapeInferenceRule
+// Each class is dedicated for handling shape { D1, D2 } and D1, D2 are declared as a template
+namespace
+{
+
+template <uint32_t D1, uint32_t D2> class TestDialect final : public loco::Dialect
+{
+public:
+  static Dialect *get(void)
+  {
+    static TestDialect<D1, D2> d;
+    return &d;
+  }
+};
+
+template <uint32_t D1, uint32_t D2>
+struct TestOpNode final : public loco::FixedArityNode<1, loco::Node>,
+                          public loco::NodeMixin<loco::NodeTrait::TensorShape>
+{
+  void input(Node *node) { at(0)->node(node); }
+  const loco::Dialect *dialect(void) const final { return TestDialect<D1, D2>::get(); }
+  uint32_t opnum(void) const final { return static_cast<uint32_t>(D1); /* not used */ }
+};
+
+template <uint32_t D1, uint32_t D2>
+struct TestShapeInferenceRule final : public loco::ShapeInferenceRule
+{
+public:
+  bool recognize(const loco::Dialect *d) const final { return (d == TestDialect<D1, D2>::get()); }
+
+  bool infer(const loco::Node *node, loco::NodeShape &node_shape) const final
+  {
+    assert(recognize(node->dialect()));
+    auto test_node = dynamic_cast<const TestOpNode<D1, D2> *>(node);
+    assert(test_node != nullptr);
+
+    loco::TensorShape ts;
+    {
+      ts.rank(2);
+      ts.dim(0) = D1;
+      ts.dim(1) = D2; // making shape : { D1, D2 }
+    }
+
+    node_shape.set(ts);
+
+    return true;
+  }
+};
+
+} // namespace
+
+TEST(MultiDialectShapeInferenceRuleTest, test1)
+{
+  // Create a simple network : Pull ------- t23<2,3> ------------ t45<4,5> ---------- Push
+  //                                  TensorShape({2, 3})    TensorShape({4, 5})
+  auto g = loco::make_graph();
+
+  auto pull_node = g->nodes()->create<loco::Pull>();
+  auto t23_node = g->nodes()->create<TestOpNode<2, 3>>();
+  auto t45_node = g->nodes()->create<TestOpNode<4, 5>>();
+  auto push_node = g->nodes()->create<loco::Push>();
+
+  t23_node->input(pull_node);
+  t45_node->input(t23_node);
+  push_node->from(t45_node);
+
+  auto graph_input = g->inputs()->create();
+  graph_input->name("input");
+  loco::link(graph_input, pull_node);
+
+  auto graph_output = g->outputs()->create();
+  graph_output->name("output");
+  loco::link(graph_output, push_node);
+
+  // initially they don't have shape info
+  ASSERT_FALSE(loco::shape_known(t23_node));
+  ASSERT_FALSE(loco::shape_known(t45_node));
+
+  // Run Type Inference
+  TestShapeInferenceRule<2, 3> t23_rule;
+  TestShapeInferenceRule<4, 5> t45_rule;
+
+  loco::MultiDialectShapeInferenceRule rules;
+
+  rules.bind(TestDialect<2, 3>::get(), &t23_rule).bind(TestDialect<4, 5>::get(), &t45_rule);
+
+  loco::apply(&rules).to(g.get());
+
+  // Verify!
+  ASSERT_TRUE(loco::shape_known(t23_node));
+  auto t23_shape = loco::shape_get(t23_node);
+  ASSERT_EQ(t23_shape.domain(), loco::Domain::Tensor);
+  ASSERT_EQ(t23_shape.as<loco::TensorShape>().rank(), 2);
+  ASSERT_EQ(t23_shape.as<loco::TensorShape>().dim(0), 2);
+  ASSERT_EQ(t23_shape.as<loco::TensorShape>().dim(1), 3);
+
+  ASSERT_TRUE(loco::shape_known(t45_node));
+  auto t45_shape = loco::shape_get(t45_node);
+  ASSERT_EQ(t45_shape.domain(), loco::Domain::Tensor);
+  ASSERT_EQ(t45_shape.as<loco::TensorShape>().rank(), 2);
+  ASSERT_EQ(t45_shape.as<loco::TensorShape>().dim(0), 4);
+  ASSERT_EQ(t45_shape.as<loco::TensorShape>().dim(1), 5);
+}