Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / NeonWorkloads / NeonPermuteWorkload.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #pragma once
7
8 #include "backends/Workload.hpp"
9 #include "backends/WorkloadData.hpp"
10 #include "backends/NeonWorkloadUtils.hpp"
11
12 #include <armnn/TypesUtils.hpp>
13 #include <arm_compute/runtime/NEON/functions/NEPermute.h>
14
15 #include <string>
16
17 namespace armnn
18 {
19 arm_compute::Status NeonPermuteWorkloadValidate(const TensorInfo& input, const TensorInfo& output,
20                                                 const PermuteDescriptor& descriptor);
21
22 template <armnn::DataType... DataTypes>
23 class NeonPermuteWorkload : public TypedWorkload<PermuteQueueDescriptor, DataTypes...>
24 {
25 public:
26     static const std::string& GetName()
27     {
28         static const std::string name = std::string("NeonPermuteWorkload");
29         return name;
30     }
31
32     NeonPermuteWorkload(const PermuteQueueDescriptor& descriptor, const WorkloadInfo& info);
33     void Execute() const override;
34
35 private:
36     using TypedWorkload<PermuteQueueDescriptor, DataTypes...>::m_Data;
37     mutable arm_compute::NEPermute m_PermuteFunction;
38 };
39
40 using NeonPermuteFloatWorkload = NeonPermuteWorkload<DataType::Float16, DataType::Float32>;
41 using NeonPermuteUint8Workload = NeonPermuteWorkload<DataType::QuantisedAsymm8>;
42
43 } // namespace armnn