#include <arm_compute/runtime/CL/functions/CLDequantizationLayer.h>
#include <arm_compute/runtime/CL/functions/CLDepthToSpace.h>
#include <arm_compute/runtime/CL/functions/CLReductionMean.h>
-#include <arm_compute/runtime/CL/functions/CLTranspose.h>
+#include <arm_compute/runtime/CL/functions/CLPermuteEx.h>
#include <arm_compute/runtime/CL/functions/CLRNNLayer.h>
#include <arm_compute/runtime/CL/functions/CLFloor.h>
#include <arm_compute/runtime/CL/functions/CLCopy.h>
void Planner::visit(const ::internal::tflite::op::Transpose::Node &node)
{
VERBOSE(Transpose) << "Configure Transpose operation" << std::endl;
- // Transpose supports only height-wight dimention support.
- // CLPermute can be used to implement generic transpose along any axis
- // But CLPermute only implements [2,0,1], [1,2,0], [3,2,0,1]
- // TODO Implement other permutation CLPermute function and provide generic transpose
const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
+ const ::internal::tflite::operand::Index permu_index{node.param().permu_index};
+
+ assert(_ctx.at(ifm_index).shape().rank() == _ctx.at(ofm_index).shape().rank());
+ assert(_ctx.at(permu_index).hasData() == true);
// Set shape constraints
_builder.addShapeConstr(
_builder.addShapeConstr(
ifm_index, asTensorInfo(asTensorShape(_ctx.at(ifm_index).shape()), _ctx.at(ifm_index).type(),
_ctx.at(ifm_index).scale(), _ctx.at(ifm_index).zeroPoint()));
- // NNAPI spec provides permutation vector for generic transpose
- // TODO Make the permutation vector a part of Param
+
struct Param
{
int ofm_index;
int ifm_index;
+ const int32_t *pv;
+ int rank;
};
Param param;
param.ofm_index = ofm_index.asInt();
param.ifm_index = ifm_index.asInt();
+ param.pv = reinterpret_cast<const int32_t *>(_ctx.at(permu_index).data().base());
+ param.rank = _ctx.at(ifm_index).shape().rank();
auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
+
auto ofm_alloc = ctx.at(::internal::tflite::operand::Index{param.ofm_index});
const auto ifm_alloc = ctx.at(::internal::tflite::operand::Index{param.ifm_index});
- // CLTranspose assumes only spatial transpose, will be replaced with CLPermute
- // TODO Check the validity of permutation vector, then call CLPermute with permu vector
- auto fn = nnfw::make_unique<::arm_compute::CLTranspose>();
+ if (::internal::arm_compute::isGpuMode())
+ {
+ auto fn = nnfw::make_unique<::arm_compute::CLPermuteEx>();
+
+ fn->configure(CAST_CL(ifm_alloc), CAST_CL(ofm_alloc),
+ getARMComputePermutationVector(param.rank, param.pv));
- fn->configure(CAST_CL(ifm_alloc), CAST_CL(ofm_alloc));
+ builder.append("Transpose", std::move(fn));
+ }
+ else
+ {
+ throw std::runtime_error("Not supported, yet");
+ }
- builder.append("Transpose", std::move(fn));
};
_builder.addStage(stage);