[loco] Shape inference over Conv2D-related nodes (#6349)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 8 Aug 2019 01:47:06 +0000 (10:47 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 8 Aug 2019 01:47:06 +0000 (10:47 +0900)
CanonicalShapeInferenceRule is now able to infer the shape of nodes
related with Conv2D (such as FilterEncode).

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp

index b1e49af..f865395 100644 (file)
@@ -73,6 +73,16 @@ FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
   return FeatureShapeUpdater{&feature_shape};
 }
 
+loco::Window<2> window_of(const loco::FilterShape &filter_shape)
+{
+  loco::Window<2> window;
+
+  window.vertical(filter_shape.height().value());
+  window.horizontal(filter_shape.height().value());
+
+  return window;
+}
+
 class PlaneInference final
 {
 public:
@@ -186,7 +196,34 @@ public:
     return loco::NodeShape{tensor_shape};
   }
 
-  // TODO Support Conv2D
+  // CASE: Conv2D
+  loco::NodeShape visit(const loco::Conv2D *node) final
+  {
+    auto filter_shape = loco::shape_get(node->ker()).as<loco::FilterShape>();
+    auto filter_window = window_of(filter_shape);
+
+    PlaneInference infer_plane_shape;
+
+    infer_plane_shape.pad(node->pad());
+    infer_plane_shape.window(&filter_window);
+    infer_plane_shape.stride(node->stride());
+
+    auto input_feature_shape = loco::shape_get(node->ifm()).as<loco::FeatureShape>();
+    auto input_plane_shape = make_plane_shape(input_feature_shape);
+    auto output_plane_shape = infer_plane_shape(input_plane_shape);
+
+    loco::FeatureShape output_feature_shape;
+
+    // "COUNT" does not change
+    output_feature_shape.count() = input_feature_shape.count();
+    // "DEPTH" depends on # of filters
+    output_feature_shape.depth() = filter_shape.count();
+    // Update the height/width of output_feature_shape with that of output_plane_shape
+    update(output_feature_shape).with(output_plane_shape);
+
+    return loco::NodeShape{output_feature_shape};
+  }
+
   // TODO Support DepthwiseConv2D
   // TODO Support DepthwiseFilterEncode
 
@@ -231,7 +268,13 @@ public:
     return loco::NodeShape{node->encoder()->shape(input_node_shape.as<loco::TensorShape>())};
   }
 
-  // TODO Support FilterEncode
+  // CASE: FilterEncode
+  loco::NodeShape visit(const loco::FilterEncode *node) final
+  {
+    auto input_tensor_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+    return loco::NodeShape{node->encoder()->shape(input_tensor_shape)};
+  }
+
   // TODO Support FixedReshape
 
   // CASE: MaxPool2D