2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ir/Padding.h"
19 #include "util/Utils.h"
31 inline ExplicitPadding validPadding(void)
34 // ANEURALNETWORKS_PADDING_VALID
36 // VALID padding. No padding.
38 // When the input size is not evenly divisible by the filter size,
39 // the input at the end that could not fill the whole filter tile
40 // will simply be ignored.
42 ExplicitPadding padding;
52 inline ExplicitPadding samePaddingUsingIFM(const FeatureShape &ifm_shape, const Stride &stride,
53 uint32_t kw, uint32_t kh, uint32_t dwf, uint32_t dhf)
55 ExplicitPadding padding;
57 // ANEURALNETWORKS_PADDING_SAME (from NNAPI spec)
59 // SAME padding. Padding on both ends are the "same":
61 // padding_to_beginning = total_padding / 2
62 // padding_to_end = (total_padding + 1)/2.
64 const int32_t effective_filter_h_size = (kh - 1) * dhf + 1;
65 const int32_t effective_filter_w_size = (kw - 1) * dwf + 1;
67 const int32_t vertical_expected_output = (ifm_shape.H + stride.vertical - 1) / stride.vertical;
68 const int32_t horizontal_expected_output =
69 (ifm_shape.W + stride.horizontal - 1) / stride.horizontal;
71 const int32_t vertical_needed_input =
72 (vertical_expected_output - 1) * stride.vertical + effective_filter_h_size;
73 const int32_t vertical_total_padding = std::max(0, vertical_needed_input - ifm_shape.H);
75 const int32_t horizontal_needed_input =
76 (horizontal_expected_output - 1) * stride.horizontal + effective_filter_w_size;
77 const int32_t horizontal_total_padding = std::max(0, horizontal_needed_input - ifm_shape.W);
79 padding.top = vertical_total_padding / 2;
80 padding.bottom = (vertical_total_padding + 1) / 2;
81 padding.left = horizontal_total_padding / 2;
82 padding.right = (horizontal_total_padding + 1) / 2;
87 inline ExplicitPadding samePadding(const FeatureShape &ifm_shape, const FeatureShape &ofm_shape,
88 const Stride &stride, uint32_t kw, uint32_t kh, uint32_t dwf,
91 const int32_t vertical_expected_output = (ifm_shape.H + stride.vertical - 1) / stride.vertical;
92 const int32_t horizontal_expected_output =
93 (ifm_shape.W + stride.horizontal - 1) / stride.horizontal;
94 assert(vertical_expected_output == ofm_shape.H);
95 assert(horizontal_expected_output == ofm_shape.W);
97 UNUSED_RELEASE(ofm_shape);
98 UNUSED_RELEASE(vertical_expected_output);
99 UNUSED_RELEASE(horizontal_expected_output);
101 return samePaddingUsingIFM(ifm_shape, stride, kw, kh, dwf, dhf);
106 inline std::string to_string(const PaddingType type)
110 case PaddingType::EXPLICIT:
111 return "Padding::EXPLICIT";
112 case PaddingType::SAME:
113 return "Padding::SAME";
114 case PaddingType::VALID:
115 return "Padding::VALID";
117 throw std::runtime_error{"Fail to convert string: wrong padding type"};
121 Padding::Padding(void) : type{PaddingType::EXPLICIT}, param{0, 0, 0, 0}
126 Padding::Padding(PaddingType paddingType) : type{paddingType}, param{0, 0, 0, 0}
128 assert(paddingType != PaddingType::EXPLICIT);
131 Padding::Padding(uint32_t left, uint32_t right, uint32_t top, uint32_t bottom)
132 : type{PaddingType::EXPLICIT}, param{left, right, top, bottom}
137 const ExplicitPadding calculatePadding(const Padding &padding, const FeatureShape &ifm_shape,
138 const FeatureShape &ofm_shape, const Stride &stride,
139 uint32_t kw, uint32_t kh, uint32_t dwf, uint32_t dhf)
141 if (padding.type == PaddingType::EXPLICIT)
143 return padding.param;
145 else if (padding.type == PaddingType::SAME)
147 return samePadding(ifm_shape, ofm_shape, stride, kw, kh, dwf, dhf);
149 else if (padding.type == PaddingType::VALID)
151 return validPadding();
155 throw std::runtime_error{"Cannot handle padding type"};