[loco] Backward plane inference (#7458)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Tue, 17 Sep 2019 23:33:09 +0000 (08:33 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 17 Sep 2019 23:33:09 +0000 (08:33 +0900)
This commit introduces backward plane inference, which would be used for
canonical shape inference rule, especially for TransposedConv2D.

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp

index 591b024..5934007 100644 (file)
@@ -151,6 +151,50 @@ private:
   const loco::Stride<2> *_stride = nullptr;
 };
 
+template <> class PlaneInference<Direction::Backward> final
+{
+public:
+  PlaneShape operator()(const PlaneShape &in) const
+  {
+    assert(_pad != nullptr);
+    assert(_window != nullptr);
+    assert(_stride != nullptr);
+
+    uint32_t const input_height = in.height.value();
+    uint32_t const input_width = in.width.value();
+
+    uint32_t const vertical_padding = _pad->top() + _pad->bottom();
+    uint32_t const horizontal_padding = _pad->left() + _pad->right();
+
+    uint32_t const raw_window_height = _window->vertical();
+    uint32_t const raw_window_width = _window->horizontal();
+
+    // TODO Support "dilation"
+    uint32_t const effective_window_height = raw_window_height;
+    uint32_t const effective_window_width = raw_window_width;
+
+    uint32_t const vertical_stride = _stride->vertical();
+    uint32_t const horizontal_stride = _stride->horizontal();
+
+    PlaneShape res;
+
+    res.height = vertical_stride * (input_height - 1) + effective_window_height - vertical_padding;
+    res.width = horizontal_stride * (input_width - 1) + effective_window_width - horizontal_padding;
+
+    return res;
+  }
+
+public:
+  void pad(const loco::Padding2D *value) { _pad = value; }
+  void window(const loco::Window<2> *value) { _window = value; }
+  void stride(const loco::Stride<2> *value) { _stride = value; }
+
+private:
+  const loco::Padding2D *_pad = nullptr;
+  const loco::Window<2> *_window = nullptr;
+  const loco::Stride<2> *_stride = nullptr;
+};
+
 /**
  * There are two possible maintenance policies.
  * - Introduce a new canonical node first, and then extend this algorithm later