void KernelGenerator::visit(const model::operation::TransposeNode &node)
{
- (void)node;
- throw std::runtime_error("Not supported, yet");
+ const auto ofm_idx{node.getOutputs().at(0)};
+ const auto ifm_idx{node.getInputs().at(model::operation::TransposeNode::Input::INPUT)};
+ const auto perm{node.param().perm};
+
+ const auto rank = _ctx.at(ifm_idx).shape().rank();
+ std::vector<int32_t> pv;
+ const auto perm_base = _ctx.at(perm).data().base();
+ const int perm_size = _ctx.at(perm).shape().num_elements();
+
+ assert(perm_base != nullptr);
+ for (int32_t n = 0; n < perm_size; ++n)
+ {
+ const int32_t perm_value = *(reinterpret_cast<const int32_t *>(perm_base) + n);
+ assert(perm_value < rank);
+ pv.emplace_back(perm_value);
+ }
+
+ auto ofm_alloc = _tensor_builder->at(ofm_idx).get();
+ const auto ifm_alloc = _tensor_builder->at(ifm_idx).get();
+ const auto frontend_layout = _current_subg_layout;
+ const auto backend_layout = ifm_alloc->layout();
+
+ auto backend_pv = ::neurun::backend::acl_common::getARMComputePermutationVector(
+ rank, pv, frontend_layout, backend_layout);
+
+ std::unique_ptr<::arm_compute::IFunction> fn;
+
+ if (ifm_alloc->num_dimensions() <= 2 && ofm_alloc->num_dimensions() <= 2)
+ {
+ auto l = nnfw::cpp14::make_unique<::arm_compute::NETranspose>();
+
+ l->configure(ifm_alloc->handle(), ofm_alloc->handle());
+
+ fn = std::move(l);
+ }
+ else
+ {
+ auto l = nnfw::cpp14::make_unique<::arm_compute::NEPermute>();
+
+ l->configure(ifm_alloc->handle(), ofm_alloc->handle(), backend_pv);
+
+ fn = std::move(l);
+ }
+
+ auto acl_fn = asAclFunction(std::move(fn));
+
+ _execution_builder->append(std::move(acl_fn));
}
void KernelGenerator::visit(const model::operation::AddNode &node)
void visit(const model::operation::SubNode &) override;
void visit(const model::operation::StridedSliceNode &) override;
void visit(const model::operation::TransposeConvNode &) override;
+ void visit(const model::operation::TransposeNode &) override;
void visit(const model::operation::AddNode &) override;
void visit(const model::operation::DivNode &) override;
void visit(const model::operation::ComparisonNode &) override;