From 3e240d6756fa94f05e94ba5abc877d986be99b2a Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 7 Aug 2019 09:44:41 +0900 Subject: [PATCH] [loco] MaxPool2D Shape Inference (#6287) CanonicalShapeInferenceRule is now able to infer the shape of MaxPool2D nodes. Signed-off-by: Jonghyun Park --- .../src/Service/CanonicalShapeInferenceRule.cpp | 21 ++++++++- .../Service/CanonicalShapeInferenceRule.test.cpp | 37 +++++++++++++++ compiler/loco/src/Service/GraphTestcase.h | 53 ++++++++++++++++++++++ 3 files changed, 110 insertions(+), 1 deletion(-) diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index 3620273..df5abf9 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -219,7 +219,26 @@ public: // TODO Support FilterEncode // TODO Support FixedReshape - // TODO Support MaxPool2D + + // CASE: MaxPool2D + loco::NodeShape visit(const loco::MaxPool2D *node) final + { + PlaneInference infer_plane_shape; + + infer_plane_shape.pad(node->pad()); + infer_plane_shape.window(node->window()); + infer_plane_shape.stride(node->stride()); + + auto input_feature_shape = loco::shape_get(node->ifm()).as(); + auto input_plane_shape = make_plane_shape(input_feature_shape); + auto output_plane_shape = infer_plane_shape(input_plane_shape); + auto output_feature_shape = input_feature_shape; // MaxPool2D does not change count/depth + + // Update the height/width of output_feature_shape with that of output_plane_shape + update(output_feature_shape).with(output_plane_shape); + + return loco::NodeShape{output_feature_shape}; + } // CASE: Push loco::NodeShape visit(const loco::Push *node) final diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp index c10cea5..670fffb 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp @@ -128,3 +128,40 @@ TEST(CanonicalShapeInferenceRuleTest, avgpool2d) ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as().height(), 4); ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as().width(), 2); } + +TEST(CanonicalShapeInferenceRuleTest, maxpool2d) +{ + using namespace loco; + + // Create a sample network + GraphTestcase testcase; + + auto perm = make_NHWC_perm(); + + testcase.pull_node->shape({1, 8, 4, 3}); + + testcase.encode_node->encoder(stdex::make_unique>(perm)); + + testcase.maxpool2d_node->window()->vertical(2); + testcase.maxpool2d_node->window()->horizontal(2); + + testcase.maxpool2d_node->stride()->vertical(2); + testcase.maxpool2d_node->stride()->horizontal(2); + + testcase.decode_node->decoder(stdex::make_unique>(perm)); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + // + // NOTE MaxPool2D testcase assumes NHWC layout + ASSERT_TRUE(loco::shape_known(testcase.maxpool2d_node)); + ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).domain(), loco::Domain::Feature); + ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as().count(), 1); + ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as().depth(), 3); + ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as().height(), 4); + ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as().width(), 2); +} diff --git a/compiler/loco/src/Service/GraphTestcase.h b/compiler/loco/src/Service/GraphTestcase.h index 0145b7f..4e36db6 100644 --- a/compiler/loco/src/Service/GraphTestcase.h +++ b/compiler/loco/src/Service/GraphTestcase.h @@ -12,6 +12,7 @@ enum class GraphCode Relu, FeatureCodec, AvgPool2D, + MaxPool2D, }; template class GraphTestcase; @@ -208,6 +209,58 @@ private: std::unique_ptr _graph; }; +template <> class GraphTestcase final +{ +public: + GraphTestcase() + { + using namespace loco; + + // Create a sample network + _graph = make_graph(); + + // Create Graph Input/Output + auto graph_input = _graph->inputs()->create(); + auto graph_output = _graph->outputs()->create(); + + graph_input->name("input"); + graph_output->name("output"); + + // Create and connect nodes + pull_node = _graph->nodes()->create(); + pull_node->index(0); + + encode_node = _graph->nodes()->create(); + encode_node->input(pull_node); + + maxpool2d_node = _graph->nodes()->create(); + maxpool2d_node->ifm(encode_node); + + decode_node = _graph->nodes()->create(); + decode_node->input(maxpool2d_node); + + push_node = _graph->nodes()->create(); + push_node->index(0); + push_node->from(decode_node); + + // Create a link between input/output and corresponding nodes + loco::link(graph_input, pull_node); + loco::link(graph_output, push_node); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::FeatureEncode *encode_node = nullptr; + loco::MaxPool2D *maxpool2d_node = nullptr; + loco::FeatureDecode *decode_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr _graph; +}; + namespace { -- 2.7.4