#include "mir/Shape.h"
#include "mir/ShapeRange.h"
#include "mir/Tensor.h"
-#include "mir/TensorUtil.h"
using namespace ::tflite;
auto bias = inputs.at(2);
// OHWI -> HWIO
- // TODO Insert TransposeOp instead when ACL backend is ready for that.
- const auto &kernel_tensor = mir::transposeTensor<1, 2, 3, 0>(extractTensor(kernel));
- kernel = createOp<ops::ConstantOp>(kernel_tensor)->getOutput(0);
+ const std::vector<std::size_t> axis_order{1, 2, 3, 0};
+ kernel = createOp<ops::TransposeOp>(kernel, axis_order)->getOutput(0);
Shape strides{opts->stride_h(), opts->stride_w()};
std::vector<int32_t> padding_before(2);
Shape output_shape(convertIntTensorToVector<int32_t>(output_shape_tensor));
// OHWI -> HWOI
- // TODO Insert TransposeOp instead when ACL backend is ready for that.
- const auto &kernel_tensor = mir::transposeTensor<1, 2, 0, 3>(extractTensor(kernel));
- kernel = createOp<ops::ConstantOp>(kernel_tensor)->getOutput(0);
+ const std::vector<std::size_t> axis_order{1, 2, 0, 3};
+ kernel = createOp<ops::TransposeOp>(kernel, axis_order)->getOutput(0);
auto result =
- createOp<ops::DeConv2DOp>(input, kernel, strides, paddingMap[opts->padding()], output_shape);
- return {result->getOutput(0)};
+ createOp<ops::DeConv2DOp>(input, kernel, strides, paddingMap[opts->padding()], output_shape)
+ ->getOutput(0);
+ return {result};
}
std::vector<mir::Operation::Output *>
int32_t inner_size = input_shape.numElements() / outer_size;
auto flatten = createOp<ops::ReshapeOp>(input, Shape{outer_size, inner_size});
- // TODO Insert TransposeOp instead when ACL backend is ready for that.
- const auto &weights_tensor = mir::transposeTensor<1, 0>(extractTensor(weights));
- weights = createOp<ops::ConstantOp>(weights_tensor)->getOutput(0);
+ // Transpose the weights.
+ const std::vector<std::size_t> axis_order{1, 0};
+ weights = createOp<ops::TransposeOp>(weights, axis_order)->getOutput(0);
auto result = createOp<ops::FullyConnectedOp>(flatten->getOutput(0), weights)->getOutput(0);
result = createOp<ops::AddOp>(result, bias)->getOutput(0);