[loco] Implement AvgPool2D shape inference (#6251)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 6 Aug 2019 08:16:14 +0000 (17:16 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 6 Aug 2019 08:16:14 +0000 (17:16 +0900)
* [loco] Implement AvgPool2D shape inference

This commit extends CanonicalShapeInferenceRule to accept AvgPool2D nodes.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
* Update testcase

* Introduce FeatureShapeUpdater

compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
compiler/loco/src/Service/GraphTestcase.h

index 366c6e5..3620273 100644 (file)
 namespace
 {
 
+struct PlaneShape
+{
+  loco::Dimension height;
+  loco::Dimension width;
+};
+
+PlaneShape make_plane_shape(const loco::FeatureShape &feature_shape)
+{
+  PlaneShape plane_shape;
+
+  plane_shape.height = feature_shape.height();
+  plane_shape.width = feature_shape.width();
+
+  return plane_shape;
+}
+
+class FeatureShapeUpdater final
+{
+public:
+  FeatureShapeUpdater(loco::FeatureShape *ptr) : _feature_shape_ptr{ptr}
+  {
+    // DO NOTHING
+  }
+
+public:
+  void with(const PlaneShape &plane_shape) const
+  {
+    _feature_shape_ptr->height() = plane_shape.height;
+    _feature_shape_ptr->width() = plane_shape.width;
+  }
+
+private:
+  loco::FeatureShape *_feature_shape_ptr;
+};
+
+/**
+ * HOW TO USE
+ *
+ *   loco::FeatureShape feature_shape = ...;
+ *
+ *   update(feature_shape).with(...)
+ */
+FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
+{
+  return FeatureShapeUpdater{&feature_shape};
+}
+
+class PlaneInference final
+{
+public:
+  PlaneShape operator()(const PlaneShape &in) const
+  {
+    assert(_pad != nullptr);
+    assert(_window != nullptr);
+    assert(_stride != nullptr);
+
+    uint32_t const raw_input_height = in.height.value();
+    uint32_t const raw_input_width = in.width.value();
+
+    uint32_t const raw_window_height = _window->vertical();
+    uint32_t const raw_window_width = _window->horizontal();
+
+    uint32_t const vertical_padding = _pad->top() + _pad->bottom();
+    uint32_t const horizontal_padding = _pad->left() + _pad->right();
+
+    uint32_t const effective_input_height = raw_input_height + vertical_padding;
+    uint32_t const effective_input_width = raw_input_width + horizontal_padding;
+
+    // NOTE To support "dilation" later
+    uint32_t const effective_window_height = raw_window_height;
+    uint32_t const effective_window_width = raw_window_width;
+
+    uint32_t const vertical_stride = _stride->vertical();
+    uint32_t const horizontal_stride = _stride->horizontal();
+
+    assert((effective_input_height - effective_window_height) % vertical_stride == 0);
+    assert((effective_input_width - effective_window_width) % horizontal_stride == 0);
+
+    PlaneShape res;
+
+    res.height = (effective_input_height - effective_window_height) / vertical_stride + 1;
+    res.width = (effective_input_width - effective_window_width) / horizontal_stride + 1;
+
+    return res;
+  }
+
+public:
+  void pad(const loco::Pad<2> *value) { _pad = value; }
+  void window(const loco::Window<2> *value) { _window = value; }
+  void stride(const loco::Stride<2> *value) { _stride = value; }
+
+public:
+  const loco::Pad<2> *_pad = nullptr;
+  const loco::Window<2> *_window = nullptr;
+  const loco::Stride<2> *_stride = nullptr;
+};
+
 loco::NodeShape eltwise_binary_node_shape(const loco::Node *node)
 {
   // This helper works only for binary node.
@@ -53,7 +150,26 @@ loco::NodeShape eltwise_binary_node_shape(const loco::Node *node)
 class ForwardShapeInferenceAlgorithm final : public loco::CanonicalNodeVisitor<loco::NodeShape>
 {
 public:
-  // TODO Support AvgPool2D
+  // CASE: AvgPool2D
+  loco::NodeShape visit(const loco::AvgPool2D *node) final
+  {
+    PlaneInference infer_plane_shape;
+
+    infer_plane_shape.pad(node->pad());
+    infer_plane_shape.window(node->window());
+    infer_plane_shape.stride(node->stride());
+
+    auto input_feature_shape = loco::shape_get(node->ifm()).as<loco::FeatureShape>();
+    auto input_plane_shape = make_plane_shape(input_feature_shape);
+    auto output_plane_shape = infer_plane_shape(input_plane_shape);
+    auto output_feature_shape = input_feature_shape; // AvgPool2D does not change count/depth
+
+    // Update the height/width of output_feature_shape with that of output_plane_shape
+    update(output_feature_shape).with(output_plane_shape);
+
+    return loco::NodeShape{output_feature_shape};
+  }
+
   // TODO Support BiasEncode
   // TODO Support ConstGen
   // TODO Support Conv2D
index df0e33a..c10cea5 100644 (file)
@@ -91,3 +91,40 @@ TEST(CanonicalShapeInferenceRuleTest, feature_codec)
   ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(2), 3);
   ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(3), 4);
 }
+
+TEST(CanonicalShapeInferenceRuleTest, avgpool2d)
+{
+  using namespace loco;
+
+  // Create a sample network
+  GraphTestcase<GraphCode::AvgPool2D> testcase;
+
+  auto perm = make_NHWC_perm<Domain::Feature>();
+
+  testcase.pull_node->shape({1, 8, 4, 3});
+
+  testcase.encode_node->encoder(stdex::make_unique<PermutingEncoder<Domain::Feature>>(perm));
+
+  testcase.avgpool2d_node->window()->vertical(2);
+  testcase.avgpool2d_node->window()->horizontal(2);
+
+  testcase.avgpool2d_node->stride()->vertical(2);
+  testcase.avgpool2d_node->stride()->horizontal(2);
+
+  testcase.decode_node->decoder(stdex::make_unique<PermutingDecoder<Domain::Feature>>(perm));
+
+  // Run Inference
+  loco::CanonicalShapeInferenceRule rule;
+
+  loco::apply(&rule).to(testcase.graph());
+
+  // Verify!
+  //
+  // NOTE AvgPool2D testcase assumes NHWC layout
+  ASSERT_TRUE(loco::shape_known(testcase.avgpool2d_node));
+  ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).domain(), loco::Domain::Feature);
+  ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().count(), 1);
+  ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().depth(), 3);
+  ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().height(), 4);
+  ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().width(), 2);
+}
index a192323..4e28539 100644 (file)
@@ -11,6 +11,7 @@ enum class GraphCode
   Identity,
   Relu,
   FeatureCodec,
