[loco] Implement DepthwiseConv2D shape inference (#6569)
author채성우/On-Device Lab(SR)/Engineer/삼성전자 <sw4670.chae@samsung.com>
Tue, 20 Aug 2019 06:42:07 +0000 (15:42 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 20 Aug 2019 06:42:07 +0000 (15:42 +0900)
* [loco] Implement DepthwiseConv2D shape inference

This commit extends CanonicalShapeInferenceRule to accept
DepthwiseConv2D nodes.

Signed-off-by: seongwoo <sw4670.chae@samsung.com>
* fix a typo.

compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp

index 440a35d..d561878 100644 (file)
@@ -83,6 +83,16 @@ loco::Window<2> window_of(const loco::FilterShape &filter_shape)
   return window;
 }
 
+loco::Window<2> window_of(const loco::DepthwiseFilterShape &depthwise_filter_shape)
+{
+  loco::Window<2> window;
+
+  window.vertical(depthwise_filter_shape.height().value());
+  window.horizontal(depthwise_filter_shape.width().value());
+
+  return window;
+}
+
 class PlaneInference final
 {
 public:
@@ -236,7 +246,35 @@ public:
     return loco::NodeShape{output_feature_shape};
   }
 
-  // TODO Support DepthwiseConv2D
+  // CASE: DepthwiseConv2D
+  loco::NodeShape visit(const loco::DepthwiseConv2D *node) final
+  {
+    auto depthwise_filter_shape = loco::shape_get(node->ker()).as<loco::DepthwiseFilterShape>();
+    auto dpethwise_filter_window = window_of(depthwise_filter_shape);
+
+    PlaneInference infer_plane_shape;
+
+    infer_plane_shape.pad(node->pad());
+    infer_plane_shape.window(&dpethwise_filter_window);
+    infer_plane_shape.stride(node->stride());
+
+    auto input_feature_shape = loco::shape_get(node->ifm()).as<loco::FeatureShape>();
+    auto input_plane_shape = make_plane_shape(input_feature_shape);
+    auto output_plane_shape = infer_plane_shape(input_plane_shape);
+
+    loco::FeatureShape output_feature_shape;
+
+    // "COUNT" does not change
+    output_feature_shape.count() = input_feature_shape.count();
+    // "DEPTH" depends on [in_channels * channel_multiplier] of filters
+    output_feature_shape.depth() = loco::Dimension(depthwise_filter_shape.depth().value() *
+                                                   depthwise_filter_shape.multiplier().value());
+    // Update the height/width of output_feature_shape with that of output_plane_shape
+    update(output_feature_shape).with(output_plane_shape);
+
+    return loco::NodeShape{output_feature_shape};
+  }
+
   // CASE: DepthwiseFilterEncode
   loco::NodeShape visit(const loco::DepthwiseFilterEncode *node) final
   {
index ac07b3f..f69d971 100644 (file)
@@ -148,6 +148,37 @@ TEST(CanonicalShapeInferenceRuleTest, avgpool2d)
   ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().width(), 2);
 }
 
+TEST(CanonicalShapeInferenceRuleTest, depthwiseconv2d)
+{
+  using namespace loco;
+
+  // Create a sample network
+  GraphTestcase<GraphCode::DepthwiseConv2D> testcase;
+
+  testcase.pull_node->shape({1, 4, 4, 3});
+
+  testcase.const_node->dtype(loco::DataType::FLOAT32);
+  testcase.const_node->shape({2, 2, 3, 2});
+
+  testcase.depthwiseconv2d_node->stride()->vertical(1);
+  testcase.depthwiseconv2d_node->stride()->horizontal(1);
+
+  // Run Inference
+  loco::CanonicalShapeInferenceRule rule;
+
+  loco::apply(&rule).to(testcase.graph());
+
+  // Verify!
+  //
+  // NOTE DepthwiseConv2D testcase assumes NHWC layout
+  ASSERT_TRUE(loco::shape_known(testcase.depthwiseconv2d_node));
+  ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).domain(), loco::Domain::Feature);
+  ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().count(), 1);
+  ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().depth(), 6);
+  ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().height(), 3);
+  ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().width(), 3);
+}
+
 TEST(CanonicalShapeInferenceRuleTest, maxpool2d)
 {
   using namespace loco;