IVGCVSW-1898 : Ref backend folder structure
[platform/upstream/armnn.git] / src / backends / reference / RefWorkloadFactory.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <backends/CpuTensorHandle.hpp>
6 #include <backends/MemCopyWorkload.hpp>
7 #include <backends/MakeWorkloadHelper.hpp>
8 #include "RefWorkloadFactory.hpp"
9 #include "workloads/RefWorkloads.hpp"
10 #include "Layer.hpp"
11
12 #include <boost/log/trivial.hpp>
13
14 namespace armnn
15 {
16
17 template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
18 std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
19     const WorkloadInfo& info) const
20 {
21     return armnn::MakeWorkload<NullWorkload, F32Workload, U8Workload>(descriptor, info);
22 }
23
24 RefWorkloadFactory::RefWorkloadFactory()
25 {
26 }
27
28 bool RefWorkloadFactory::IsLayerSupported(const Layer& layer, boost::optional<DataType> dataType,
29                                           std::string& outReasonIfUnsupported)
30 {
31     return IWorkloadFactory::IsLayerSupported(Compute::CpuRef, layer, dataType, outReasonIfUnsupported);
32 }
33
34 std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
35 {
36     return std::make_unique<ScopedCpuTensorHandle>(tensorInfo);
37 }
38
39 std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
40                                                                       DataLayout dataLayout) const
41 {
42     return std::make_unique<ScopedCpuTensorHandle>(tensorInfo);
43 }
44
45 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
46                                                            const WorkloadInfo& info) const
47 {
48     if (info.m_InputTensorInfos.empty() )
49     {
50         throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Input cannot be zero length");
51     }
52     if (info.m_OutputTensorInfos.empty())
53     {
54         throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Output cannot be zero length");
55     }
56
57     if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
58     {
59         throw InvalidArgumentException("RefWorkloadFactory::CreateInput: data input and output differ in byte count.");
60     }
61
62     return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info);
63 }
64
65 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
66                                                             const WorkloadInfo& info) const
67 {
68     if (info.m_InputTensorInfos.empty() )
69     {
70         throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Input cannot be zero length");
71     }
72     if (info.m_OutputTensorInfos.empty())
73     {
74         throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Output cannot be zero length");
75     }
76     if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
77     {
78         throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: data input and output differ in byte count.");
79     }
80
81     return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info);
82 }
83
84 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
85                                                                 const WorkloadInfo&              info) const
86 {
87     return MakeWorkload<RefActivationFloat32Workload, RefActivationUint8Workload>(descriptor, info);
88 }
89
90 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
91                                                              const WorkloadInfo&           info) const
92 {
93     return MakeWorkload<RefSoftmaxFloat32Workload, RefSoftmaxUint8Workload>(descriptor, info);
94 }
95
96 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
97                                                               const WorkloadInfo&            info) const
98 {
99     return MakeWorkload<RefSplitterFloat32Workload, RefSplitterUint8Workload>(descriptor, info);
100 }
101
102 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
103                                                                    const WorkloadInfo&          info) const
104 {
105     return MakeWorkload<RefMergerFloat32Workload, RefMergerUint8Workload>(descriptor, info);
106 }
107
108 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateFullyConnected(
109     const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) const
110 {
111     return MakeWorkload<RefFullyConnectedFloat32Workload, RefFullyConnectedUint8Workload>(descriptor, info);
112 }
113
114 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
115                                                                     const WorkloadInfo&           info) const
116 {
117     return MakeWorkload<RefPermuteFloat32Workload, RefPermuteUint8Workload>(descriptor, info);
118 }
119
120 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
121                                                                       const WorkloadInfo&           info) const
122 {
123     return MakeWorkload<RefPooling2dFloat32Workload, RefPooling2dUint8Workload>(descriptor, info);
124 }
125
126 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateConvolution2d(
127     const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
128 {
129     return MakeWorkload<RefConvolution2dFloat32Workload, RefConvolution2dUint8Workload>(descriptor, info);
130 }
131
132 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDepthwiseConvolution2d(
133     const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
134 {
135     return MakeWorkload<RefDepthwiseConvolution2dFloat32Workload,
136         RefDepthwiseConvolution2dUint8Workload>(descriptor, info);
137 }
138
139 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateNormalization(
140     const NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
141 {
142     return MakeWorkload<RefNormalizationFloat32Workload, NullWorkload>(descriptor, info);
143 }
144
145 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
146                                                                      const WorkloadInfo&            info) const
147 {
148     return MakeWorkload<RefAdditionFloat32Workload, RefAdditionUint8Workload>(descriptor, info);
149 }
150
151 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMultiplication(
152     const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const
153 {
154     return MakeWorkload<RefMultiplicationFloat32Workload, RefMultiplicationUint8Workload>(descriptor, info);
155 }
156
157 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateBatchNormalization(
158     const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
159 {
160     return MakeWorkload<RefBatchNormalizationFloat32Workload, RefBatchNormalizationUint8Workload>(descriptor, info);
161 }
162
163 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
164                                                                     const WorkloadInfo&        info) const
165 {
166     if (descriptor.m_Inputs.empty())
167     {
168         throw InvalidArgumentException("RefWorkloadFactory: CreateMemCopy() expected an input tensor.");
169     }
170     return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
171 }
172
173 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
174                                                                     const WorkloadInfo& info) const
175 {
176     return MakeWorkload<RefResizeBilinearFloat32Workload, RefResizeBilinearUint8Workload>(descriptor, info);
177 }
178
179 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFakeQuantization(
180     const FakeQuantizationQueueDescriptor& descriptor,
181     const WorkloadInfo& info) const
182 {
183     return MakeWorkload<RefFakeQuantizationFloat32Workload, NullWorkload>(descriptor, info);
184 }
185
186 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
187     const WorkloadInfo& info) const
188 {
189     return MakeWorkload<RefL2NormalizationFloat32Workload, NullWorkload>(descriptor, info);
190 }
191
192 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
193     const WorkloadInfo& info) const
194 {
195     return MakeWorkload<RefConstantFloat32Workload, RefConstantUint8Workload>(descriptor, info);
196 }
197
198 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
199     const WorkloadInfo& info) const
200 {
201     return MakeWorkload<RefReshapeFloat32Workload, RefReshapeUint8Workload>(descriptor, info);
202 }
203
204 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
205                                                           const WorkloadInfo& info) const
206 {
207     return MakeWorkload<RefFloorFloat32Workload, NullWorkload>(descriptor, info);
208 }
209
210 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
211     const WorkloadInfo& info) const
212 {
213     return MakeWorkload<RefLstmFloat32Workload, NullWorkload>(descriptor, info);
214 }
215
216 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp16ToFp32(
217     const ConvertFp16ToFp32QueueDescriptor& descriptor,
218     const WorkloadInfo& info) const
219 {
220     return std::make_unique<RefConvertFp16ToFp32Workload>(descriptor, info);
221 }
222
223 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp32ToFp16(
224     const ConvertFp32ToFp16QueueDescriptor& descriptor,
225     const WorkloadInfo& info) const
226 {
227     return std::make_unique<RefConvertFp32ToFp16Workload>(descriptor, info);
228 }
229
230 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateDivision(
231     const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const
232 {
233     return MakeWorkload<RefDivisionFloat32Workload, RefDivisionUint8Workload>(descriptor, info);
234 }
235
236 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateSubtraction(
237     const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const
238 {
239     return MakeWorkload<RefSubtractionFloat32Workload, RefSubtractionUint8Workload>(descriptor, info);
240 }
241
242 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMean(
243     const MeanQueueDescriptor& descriptor, const WorkloadInfo& info) const
244 {
245     return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
246 }
247
248 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
249                                                  const WorkloadInfo& info) const
250 {
251     return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
252 }
253
254
255 } // namespace armnn