IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / reference / workloads / RefArithmeticWorkload.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <armnn/Types.hpp>
9 #include <backendsCommon/StringMapping.hpp>
10 #include <backendsCommon/Workload.hpp>
11 #include <backendsCommon/WorkloadData.hpp>
12
13 namespace armnn
14 {
15
16 template <typename Functor,
17           typename armnn::DataType DataType,
18           typename ParentDescriptor,
19           typename armnn::StringMapping::Id DebugString>
20 class RefArithmeticWorkload
21 {
22     // Needs specialization. The default is empty on purpose.
23 };
24
25 template <typename ParentDescriptor, typename Functor>
26 class BaseFloat32ArithmeticWorkload : public Float32Workload<ParentDescriptor>
27 {
28 public:
29     using Float32Workload<ParentDescriptor>::Float32Workload;
30     void ExecuteImpl(const char * debugString) const;
31 };
32
33 template <typename Functor,
34           typename ParentDescriptor,
35           typename armnn::StringMapping::Id DebugString>
36 class RefArithmeticWorkload<Functor, armnn::DataType::Float32, ParentDescriptor, DebugString>
37     : public BaseFloat32ArithmeticWorkload<ParentDescriptor, Functor>
38 {
39 public:
40     using BaseFloat32ArithmeticWorkload<ParentDescriptor, Functor>::BaseFloat32ArithmeticWorkload;
41
42     virtual void Execute() const override
43     {
44         using Parent = BaseFloat32ArithmeticWorkload<ParentDescriptor, Functor>;
45         Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString));
46     }
47 };
48
49 template <typename ParentDescriptor, typename Functor>
50 class BaseUint8ArithmeticWorkload : public Uint8Workload<ParentDescriptor>
51 {
52 public:
53     using Uint8Workload<ParentDescriptor>::Uint8Workload;
54     void ExecuteImpl(const char * debugString) const;
55 };
56
57 template <typename Functor,
58           typename ParentDescriptor,
59           typename armnn::StringMapping::Id DebugString>
60 class RefArithmeticWorkload<Functor, armnn::DataType::QuantisedAsymm8, ParentDescriptor, DebugString>
61     : public BaseUint8ArithmeticWorkload<ParentDescriptor, Functor>
62 {
63 public:
64     using BaseUint8ArithmeticWorkload<ParentDescriptor, Functor>::BaseUint8ArithmeticWorkload;
65
66     virtual void Execute() const override
67     {
68         using Parent = BaseUint8ArithmeticWorkload<ParentDescriptor, Functor>;
69         Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString));
70     }
71 };
72
73 using RefAdditionFloat32Workload =
74     RefArithmeticWorkload<std::plus<float>,
75                           DataType::Float32,
76                           AdditionQueueDescriptor,
77                           StringMapping::RefAdditionWorkload_Execute>;
78
79 using RefAdditionUint8Workload =
80     RefArithmeticWorkload<std::plus<float>,
81                           DataType::QuantisedAsymm8,
82                           AdditionQueueDescriptor,
83                           StringMapping::RefAdditionWorkload_Execute>;
84
85
86 using RefSubtractionFloat32Workload =
87     RefArithmeticWorkload<std::minus<float>,
88                           DataType::Float32,
89                           SubtractionQueueDescriptor,
90                           StringMapping::RefSubtractionWorkload_Execute>;
91
92 using RefSubtractionUint8Workload =
93     RefArithmeticWorkload<std::minus<float>,
94                           DataType::QuantisedAsymm8,
95                           SubtractionQueueDescriptor,
96                           StringMapping::RefSubtractionWorkload_Execute>;
97
98 using RefMultiplicationFloat32Workload =
99     RefArithmeticWorkload<std::multiplies<float>,
100                           DataType::Float32,
101                           MultiplicationQueueDescriptor,
102                           StringMapping::RefMultiplicationWorkload_Execute>;
103
104 using RefMultiplicationUint8Workload =
105     RefArithmeticWorkload<std::multiplies<float>,
106                           DataType::QuantisedAsymm8,
107                           MultiplicationQueueDescriptor,
108                           StringMapping::RefMultiplicationWorkload_Execute>;
109
110 using RefDivisionFloat32Workload =
111     RefArithmeticWorkload<std::divides<float>,
112                           DataType::Float32,
113                           DivisionQueueDescriptor,
114                           StringMapping::RefDivisionWorkload_Execute>;
115
116 using RefDivisionUint8Workload =
117     RefArithmeticWorkload<std::divides<float>,
118                           DataType::QuantisedAsymm8,
119                           DivisionQueueDescriptor,
120                           StringMapping::RefDivisionWorkload_Execute>;
121
122 } // armnn