namespace
{
-coco::Padding2D get_padding(const tflite::Conv2DOptions *options, const tensor::Shape ishape,
- const tensor::Shape kshape)
+coco::Padding2D get_padding(const tensor::Shape &ifm_shape, const int kernel_w, const int kernel_h,
+ tflite::Padding padding, int stride_w, int stride_h,
+ int dilation_w_factor, int dilation_h_factor)
{
- assert(ishape.rank() == 4);
- assert(kshape.rank() == 4);
+ assert(ifm_shape.rank() == 4);
auto compute_padding = [](tflite::Padding padding, int stride, int dilation_rate, int in_size,
int filter_size) {
return value > 0 ? value : 0;
};
- tflite::Padding padding = options->padding();
// dim(2), dim(3) are from order of NCHW
- int padding_w = compute_padding(padding, options->stride_w(), options->dilation_w_factor(),
- ishape.dim(3), kshape.dim(3));
- int padding_h = compute_padding(padding, options->stride_h(), options->dilation_h_factor(),
- ishape.dim(2), kshape.dim(2));
+ int padding_w = compute_padding(padding, stride_w, dilation_w_factor, ifm_shape.dim(3), kernel_w);
+ int padding_h = compute_padding(padding, stride_h, dilation_h_factor, ifm_shape.dim(2), kernel_h);
coco::Padding2D coco_padding;
coco_padding.top(padding_h).bottom(padding_h).left(padding_w).right(padding_w);
return coco_padding;
}
+coco::Padding2D pool2D_padding(const tflite::Pool2DOptions *options, const tensor::Shape &ifm_shape,
+ const int filter_w, const int filter_h)
+{
+ return get_padding(ifm_shape, filter_w, filter_h, options->padding(), options->stride_w(),
+ options->stride_h(), 1, 1);
+}
+
+coco::Padding2D conv2D_padding(const tflite::Conv2DOptions *options, const tensor::Shape &ifm_shape,
+ const tensor::Shape &kernel_shape)
+{
+ return get_padding(ifm_shape, kernel_shape.dim(3), kernel_shape.dim(2), /* kernel layout: NCHW */
+ options->padding(), options->stride_w(), options->stride_h(),
+ options->dilation_w_factor(), options->dilation_h_factor());
+}
+
} // namespace
/**
coco_conv2d->stride()->horizontal(conv_params->stride_w());
// conv_params->padding() to left, top, right, bottom
- coco::Padding2D padding = get_padding(conv_params, ifm_shape, ker_shape);
+ coco::Padding2D padding = conv2D_padding(conv_params, ifm_shape, ker_shape);
coco_conv2d->pad()->top(padding.top());
coco_conv2d->pad()->bottom(padding.bottom());