int ofm_index;
int ifm_index;
int ker_index;
- int stride_width;
- int stride_height;
+ Padding padding;
+ Stride stride;
};
Param param;
param.ifm_index = ifm_index.asInt();
param.ker_index = ker_index.asInt();
- param.stride_width = hstride;
- param.stride_height = vstride;
+ param.stride.horizontal = hstride;
+ param.stride.vertical = vstride;
+
+ param.padding = (padding_type == ANEURALNETWORKS_PADDING_SAME)
+ ? same_padding(ifm_shape, param.stride, ker_shape.W, ker_shape.H)
+ : valid_padding();
auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
auto ofm_alloc = ctx.at(::internal::tflite::operand::Index{param.ofm_index});
// Only rank 4 is supported
const int rank = 4;
- fn->configure(ifm_alloc, ker_alloc, ofm_alloc, param.stride_width, param.stride_height,
- getARMComputeAxises(rank));
+
+ auto tconv_info = asPadStringInfo(param.padding, param.stride);
+
+ fn->configure(ifm_alloc, ker_alloc, ofm_alloc, tconv_info, getARMComputeAxises(rank));
builder.append("TransposeConv", std::move(fn));
};
#include <arm_compute/runtime/CL/CLScheduler.h>
void SimpleTransposeConv::configure(::arm_compute::ITensor *input, ::arm_compute::ITensor *weights,
- ::arm_compute::ITensor *output, int32_t stride_width,
- int32_t stride_height, ::arm_compute::Coordinates axises)
+ ::arm_compute::ITensor *output,
+ ::arm_compute::PadStrideInfo &tconv_info,
+ ::arm_compute::Coordinates axises)
{
auto rank = axises.num_dimensions();
_input = input;
_weights = weights;
_output = output;
- _stride_width = stride_width;
- _stride_height = stride_height;
+ _stride_width = tconv_info.stride().first;
+ _stride_height = tconv_info.stride().second;
+ _pad_width = tconv_info.pad_left();
+ _pad_height = tconv_info.pad_top();
_axises = axises;
}
template <typename T>
-inline void ApplyTransposeConv(const ::arm_compute::TensorShape &input_shape,
- const ::arm_compute::ITensor *input_data,
- const ::arm_compute::TensorShape &filter_shape,
- const ::arm_compute::ITensor *filter_data,
- const ::arm_compute::TensorShape &output_shape,
- const ::arm_compute::ITensor *output_data,
- const int32_t stride_width, const int32_t stride_height,
- const ::arm_compute::Coordinates axises)
+inline void ApplyTransposeConv(
+ const ::arm_compute::TensorShape &input_shape, const ::arm_compute::ITensor *input_data,
+ const ::arm_compute::TensorShape &filter_shape, const ::arm_compute::ITensor *filter_data,
+ const ::arm_compute::TensorShape &output_shape, const ::arm_compute::ITensor *output_data,
+ const int32_t stride_width, const int32_t stride_height, const int32_t pad_width,
+ const int32_t pad_height, const ::arm_compute::Coordinates axises)
{
const int batches = input_shape[axises[0]];
const int input_height = input_shape[axises[1]];
const int output_width = output_shape[axises[2]];
const int output_depth = output_shape[axises[3]];
- const int pad_width = (stride_width * (input_width - 1) + filter_width - output_width) / 2;
- const int pad_height = (stride_height * (input_height - 1) + filter_height - output_height) / 2;
-
// Although transpose convolution simplifies to convolution with transposed
// weights for strides of 1, non-unitary striding complicates matters. To
// keep this reference implementation as clear as possible, we use a
switch (_input->info()->data_type())
{
case ::arm_compute::DataType::S32:
- ApplyTransposeConv<int32_t>(
- _input->info()->tensor_shape(), _input, _weights->info()->tensor_shape(), _weights,
- _output->info()->tensor_shape(), _output, _stride_width, _stride_height, _axises);
+ ApplyTransposeConv<int32_t>(_input->info()->tensor_shape(), _input,
+ _weights->info()->tensor_shape(), _weights,
+ _output->info()->tensor_shape(), _output, _stride_width,
+ _stride_height, _pad_width, _pad_height, _axises);
break;
case ::arm_compute::DataType::F32:
- ApplyTransposeConv<float>(
- _input->info()->tensor_shape(), _input, _weights->info()->tensor_shape(), _weights,
- _output->info()->tensor_shape(), _output, _stride_width, _stride_height, _axises);
+ ApplyTransposeConv<float>(_input->info()->tensor_shape(), _input,
+ _weights->info()->tensor_shape(), _weights,
+ _output->info()->tensor_shape(), _output, _stride_width,
+ _stride_height, _pad_width, _pad_height, _axises);
break;
default:
ARM_COMPUTE_ERROR("DataType not supported");
{
public:
SimpleTransposeConv()
- : _input(nullptr), _weights(nullptr), _output(nullptr), _stride_width(0), _stride_height(0)
+ : _input(nullptr), _weights(nullptr), _output(nullptr), _stride_width(0), _stride_height(0),
+ _pad_width(0), _pad_height(0)
{
// DO NOTHING
}
* @param[in] input First tensor input.
* @param[in] weights Weights
* @param[out] output Output tensor.
- * @param[in] stride_width Stride width value
- * @param[in] stride_height Stride height value
+ * @param[in] tc_info Contains padding and policies to be used in the deconvolution,
+ * this is decribed in @ref PadStrideInfo.
* @param[in] axises Axises of rank 4
*/
void configure(::arm_compute::ITensor *input, ::arm_compute::ITensor *weights,
- ::arm_compute::ITensor *output, int32_t stride_width, int32_t stride_height,
+ ::arm_compute::ITensor *output, ::arm_compute::PadStrideInfo &tconv_info,
::arm_compute::Coordinates axises = getARMComputeAxises(4));
void run() override;
::arm_compute::ITensor *_output;
int32_t _stride_width;
int32_t _stride_height;
+ int32_t _pad_width;
+ int32_t _pad_height;
::arm_compute::Coordinates _axises;
};