From 6ecba2774dd305b7c2ae7f07d159f67f7fc61c6a Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 2 Aug 2019 08:30:36 +0900 Subject: [PATCH] [loco] Introduce ShapeInference infrastructure (#6102) * [loco] Introduce ShapeInference infrastructure This commit introduces dialect-agnostic ShapeInference infrastructure. Signed-off-by: Jonghyun Park * Remove unnecessary return --- .../loco/include/loco/Service/ShapeInference.h | 66 ++++++++++++++++ .../loco/include/loco/Service/ShapeInferenceRule.h | 48 ++++++++++++ compiler/loco/src/Service/GraphTestcase.h | 52 +++++++++++++ compiler/loco/src/Service/ShapeInference.cpp | 76 +++++++++++++++++++ compiler/loco/src/Service/ShapeInference.test.cpp | 87 ++++++++++++++++++++++ compiler/loco/src/Service/ShapeInferenceRule.cpp | 19 +++++ 6 files changed, 348 insertions(+) create mode 100644 compiler/loco/include/loco/Service/ShapeInference.h create mode 100644 compiler/loco/include/loco/Service/ShapeInferenceRule.h create mode 100644 compiler/loco/src/Service/GraphTestcase.h create mode 100644 compiler/loco/src/Service/ShapeInference.cpp create mode 100644 compiler/loco/src/Service/ShapeInference.test.cpp create mode 100644 compiler/loco/src/Service/ShapeInferenceRule.cpp diff --git a/compiler/loco/include/loco/Service/ShapeInference.h b/compiler/loco/include/loco/Service/ShapeInference.h new file mode 100644 index 0000000..405eff7 --- /dev/null +++ b/compiler/loco/include/loco/Service/ShapeInference.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCO_SERVICE_SHAPE_INFERENCE_H__ +#define __LOCO_SERVICE_SHAPE_INFERENCE_H__ + +#include "loco/Service/ShapeInferenceRule.h" +#include "loco/IR/Graph.h" + +/** + * @file This file implements dialect-agnostic shape inference framework + * + * HOW TO USE: + * + * loco::Graph *g = ...; + * loco::ShapeInferenceRule *rule = ...; + * loco::apply(rule).to(g); + * + */ +namespace loco +{ + +class ShapeInferenceSession +{ +public: + ShapeInferenceSession(const ShapeInferenceRule *rule) : _rule{rule} + { + // DO NOTHING + } + +public: + void to(Graph *g) const; + +private: + const ShapeInferenceRule *_rule; +}; + +inline ShapeInferenceSession apply(ShapeInferenceRule *r) { return ShapeInferenceSession{r}; } + +struct ShapeInference +{ + static bool known(const Node *); + static NodeShape get(const Node *); + static void erase(Node *); +}; + +inline bool shape_known(const Node *node) { return ShapeInference::known(node); } +inline NodeShape shape_get(const Node *node) { return ShapeInference::get(node); } +inline void shape_erase(Node *node) { ShapeInference::erase(node); } + +} // namespace loco + +#endif // __LOCO_SERVICE_SHAPE_INFERENCE_H__ diff --git a/compiler/loco/include/loco/Service/ShapeInferenceRule.h b/compiler/loco/include/loco/Service/ShapeInferenceRule.h new file mode 100644 index 0000000..eeea3aa --- /dev/null +++ b/compiler/loco/include/loco/Service/ShapeInferenceRule.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCO_SERVICE_SHAPE_INFERENCE_RULE_H__ +#define __LOCO_SERVICE_SHAPE_INFERENCE_RULE_H__ + +#include "loco/IR/Domain.h" +#include "loco/IR/Dialect.h" +#include "loco/IR/Node.h" +#include "loco/IR/NodeShape.h" + +namespace loco +{ + +struct ShapeInferenceRule +{ + virtual ~ShapeInferenceRule() = default; + + /// @brief Return true if this rule recognizes a given dialect + virtual bool recognize(const Dialect *) const = 0; + + /** + * @brief Infer node's shape + * + * WARNING!! + * + * Implementation SHOULD return true only when it succeeds in inference! + * + */ + virtual bool infer(const Node *, NodeShape &) const = 0; +}; + +} // namespace loco + +#endif // __LOCO_SERVICE_SHAPE_INFERENCE_RULE_H__ diff --git a/compiler/loco/src/Service/GraphTestcase.h b/compiler/loco/src/Service/GraphTestcase.h new file mode 100644 index 0000000..27cd90e --- /dev/null +++ b/compiler/loco/src/Service/GraphTestcase.h @@ -0,0 +1,52 @@ +#ifndef __GRAPH_TESTCASE_H__ +#define __GRAPH_TESTCASE_H__ + +#include "loco/IR/Graph.h" + +enum class GraphCode +{ + Identity, +}; + +template class GraphTestcase; + +template <> class GraphTestcase final +{ +public: + GraphTestcase() + { + // Create a sample network + _graph = loco::make_graph(); + + // Create Nodes + pull_node = _graph->nodes()->create(); + push_node = _graph->nodes()->create(); + + push_node->from(pull_node); + + // Create Graph Input + auto graph_input = _graph->inputs()->create(); + + graph_input->name("input"); + graph_input->node(pull_node); + pull_node->index(0); + + // Create Graph Output + auto graph_output = _graph->outputs()->create(); + + graph_output->name("output"); + graph_output->node(push_node); + push_node->index(0); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr _graph; +}; + +#endif // __GRAPH_TESTCASE_H__ diff --git a/compiler/loco/src/Service/ShapeInference.cpp b/compiler/loco/src/Service/ShapeInference.cpp new file mode 100644 index 0000000..d8eb545 --- /dev/null +++ b/compiler/loco/src/Service/ShapeInference.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/ShapeInference.h" +#include "loco/IR/Algorithm.h" + +#include + +#include + +// +// Infrastructure +// +namespace +{ + +struct ShapeAnnotation : public loco::NodeAnnotation +{ +public: + ShapeAnnotation(const loco::NodeShape &shape) : _shape{shape} + { + // DO NOTHING + } + +public: + const loco::NodeShape &shape(void) const { return _shape; } + +private: + loco::NodeShape _shape; +}; + +} // namespace + +namespace loco +{ + +void ShapeInferenceSession::to(Graph *g) const +{ + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + { + if (_rule->recognize(node->dialect())) + { + loco::NodeShape shape; + + if (_rule->infer(node, shape)) + { + node->annot(stdex::make_unique(shape)); + } + } + } +} + +bool ShapeInference::known(const Node *node) { return node->annot() != nullptr; } + +NodeShape ShapeInference::get(const Node *node) +{ + assert(known(node)); + return node->annot()->shape(); +} + +void ShapeInference::erase(Node *node) { node->annot(nullptr); } + +} // namespace loco diff --git a/compiler/loco/src/Service/ShapeInference.test.cpp b/compiler/loco/src/Service/ShapeInference.test.cpp new file mode 100644 index 0000000..e10b988 --- /dev/null +++ b/compiler/loco/src/Service/ShapeInference.test.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/ShapeInference.h" +#include "GraphTestcase.h" + +#include + +#include + +// This test validates whether framework works as expected. +TEST(ShapeInferenceTest, framework) +{ + // Mock-up Shape Inference Rule + struct SampleShapeInferenceRule final : public loco::ShapeInferenceRule + { + public: + SampleShapeInferenceRule(std::vector *nodes) : _nodes{nodes} + { + // DO NOTHING + } + + public: + // Accept all the dialects + bool recognize(const loco::Dialect *) const final { return true; } + + bool infer(const loco::Node *node, loco::NodeShape &shape) const final + { + // Record the order of inference + _nodes->emplace_back(node); + + if (_nodes->size() != 1) + { + return false; + } + + // Set the first node as Tensor<1> + loco::TensorShape tensor_shape; + + tensor_shape.rank(1); + tensor_shape.dim(0) = 4; + + shape.set(tensor_shape); + + return true; + } + + private: + std::vector *_nodes; + }; + + GraphTestcase testcase; + + std::vector nodes; + + SampleShapeInferenceRule rule{&nodes}; + + loco::apply(&rule).to(testcase.graph()); + + // Framework SHOULD visit all the nodes + ASSERT_EQ(nodes.size(), 2); + // Framework SHOULD visit "pull" before "push" + ASSERT_EQ(nodes.at(0), testcase.pull_node); + ASSERT_EQ(nodes.at(1), testcase.push_node); + + // Framework SHOULD make an annotation if "rule" returns TRUE + ASSERT_TRUE(loco::shape_known(testcase.pull_node)); + ASSERT_EQ(loco::shape_get(testcase.pull_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.pull_node).as().rank(), 1); + ASSERT_EQ(loco::shape_get(testcase.pull_node).as().dim(0), 4); + + // Framework SHOULD NOT make any annotation if "rule" returns FALSE + ASSERT_FALSE(loco::shape_known(testcase.push_node)); +} diff --git a/compiler/loco/src/Service/ShapeInferenceRule.cpp b/compiler/loco/src/Service/ShapeInferenceRule.cpp new file mode 100644 index 0000000..7020858 --- /dev/null +++ b/compiler/loco/src/Service/ShapeInferenceRule.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/ShapeInferenceRule.h" + +// This file validates "ShapeInferenceRule.h". Please DO NOT remove this file. -- 2.7.4