From 1bac9f7bca3af605431a5f2b9fc70dc870789016 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 11 Sep 2019 14:04:37 +0900 Subject: [PATCH] [exo-tflite] shape inference for TFLAveragePool2D (#7345) Adding shape inference for TFLAveragePool2D and two test cases. Signed-off-by: Hyun Sik Yoon --- .../src/Dialect/Service/TFLShapeInferenceRule.cpp | 51 +++++++++++- .../Dialect/Service/TFLShapeInferenceRule.test.cpp | 90 ++++++++++++++++++++++ compiler/exo-tflite/src/TestGraph.h | 79 +++++++++++++++++++ 3 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 compiler/exo-tflite/src/TestGraph.h diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp index e3ba39c..0f82185 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp @@ -22,11 +22,57 @@ #include "ShapeInference.h" +#include "Check.h" + #include namespace { +// Call this for TFLAvgPool2D and TFLMaxPool2D only +template loco::NodeShape infer_pool_2d_shape(const Pool2DType *node) +{ + EXO_ASSERT(loco::shape_known(node->value()), "Shape must be known"); + + auto ifm_shape = loco::shape_get(node->value()).template as(); + + uint32_t input_height = ifm_shape.dim(1).value(); + uint32_t input_width = ifm_shape.dim(2).value(); + uint32_t stride_height = node->stride()->h(); + uint32_t stride_width = node->stride()->w(); + uint32_t window_height = node->filter()->h(); + uint32_t window_width = node->filter()->w(); + uint32_t dilation_height = 1; // dilation for TFLAvgPool2D and TFLMaxPool2D is 1 + uint32_t dilation_width = 1; + uint32_t effective_window_height = dilation_height * (window_height - 1) + 1; + uint32_t effective_window_width = dilation_width * (window_width - 1) + 1; + + uint32_t output_height; + uint32_t output_width; + + if (node->padding() == locoex::Padding::VALID) + { + output_height = (input_height + stride_height - effective_window_height) / stride_height; + output_width = (input_width + stride_width - effective_window_width) / stride_width; + } + else if (node->padding() == locoex::Padding::SAME) + { + output_height = (input_height + stride_height - 1) / stride_height; + output_width = (input_width + stride_width - 1) / stride_width; + } + else + EXO_ASSERT(false, "Wrong padding type"); + + loco::TensorShape ofm_shape; + ofm_shape.rank(4); + ofm_shape.dim(0) = ifm_shape.dim(0); + ofm_shape.dim(1) = output_height; + ofm_shape.dim(2) = output_width; + ofm_shape.dim(3) = ifm_shape.dim(3); + + return loco::NodeShape{ofm_shape}; +} + /** * @brief Class to infer the shape of TFLNode * @@ -52,7 +98,10 @@ public: // TFLAdd - // TFLAveragePool2D + loco::NodeShape visit(const locoex::TFLAveragePool2D *node) final + { + return infer_pool_2d_shape(node); + } // TODO TFLConcatenation diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp index c5c375d..eca72e6 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp @@ -14,6 +14,8 @@ * limitations under the License. */ +#include "TestGraph.h" + #include "Dialect/IR/TFLNodes.h" #include "Dialect/IR/TFLDialect.h" #include "Dialect/Service/TFLShapeInferenceRule.h" @@ -81,3 +83,91 @@ TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu) ASSERT_EQ(shape.dim(1), 4); } } + +// based on the case shown in +// https://www.corvil.com/kb/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-tensorflow +TEST(TFLShapeInferenceRuleTest, avgpool2d_valid) +{ + exo::test::PullPushGraph test_graph; + auto pull = test_graph.pull; + { + pull->shape({1, 4, 3, 1}); + } + auto tfl_node = test_graph.middle_node; + { + tfl_node->filter()->h(2); + tfl_node->filter()->w(2); + tfl_node->stride()->h(2); + tfl_node->stride()->w(2); + tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE); + tfl_node->padding(locoex::Padding::VALID); + } + ASSERT_FALSE(loco::shape_known(tfl_node)); + + // shape inference + locoex::TFLShapeInferenceRule tfl_rule; + loco::CanonicalShapeInferenceRule canonical_rule; + loco::MultiDialectShapeInferenceRule rules; + + rules.bind(loco::CanonicalDialect::get(), &canonical_rule) + .bind(locoex::TFLDialect::get(), &tfl_rule); + + loco::apply(&rules).to(test_graph.g.get()); + + // Verify + { + ASSERT_TRUE(loco::shape_known(tfl_node)); + ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor); + + auto shape = loco::shape_get(tfl_node).as(); + ASSERT_EQ(shape.rank(), 4); + ASSERT_EQ(shape.dim(0).value(), 1); + ASSERT_EQ(shape.dim(1).value(), 2); + ASSERT_EQ(shape.dim(2).value(), 1); + ASSERT_EQ(shape.dim(3).value(), 1); + } +} + +TEST(TFLShapeInferenceRuleTest, avgpool2d_same) +{ + exo::test::PullPushGraph test_graph; + auto pull = test_graph.pull; + { + pull->shape({1, 4, 3, 1}); + } + + auto tfl_node = test_graph.middle_node; + { + tfl_node->filter()->h(2); + tfl_node->filter()->w(2); + tfl_node->stride()->h(2); + tfl_node->stride()->w(2); + tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE); + tfl_node->padding(locoex::Padding::SAME); + } + + ASSERT_FALSE(loco::shape_known(tfl_node)); + + // shape inference + locoex::TFLShapeInferenceRule tfl_rule; + loco::CanonicalShapeInferenceRule canonical_rule; + loco::MultiDialectShapeInferenceRule rules; + + rules.bind(loco::CanonicalDialect::get(), &canonical_rule) + .bind(locoex::TFLDialect::get(), &tfl_rule); + + loco::apply(&rules).to(test_graph.g.get()); + + // Verify + { + ASSERT_TRUE(loco::shape_known(tfl_node)); + ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor); + + auto shape = loco::shape_get(tfl_node).as(); + ASSERT_EQ(shape.rank(), 4); + ASSERT_EQ(shape.dim(0).value(), 1); + ASSERT_EQ(shape.dim(1).value(), 2); + ASSERT_EQ(shape.dim(2).value(), 2); + ASSERT_EQ(shape.dim(3).value(), 1); + } +} diff --git a/compiler/exo-tflite/src/TestGraph.h b/compiler/exo-tflite/src/TestGraph.h new file mode 100644 index 0000000..11903d3 --- /dev/null +++ b/compiler/exo-tflite/src/TestGraph.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TEST_GRAPH_H__ +#define __TEST_GRAPH_H__ + +#include "Dialect/IR/TFLNodes.h" + +#include + +#include + +namespace exo +{ +namespace test +{ + +// graph to build [Pull - some node of type T - Push] +template struct PullPushGraph +{ +public: + std::unique_ptr g; + loco::Pull *pull; + loco::Push *push; + T *middle_node; + + PullPushGraph() + { + // g = Pull - T - Push + g = loco::make_graph(); + + pull = g->nodes()->create(); + + middle_node = g->nodes()->create(); + { + setInput(); + } + + push = g->nodes()->create(); + { + push->from(middle_node); + } + + auto input = g->inputs()->create(); + { + input->name("input"); + loco::link(input, pull); + } + auto output = g->outputs()->create(); + { + output->name("output"); + loco::link(output, push); + } + } + +private: + void setInput(); // set the input of T +}; + +// setInput of TFL nodes +template <> void PullPushGraph::setInput() { middle_node->value(pull); } + +} // namespace test +} // namespace exo + +#endif // __TEST_GRAPH_H__ -- 2.7.4