From 71491c08996ff885ece10b3dca34e64fad9a0365 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 29 Aug 2019 17:19:11 +0900 Subject: [PATCH] [exo-tflite] Introducing TFLShapeInference (#6991) * [exo-tflite] Introducing TFLShapeInference This adds TFLShapeInference that infer the shape of TFL nodes and TFLShapeAnnot that passes shape from canonical node. Signed-off-by: Hyun Sik Yoon * adding missing graph input * fix typo * remove Forward word from internal class naming --- .../exo-tflite/src/Dialect/Service/TFLShapeAnnot.h | 43 +++++++++++ .../src/Dialect/Service/TFLShapeInferenceRule.cpp | 78 +++++++++++++++++++ .../src/Dialect/Service/TFLShapeInferenceRule.h | 33 ++++++++ .../Dialect/Service/TFLShapeInferenceRule.test.cpp | 90 ++++++++++++++++++++++ 4 files changed, 244 insertions(+) create mode 100644 compiler/exo-tflite/src/Dialect/Service/TFLShapeAnnot.h create mode 100644 compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp create mode 100644 compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.h create mode 100644 compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeAnnot.h b/compiler/exo-tflite/src/Dialect/Service/TFLShapeAnnot.h new file mode 100644 index 0000000..32d7e82 --- /dev/null +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeAnnot.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_SERVICE_TFLSHAPE_ANNOT__ +#define __LOCOEX_SERVICE_TFLSHAPE_ANNOT__ + +#include +#include + +namespace locoex +{ + +/** + * @brief Class to annotate shape to a TFL node + */ +struct TFLShapeAnnot : public loco::NodeAnnotation +{ +public: + TFLShapeAnnot(const loco::TensorShape &shape) : _shape(shape) {} + +public: + const loco::TensorShape shape(void) const { return _shape; } + +private: + const loco::TensorShape _shape; +}; + +} // namespace locoex + +#endif // __LOCOEX_SERVICE_TFLSHAPE_ANNOT__ diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp new file mode 100644 index 0000000..92615a8 --- /dev/null +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFLShapeInferenceRule.h" +#include "TFLShapeAnnot.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +#include "ShapeInference.h" + +#include + +namespace +{ + +/** + * @brief Class to infer the shape of TFLNode + * + * @note All TFLNode's inputs and outouts are always loco::Domain::Tensor + */ +class ShapeInferenceAlgorithm final : public locoex::TFLNodeVisitor +{ +public: + loco::NodeShape visit(const locoex::TFLNode *node) final + { + if (loco::shape_known(node)) // if shape was already inferred by inference rule + { + assert(loco::shape_get(node).domain() == loco::Domain::Tensor); + + return loco::shape_get(node); + } + else + { // getting Shape data that was annotated while converting canonical node to TFLNode + auto shape_hint = node->annot(); + assert(shape_hint != nullptr); + + return shape_hint->shape(); + } + } +}; + +} // namespace + +namespace locoex +{ + +bool TFLShapeInferenceRule::recognize(const loco::Dialect *d) const +{ + return TFLDialect::get() == d; +} + +bool TFLShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const +{ + assert(node->dialect() == TFLDialect::get()); + assert(dynamic_cast(node) != nullptr); + + ShapeInferenceAlgorithm alg; + shape = dynamic_cast(node)->accept(&alg); + + return true; +} + +} // namespace locoex diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.h b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.h new file mode 100644 index 0000000..434a145 --- /dev/null +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__ +#define __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__ + +#include + +namespace locoex +{ + +struct TFLShapeInferenceRule final : public loco::ShapeInferenceRule +{ + bool recognize(const loco::Dialect *) const final; + bool infer(const loco::Node *, loco::NodeShape &) const final; +}; + +} // namespace locoex + +#endif // __LOCOEX_SERVICE_TFLSHAPE_INFERENCE_RULE_H__ diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp new file mode 100644 index 0000000..de2e966 --- /dev/null +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/Service/TFLShapeAnnot.h" +#include "Dialect/Service/TFLShapeInferenceRule.h" + +#include +#include + +#include + +#include + +TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu) +{ + // Create a simple network + auto g = loco::make_graph(); + + auto pull_node = g->nodes()->create(); + + auto tfl_node = g->nodes()->create(); + tfl_node->input(pull_node); + + auto push_node = g->nodes()->create(); + push_node->from(tfl_node); + + auto input = g->inputs()->create(); + { + input->name("input"); + loco::link(input, pull_node); + } + auto output = g->outputs()->create(); + { + output->name("output"); + loco::link(output, push_node); + } + + // pre-check + ASSERT_FALSE(loco::shape_known(tfl_node)); + + // scenario. + // step 1. add annotation and run shape inference. + // TFLShapeInference will get shape info from annotated data + // step 2. then, run shape inference again + + // step 1. + loco::TensorShape ts; + { + ts.rank(2); + ts.dim(0) = 1; + ts.dim(1) = 3; + } + auto shape_annot = stdex::make_unique(ts); + tfl_node->annot(std::move(shape_annot)); + + locoex::TFLShapeInferenceRule tfl_rule; + loco::apply(&tfl_rule).to(g.get()); + + // Verify + auto check_shape = [](locoex::TFLRelu *tfl_node) { + ASSERT_TRUE(loco::shape_known(tfl_node)); + ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor); + + auto shape = loco::shape_get(tfl_node).as(); + ASSERT_EQ(shape.rank(), 2); + ASSERT_EQ(shape.dim(0), 1); + ASSERT_EQ(shape.dim(1), 3); + }; + + check_shape(tfl_node); + + // step 2. + loco::apply(&tfl_rule).to(g.get()); + + check_shape(tfl_node); +} -- 2.7.4