[loco] MaxPool2D Shape Inference (#6287)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 7 Aug 2019 00:44:41 +0000 (09:44 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 7 Aug 2019 00:44:41 +0000 (09:44 +0900)
CanonicalShapeInferenceRule is now able to infer the shape of MaxPool2D
nodes.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
compiler/loco/src/Service/GraphTestcase.h

index 3620273..df5abf9 100644 (file)
@@ -219,7 +219,26 @@ public:
 
   // TODO Support FilterEncode
   // TODO Support FixedReshape
-  // TODO Support MaxPool2D
+
+  // CASE: MaxPool2D
+  loco::NodeShape visit(const loco::MaxPool2D *node) final
+  {
+    PlaneInference infer_plane_shape;
+
+    infer_plane_shape.pad(node->pad());
+    infer_plane_shape.window(node->window());
+    infer_plane_shape.stride(node->stride());
+
+    auto input_feature_shape = loco::shape_get(node->ifm()).as<loco::FeatureShape>();
+    auto input_plane_shape = make_plane_shape(input_feature_shape);
+    auto output_plane_shape = infer_plane_shape(input_plane_shape);
+    auto output_feature_shape = input_feature_shape; // MaxPool2D does not change count/depth
+
+    // Update the height/width of output_feature_shape with that of output_plane_shape
+    update(output_feature_shape).with(output_plane_shape);
+
+    return loco::NodeShape{output_feature_shape};
+  }
 
   // CASE: Push
   loco::NodeShape visit(const loco::Push *node) final
index c10cea5..670fffb 100644 (file)
@@ -128,3 +128,40 @@ TEST(CanonicalShapeInferenceRuleTest, avgpool2d)
   ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().height(), 4);
   ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().width(), 2);
 }
+
+TEST(CanonicalShapeInferenceRuleTest, maxpool2d)
+{
+  using namespace loco;
+
+  // Create a sample network
+  GraphTestcase<GraphCode::MaxPool2D> testcase;
+
+  auto perm = make_NHWC_perm<Domain::Feature>();
+
+  testcase.pull_node->shape({1, 8, 4, 3});
+
+  testcase.encode_node->encoder(stdex::make_unique<PermutingEncoder<Domain::Feature>>(perm));
+
+  testcase.maxpool2d_node->window()->vertical(2);
+  testcase.maxpool2d_node->window()->horizontal(2);
+
+  testcase.maxpool2d_node->stride()->vertical(2);
+  testcase.maxpool2d_node->stride()->horizontal(2);
+
+  testcase.decode_node->decoder(stdex::make_unique<PermutingDecoder<Domain::Feature>>(perm));
+
+  // Run Inference
+  loco::CanonicalShapeInferenceRule rule;
+
+  loco::apply(&rule).to(testcase.graph());
+
+  // Verify!
+  //
+  // NOTE MaxPool2D testcase assumes NHWC layout
+  ASSERT_TRUE(loco::shape_known(testcase.maxpool2d_node));
+  ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).domain(), loco::Domain::Feature);
+  ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().count(), 1);
+  ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().depth(), 3);
+  ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().height(), 4);
+  ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().width(), 2);
+}
index 0145b7f..4e36db6 100644 (file)
@@ -12,6 +12,7 @@ enum class GraphCode
   Relu,
   FeatureCodec,
   AvgPool2D,
+  MaxPool2D,
 };
 
 template <GraphCode Code> class GraphTestcase;
@@ -208,6 +209,58 @@ private:
   std::unique_ptr<loco::Graph> _graph;
 };
 
+template <> class GraphTestcase<GraphCode::MaxPool2D> final
+{
+public:
+  GraphTestcase()
+  {
+    using namespace loco;
+
+    // Create a sample network
+    _graph = make_graph();
+
+    // Create Graph Input/Output
+    auto graph_input = _graph->inputs()->create();
+    auto graph_output = _graph->outputs()->create();
+
+    graph_input->name("input");
+    graph_output->name("output");
+
+    // Create and connect nodes
+    pull_node = _graph->nodes()->create<Pull>();
+    pull_node->index(0);
+
+    encode_node = _graph->nodes()->create<FeatureEncode>();
+    encode_node->input(pull_node);
+
+    maxpool2d_node = _graph->nodes()->create<MaxPool2D>();
+    maxpool2d_node->ifm(encode_node);
+
+    decode_node = _graph->nodes()->create<FeatureDecode>();
+    decode_node->input(maxpool2d_node);
+
+    push_node = _graph->nodes()->create<loco::Push>();
+    push_node->index(0);
+    push_node->from(decode_node);
+
+    // Create a link between input/output and corresponding nodes
+    loco::link(graph_input, pull_node);
+    loco::link(graph_output, push_node);
+  }
+
+public:
+  loco::Graph *graph() { return _graph.get(); }
+
+  loco::Pull *pull_node = nullptr;
+  loco::FeatureEncode *encode_node = nullptr;
+  loco::MaxPool2D *maxpool2d_node = nullptr;
+  loco::FeatureDecode *decode_node = nullptr;
+  loco::Push *push_node = nullptr;
+
+private:
+  std::unique_ptr<loco::Graph> _graph;
+};
+
 namespace
 {