#include "kernels/Utils.h"
-#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
+#include <tensorflow/lite/kernels/internal/reference/transpose_conv.h>
#include <stdexcept>
{
TransposeConv::TransposeConv(const Tensor *output_shape, const Tensor *filter, const Tensor *input,
- const Tensor *bias, Tensor *output, const TransposeConvParams ¶ms)
- : KernelWithParams<TransposeConvParams>({output_shape, filter, input, bias}, {output}, params)
+ const Tensor *bias, Tensor *output, Tensor *scratch_tensor,
+ const TransposeConvParams ¶ms)
+ : KernelWithParams<TransposeConvParams>({output_shape, filter, input, bias},
+ {output, scratch_tensor}, params)
{
}
if (input()->element_type() == DataType::U8 || input()->element_type() == DataType::S16)
{
- DataType scratch_data_type =
- input()->element_type() == DataType::S16 ? DataType::S64 : DataType::S32;
- _scratch_tensor =
- std::make_unique<Tensor>(scratch_data_type, output()->shape(), AffineQuantization{}, "");
+ auto scratch_tensor = getOutputTensors()[1];
+ scratch_tensor->resize(output()->shape());
const std::vector<double> real_multipliers =
getQuantizedConvolutionMultiplers(input()->scale(), filter()->scales(), output()->scale());
_quant_multipliers = quantizeMultipliers(real_multipliers);
}
+ else
+ {
+ auto scratch_tensor = getOutputTensors()[1];
+ scratch_tensor->set_allocatable(false);
+ }
}
void TransposeConv::execute() const
default:
throw std::runtime_error("Unsupported type.");
}
- if (!!_scratch_tensor)
- _scratch_tensor->deallocate();
}
void TransposeConv::evalFloat() const
op_params.quantized_activation_min = std::numeric_limits<uint8_t>::min();
op_params.quantized_activation_max = std::numeric_limits<uint8_t>::max();
+ auto scratch_tensor = getOutputTensors()[1];
+
tflite::reference_ops::TransposeConv(op_params, //
getTensorShape(input()), getTensorData<uint8>(input()), //
getTensorShape(filter()), getTensorData<uint8>(filter()), //
getTensorShape(bias()), getTensorData<int32_t>(bias()), //
getTensorShape(output()), getTensorData<uint8>(output()), //
tflite::RuntimeShape(), nullptr, //
- getTensorData<int32_t>(_scratch_tensor.get()));
+ getTensorData<int32_t>(scratch_tensor));
}
void TransposeConv::evalQuantizedPerChannel() const
const auto *filter_data = getTensorData<uint8_t>(filter());
const auto *bias_data = getTensorData<int32_t>(bias());
auto *output_data = getTensorData<uint8_t>(output());
- auto *scratch_data = getTensorData<int32_t>(_scratch_tensor.get());
+
+ auto scratch_tensor = getOutputTensors()[1];
+ auto *scratch_data = getTensorData<int32_t>(scratch_tensor);
const Shape &input_shape = input()->shape();
const Shape &filter_shape = filter()->shape();
int32_t activation_max{};
calculateActivationRangeQuantized(Activation::NONE, output(), &activation_min, &activation_max);
- std::memset(scratch_data, 0, _scratch_tensor->shape().num_elements() * sizeof(int32_t));
+ std::memset(scratch_data, 0, scratch_tensor->shape().num_elements() * sizeof(int32_t));
BroadcastableWrapper<ChannelQuantMultipliers> output_multipliers(_quant_multipliers);
for (int32_t batch = 0; batch < batches; ++batch)
const auto *filter_data = getTensorData<int16_t>(filter());
const auto *bias_data = getTensorData<int64_t>(bias());
auto *output_data = getTensorData<int16_t>(output());
- auto *scratch_data = getTensorData<int64_t>(_scratch_tensor.get());
+
+ auto scratch_tensor = getOutputTensors()[1];
+ auto *scratch_data = getTensorData<int64_t>(scratch_tensor);
const Shape &input_shape = input()->shape();
const Shape &filter_shape = filter()->shape();
int32_t activation_max{};
calculateActivationRangeQuantized(Activation::NONE, output(), &activation_min, &activation_max);
- std::memset(scratch_data, 0, _scratch_tensor->shape().num_elements() * sizeof(int64_t));
+ std::memset(scratch_data, 0, scratch_tensor->shape().num_elements() * sizeof(int64_t));
BroadcastableWrapper<ChannelQuantMultipliers> output_multipliers(_quant_multipliers);
for (int32_t batch = 0; batch < batches; ++batch)