+  AvgPool2D,
 };
 
 template <GraphCode Code> class GraphTestcase;
@@ -155,4 +156,74 @@ private:
   std::unique_ptr<loco::Graph> _graph;
 };
 
+template <> class GraphTestcase<GraphCode::AvgPool2D> final
+{
+public:
+  GraphTestcase()
+  {
+    using namespace loco;
+
+    // Create a sample network
+    _graph = make_graph();
+
+    // Create Graph Input/Output
+    auto graph_input = _graph->inputs()->create();
+    auto graph_output = _graph->outputs()->create();
+
+    graph_input->name("input");
+    graph_output->name("output");
+
+    // Create and connect nodes
+    pull_node = _graph->nodes()->create<Pull>();
+    pull_node->index(0);
+
+    encode_node = _graph->nodes()->create<FeatureEncode>();
+    encode_node->input(pull_node);
+
+    avgpool2d_node = _graph->nodes()->create<AvgPool2D>();
+    avgpool2d_node->ifm(encode_node);
+
+    decode_node = _graph->nodes()->create<FeatureDecode>();
+    decode_node->input(avgpool2d_node);
+
+    push_node = _graph->nodes()->create<loco::Push>();
+    push_node->index(0);
+    push_node->from(decode_node);
+
+    // Create a link between input/output and corresponding nodes
+    loco::link(graph_input, pull_node);
+    loco::link(graph_output, push_node);
+  }
+
+public:
+  loco::Graph *graph() { return _graph.get(); }
+
+  loco::Pull *pull_node = nullptr;
+  loco::FeatureEncode *encode_node = nullptr;
+  loco::AvgPool2D *avgpool2d_node = nullptr;
+  loco::FeatureDecode *decode_node = nullptr;
+  loco::Push *push_node = nullptr;
+
+private:
+  std::unique_ptr<loco::Graph> _graph;
+};
+
+namespace
+{
+
+template <loco::Domain D> loco::Permutation<D> make_NHWC_perm(void);
+
+template <> loco::Permutation<loco::Domain::Feature> make_NHWC_perm(void)
+{
+  loco::Permutation<loco::Domain::Feature> perm;
+
+  perm[loco::FeatureAxis::Count] = 0;
+  perm[loco::FeatureAxis::Height] = 1;
+  perm[loco::FeatureAxis::Width] = 2;
+  perm[loco::FeatureAxis::Depth] = 3;
+
+  return perm;
+}
+
+} // namespace
 #endif // __GRAPH_TESTCASE_H__