From: 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 Date: Fri, 2 Aug 2019 04:28:35 +0000 (+0900) Subject: [loco] Add CanonicalShapeInferenceRule (#6126) X-Git-Tag: submit/tizen/20190809.050447~227 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=29687afc5ea2c0da7270379f65ccfab721b4ce5d;p=platform%2Fcore%2Fml%2Fnnfw.git [loco] Add CanonicalShapeInferenceRule (#6126) This commit adds CanonicalShapeInferenceRule class with minimal implementation. The current implementation supports shape inference only for Push and Pull nodes. Signed-off-by: Jonghyun Park --- diff --git a/compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h b/compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h new file mode 100644 index 0000000..3ef6fee --- /dev/null +++ b/compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCO_SERVICE_CANONICAL_SHAPE_INFERENCE_RULE_H__ +#define __LOCO_SERVICE_CANONICAL_SHAPE_INFERENCE_RULE_H__ + +#include "loco/Service/ShapeInferenceRule.h" + +namespace loco +{ + +/** + * @brief Shape inference rule for canonical dialect + */ +struct CanonicalShapeInferenceRule final : public ShapeInferenceRule +{ + bool recognize(const Dialect *) const final; + bool infer(const Node *, NodeShape &) const final; +}; + +} // namespace loco + +#endif // __LOCO_SERVICE_CANONICAL_SHAPE_INFERENCE_RULE_H__ diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp new file mode 100644 index 0000000..77d887e --- /dev/null +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/CanonicalShapeInferenceRule.h" +#include "loco/Service/ShapeInference.h" + +#include +#include +#include + +#include + +namespace +{ + +/** + * There are two possible maintenance policies. + * - Introduce a new canonical node first, and then extend this algorithm later + * - Introduce a new canonical node and extend this algorithm at the same time + * + * The current implementation assumes the former one (for historical reason). + * + * TODO Evaluate the impact of the latter one + * + * NOTE "Forward" means that this algorithm computes the ouput shape from inputs shapes + */ +class ForwardShapeInferenceAlgorithm final : public loco::CanonicalNodeVisitor +{ +public: + // TODO Support AvgPool2D + // TODO Support BiasEncode + // TODO Support ConstGen + // TODO Support Conv2D + // TODO Support DepthwiseConv2D + // TODO Support DepthwiseFilterEncode + // TODO Support EltwiseAdd + // TODO Support EltwiseMul + // TODO Support Forward + // TODO Support FeatureBiasAdd + // TODO Support FeatureDecode + // TODO Support FeatureEncode + // TODO Support FilterEncode + // TODO Support FixedReshape + // TODO Support MaxPool2D + + // CASE: Push + loco::NodeShape visit(const loco::Push *node) final + { + assert(loco::shape_known(node->from())); + return loco::shape_get(node->from()); + } + + // CASE: Pull + loco::NodeShape visit(const loco::Pull *node) final + { + // Build a tensor shape from "Pull" node + loco::TensorShape tensor_shape; + + tensor_shape.rank(node->rank()); + for (uint32_t axis = 0; axis < node->rank(); ++axis) + { + tensor_shape.dim(axis) = node->dim(axis); + } + + return loco::NodeShape{tensor_shape}; + } + + // TODO Support ReLU + // TODO Support ReLU6 + // TODO Support TensorBiasAdd + // TODO SUpport TensorConcat +}; + +} // namespace + +namespace loco +{ + +bool CanonicalShapeInferenceRule::recognize(const Dialect *d) const +{ + return CanonicalDialect::get() == d; +} + +bool CanonicalShapeInferenceRule::infer(const Node *node, NodeShape &shape) const +{ + assert(node->dialect() == loco::CanonicalDialect::get()); + assert(dynamic_cast(node) != nullptr); + + ForwardShapeInferenceAlgorithm alg; + shape = dynamic_cast(node)->accept(&alg); + + return true; +} + +} // namespace loco diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp new file mode 100644 index 0000000..c0e3b3f --- /dev/null +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/CanonicalShapeInferenceRule.h" +#include "loco/Service/ShapeInference.h" + +#include "GraphTestcase.h" + +#include + +#include + +TEST(CanonicalShapeInferenceRuleTest, minimal) +{ + // Create a sample network + GraphTestcase testcase; + + testcase.pull_node->shape({1, 2, 3, 4}); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(testcase.push_node)); + ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().rank(), 4); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(0), 1); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(1), 2); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(2), 3); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(3), 4); +}