[enco/frontend] Add padding functions for conv and avgpool (#2243)
author윤현식/동작제어Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Tue, 13 Nov 2018 06:08:44 +0000 (15:08 +0900)
committer박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 13 Nov 2018 06:08:44 +0000 (15:08 +0900)
Add two more padding functions: One is for conv2d, another for avgpool.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
contrib/enco/frontend/tflite/src/Frontend.cpp

index 195f01d..8e77697 100644 (file)
@@ -308,11 +308,11 @@ private:
 namespace
 {
 
-coco::Padding2D get_padding(const tflite::Conv2DOptions *options, const tensor::Shape ishape,
-                            const tensor::Shape kshape)
+coco::Padding2D get_padding(const tensor::Shape &ifm_shape, const int kernel_w, const int kernel_h,
+                            tflite::Padding padding, int stride_w, int stride_h,
+                            int dilation_w_factor, int dilation_h_factor)
 {
-  assert(ishape.rank() == 4);
-  assert(kshape.rank() == 4);
+  assert(ifm_shape.rank() == 4);
 
   auto compute_padding = [](tflite::Padding padding, int stride, int dilation_rate, int in_size,
                             int filter_size) {
@@ -327,12 +327,9 @@ coco::Padding2D get_padding(const tflite::Conv2DOptions *options, const tensor::
     return value > 0 ? value : 0;
   };
 
-  tflite::Padding padding = options->padding();
   // dim(2), dim(3) are from order of NCHW
-  int padding_w = compute_padding(padding, options->stride_w(), options->dilation_w_factor(),
-                                  ishape.dim(3), kshape.dim(3));
-  int padding_h = compute_padding(padding, options->stride_h(), options->dilation_h_factor(),
-                                  ishape.dim(2), kshape.dim(2));
+  int padding_w = compute_padding(padding, stride_w, dilation_w_factor, ifm_shape.dim(3), kernel_w);
+  int padding_h = compute_padding(padding, stride_h, dilation_h_factor, ifm_shape.dim(2), kernel_h);
 
   coco::Padding2D coco_padding;
   coco_padding.top(padding_h).bottom(padding_h).left(padding_w).right(padding_w);
@@ -340,6 +337,21 @@ coco::Padding2D get_padding(const tflite::Conv2DOptions *options, const tensor::
   return coco_padding;
 }
 
+coco::Padding2D pool2D_padding(const tflite::Pool2DOptions *options, const tensor::Shape &ifm_shape,
+                               const int filter_w, const int filter_h)
+{
+  return get_padding(ifm_shape, filter_w, filter_h, options->padding(), options->stride_w(),
+                     options->stride_h(), 1, 1);
+}
+
+coco::Padding2D conv2D_padding(const tflite::Conv2DOptions *options, const tensor::Shape &ifm_shape,
+                               const tensor::Shape &kernel_shape)
+{
+  return get_padding(ifm_shape, kernel_shape.dim(3), kernel_shape.dim(2), /* kernel layout: NCHW */
+                     options->padding(), options->stride_w(), options->stride_h(),
+                     options->dilation_w_factor(), options->dilation_h_factor());
+}
+
 } // namespace
 
 /**
@@ -475,7 +487,7 @@ void Conv2DGraphBuilder::build(const tflite::Operator *op, GraphBuilderContext *
   coco_conv2d->stride()->horizontal(conv_params->stride_w());
 
   // conv_params->padding() to left, top, right, bottom
-  coco::Padding2D padding = get_padding(conv_params, ifm_shape, ker_shape);
+  coco::Padding2D padding = conv2D_padding(conv_params, ifm_shape, ker_shape);
 
   coco_conv2d->pad()->top(padding.top());
   coco_conv2d->pad()->bottom(padding.bottom());