From 4a25295e511f1b920be5e2fd96169b400c0b7e5f Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 14 Aug 2019 15:58:50 +0900 Subject: [PATCH] [loco/service] ShapeInferenceRule for multiple dialects (#6587) * [loco/service] ShapeInferenceRule for multiple dialects This adds a loco service, a ShapeInferenceRule for multiple dialects. Signed-off-by: Hyun Sik Yoon * fix typo * modified include * add #include * Remove always true assert --- .../loco/Service/MultiDialectShapeInferenceRule.h | 45 +++++++ .../src/Service/MultiDialectShapeInferenceRule.cpp | 67 +++++++++++ .../MultiDialectShapeInferenceRule.test.cpp | 129 +++++++++++++++++++++ 3 files changed, 241 insertions(+) create mode 100644 compiler/loco/include/loco/Service/MultiDialectShapeInferenceRule.h create mode 100644 compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp create mode 100644 compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp diff --git a/compiler/loco/include/loco/Service/MultiDialectShapeInferenceRule.h b/compiler/loco/include/loco/Service/MultiDialectShapeInferenceRule.h new file mode 100644 index 0000000..1a6c85b --- /dev/null +++ b/compiler/loco/include/loco/Service/MultiDialectShapeInferenceRule.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCO_SERVICE_MULTI_DIALECT_SHAPE_INFERENCE_RULE_H__ +#define __LOCO_SERVICE_MULTI_DIALECT_SHAPE_INFERENCE_RULE_H__ + +#include "loco/Service/ShapeInferenceRule.h" + +#include + +namespace loco +{ + +/** + * @brief Shape inference rule for multiple dialects + */ +class MultiDialectShapeInferenceRule final : public ShapeInferenceRule +{ +public: + bool recognize(const Dialect *) const final; + bool infer(const Node *, NodeShape &) const final; + + /// @brief Bind a specific rule to a Dialect + MultiDialectShapeInferenceRule &bind(const Dialect *d, const ShapeInferenceRule *rule); + +private: + std::map _rules; +}; + +} // namespace loco + +#endif // __LOCO_SERVICE_MULTI_DIALECT_SHAPE_INFERENCE_RULE_H__ diff --git a/compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp new file mode 100644 index 0000000..2178f5d --- /dev/null +++ b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/MultiDialectShapeInferenceRule.h" +#include "loco/Service/ShapeInferenceRule.h" + +#include +#include +#include + +#include + +namespace loco +{ + +bool MultiDialectShapeInferenceRule::recognize(const Dialect *d) const +{ + const auto found = _rules.find(d); + + if (found == _rules.cend()) + return false; + + auto rule = found->second; + auto result = rule->recognize(d); + + return result; +} + +bool MultiDialectShapeInferenceRule::infer(const Node *node, NodeShape &shape) const +{ + const auto found = _rules.find(node->dialect()); + + if (found == _rules.cend()) + return false; + + auto rule = found->second; + if (rule->infer(node, shape)) + return true; + + return false; +} + +MultiDialectShapeInferenceRule &MultiDialectShapeInferenceRule::bind(const Dialect *d, + const ShapeInferenceRule *rule) +{ + assert(_rules.find(d) == _rules.end()); + assert(rule->recognize(d)); + + _rules[d] = rule; + + return (*this); +} + +} // namespace loco diff --git a/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp new file mode 100644 index 0000000..5198d1d --- /dev/null +++ b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/MultiDialectShapeInferenceRule.h" +#include "loco/Service/ShapeInference.h" + +#include + +#include + +#include +#include + +// mockup for MultiDialectShapeInferenceRule +// Each class is dedicated for handling shape { D1, D2 } and D1, D2 are declared as a template +namespace +{ + +template class TestDialect final : public loco::Dialect +{ +public: + static Dialect *get(void) + { + static TestDialect d; + return &d; + } +}; + +template +struct TestOpNode final : public loco::FixedArityNode<1, loco::Node>, + public loco::NodeMixin +{ + void input(Node *node) { at(0)->node(node); } + const loco::Dialect *dialect(void) const final { return TestDialect::get(); } + uint32_t opnum(void) const final { return static_cast(D1); /* not used */ } +}; + +template +struct TestShapeInferenceRule final : public loco::ShapeInferenceRule +{ +public: + bool recognize(const loco::Dialect *d) const final { return (d == TestDialect::get()); } + + bool infer(const loco::Node *node, loco::NodeShape &node_shape) const final + { + assert(recognize(node->dialect())); + auto test_node = dynamic_cast *>(node); + assert(test_node != nullptr); + + loco::TensorShape ts; + { + ts.rank(2); + ts.dim(0) = D1; + ts.dim(1) = D2; // making shape : { D1, D2 } + } + + node_shape.set(ts); + + return true; + } +}; + +} // namespace + +TEST(MultiDialectShapeInferenceRuleTest, test1) +{ + // Create a simple network : Pull ------- t23<2,3> ------------ t45<4,5> ---------- Push + // TensorShape({2, 3}) TensorShape({4, 5}) + auto g = loco::make_graph(); + + auto pull_node = g->nodes()->create(); + auto t23_node = g->nodes()->create>(); + auto t45_node = g->nodes()->create>(); + auto push_node = g->nodes()->create(); + + t23_node->input(pull_node); + t45_node->input(t23_node); + push_node->from(t45_node); + + auto graph_input = g->inputs()->create(); + graph_input->name("input"); + loco::link(graph_input, pull_node); + + auto graph_output = g->outputs()->create(); + graph_output->name("output"); + loco::link(graph_output, push_node); + + // initially they don't have shape info + ASSERT_FALSE(loco::shape_known(t23_node)); + ASSERT_FALSE(loco::shape_known(t45_node)); + + // Run Type Inference + TestShapeInferenceRule<2, 3> t23_rule; + TestShapeInferenceRule<4, 5> t45_rule; + + loco::MultiDialectShapeInferenceRule rules; + + rules.bind(TestDialect<2, 3>::get(), &t23_rule).bind(TestDialect<4, 5>::get(), &t45_rule); + + loco::apply(&rules).to(g.get()); + + // Verify! + ASSERT_TRUE(loco::shape_known(t23_node)); + auto t23_shape = loco::shape_get(t23_node); + ASSERT_EQ(t23_shape.domain(), loco::Domain::Tensor); + ASSERT_EQ(t23_shape.as().rank(), 2); + ASSERT_EQ(t23_shape.as().dim(0), 2); + ASSERT_EQ(t23_shape.as().dim(1), 3); + + ASSERT_TRUE(loco::shape_known(t45_node)); + auto t45_shape = loco::shape_get(t45_node); + ASSERT_EQ(t45_shape.domain(), loco::Domain::Tensor); + ASSERT_EQ(t45_shape.as().rank(), 2); + ASSERT_EQ(t45_shape.as().dim(0), 4); + ASSERT_EQ(t45_shape.as().dim(1), 5); +} -- 2.7.4