5a5cb8920429976eca44a8509260fdd7ac3a6760
[platform/upstream/armnn.git] / src / backends / cl / ClWorkloadFactory.cpp
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "ClWorkloadFactory.hpp"
6 #include "ClBackendId.hpp"
7 #include "ClBackendModelContext.hpp"
8
9 #include <Layer.hpp>
10
11 #include <armnn/Exceptions.hpp>
12 #include <armnn/Utils.hpp>
13 #include <armnn/utility/IgnoreUnused.hpp>
14 #include <armnn/utility/NumericCast.hpp>
15 #include <armnn/utility/PolymorphicDowncast.hpp>
16
17 #include <backendsCommon/CpuTensorHandle.hpp>
18 #include <backendsCommon/MakeWorkloadHelper.hpp>
19 #include <backendsCommon/MemCopyWorkload.hpp>
20 #include <backendsCommon/MemImportWorkload.hpp>
21
22 #include <cl/ClTensorHandle.hpp>
23 #include <cl/workloads/ClWorkloads.hpp>
24 #include <cl/workloads/ClWorkloadUtils.hpp>
25
26 #include <arm_compute/core/CL/CLKernelLibrary.h>
27 #include <arm_compute/runtime/CL/CLBufferAllocator.h>
28 #include <arm_compute/runtime/CL/CLScheduler.h>
29
30 #include <Filesystem.hpp>
31
32 namespace armnn
33 {
34
35 namespace
36 {
37 static const BackendId s_Id{ClBackendId()};
38 }
39
40 bool ClWorkloadFactory::IsLayerSupported(const Layer& layer,
41                                          Optional<DataType> dataType,
42                                          std::string& outReasonIfUnsupported)
43 {
44     return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
45 }
46
47 bool ClWorkloadFactory::IsLayerSupported(const IConnectableLayer& layer,
48                                          Optional<DataType> dataType,
49                                          std::string& outReasonIfUnsupported,
50                                          const ModelOptions& modelOptions)
51 {
52     return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported, modelOptions);
53 }
54
55 const BackendId& ClWorkloadFactory::GetBackendId() const
56 {
57     return s_Id;
58 }
59
60 void ClWorkloadFactory::AfterWorkloadsCreated()
61 {
62     if(m_ModelContextPtr)
63     {
64         auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
65         if (modelOptions->SaveCachedNetwork())
66         {
67             // Save map to a filepath provided in ModelOptions
68             auto filePath = modelOptions->GetCachedNetworkFilePath();
69             if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath))
70             {
71                 ///  Saving will be implemented within IVGCVSW-5483 story.
72             }
73         }
74     }
75 }
76
77 template <typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args>
78 std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
79                                                            const WorkloadInfo& info,
80                                                            Args&&... args)
81 {
82     try
83     {
84         return MakeWorkloadHelper<FloatWorkload, Uint8Workload>(descriptor, info, std::forward<Args>(args)...);
85     }
86     catch (const cl::Error& clError)
87     {
88         throw WrapClError(clError, CHECK_LOCATION());
89     }
90 }
91
92 template <typename Workload, typename QueueDescriptorType, typename... Args>
93 std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
94                                                            const WorkloadInfo& info,
95                                                            Args&&... args)
96 {
97     try
98     {
99         return std::make_unique<Workload>(descriptor, info, std::forward<Args>(args)...);
100     }
101     catch (const cl::Error& clError)
102     {
103         throw WrapClError(clError, CHECK_LOCATION());
104     }
105 }
106
107 void ClWorkloadFactory::InitializeCLCompileContext()
108 {
109     // Initialize our m_CLCompileContext using default device and context
110     auto context = arm_compute::CLKernelLibrary::get().context();
111     auto device  = arm_compute::CLKernelLibrary::get().get_device();
112     m_CLCompileContext = arm_compute::CLCompileContext(context, device);
113
114     if (m_ModelContextPtr)
115     {
116         // Load saved programs if the user has set a filepath
117         auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
118         auto filePath = modelOptions->GetCachedNetworkFilePath();
119         if (filePath != ""
120             && fs::exists(filePath)
121             && fs::is_regular_file(filePath)
122             && !(modelOptions->SaveCachedNetwork()))
123         {
124             ///  Loading will be implemented within IVGCVSW-5483 story.
125         }
126     }
127 }
128
129 ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager)
130     : m_MemoryManager(memoryManager), m_ModelContextPtr(IBackendInternal::IBackendSpecificModelContextPtr{})
131 {
132     InitializeCLCompileContext();
133 }
134
135 ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager,
136                                      const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr)
137     : m_MemoryManager(memoryManager), m_ModelContextPtr(modelContextPtr)
138 {
139     InitializeCLCompileContext();
140 }
141
142 std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
143                                                                      const bool IsMemoryManaged) const
144 {
145     IgnoreUnused(IsMemoryManaged);
146     std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo);
147     tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
148
149     return tensorHandle;
150 }
151
152 std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
153                                                                      DataLayout dataLayout,
154                                                                      const bool IsMemoryManaged) const
155 {
156     IgnoreUnused(IsMemoryManaged);
157     std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout);
158     tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
159
160     return tensorHandle;
161 }
162
163 std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateSubTensorHandle(ITensorHandle& parent,
164                                                                         TensorShape const& subTensorShape,
165                                                                         unsigned int const* subTensorOrigin) const
166 {
167     arm_compute::Coordinates coords;
168     arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
169
170     coords.set_num_dimensions(subTensorShape.GetNumDimensions());
171     for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); i++)
172     {
173         // Arm compute indexes tensor coords in reverse order.
174         unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
175         coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex]));
176     }
177
178     const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
179     if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
180     {
181         return nullptr;
182     }
183
184     return std::make_unique<ClSubTensorHandle>(
185         PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords);
186 }
187
188 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateAbs(const AbsQueueDescriptor& descriptor,
189                                                         const WorkloadInfo& info) const
190 {
191     IgnoreUnused(descriptor);
192
193     ElementwiseUnaryQueueDescriptor elementwiseUnaryDescriptor;
194     elementwiseUnaryDescriptor.m_Parameters = ElementwiseUnaryDescriptor(UnaryOperation::Abs);
195
196     return CreateElementwiseUnary(elementwiseUnaryDescriptor, info);
197 }
198
199 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
200                                                                const WorkloadInfo& info) const
201 {
202     return MakeWorkload<ClActivationWorkload>(descriptor, info, m_CLCompileContext);
203 }
204
205 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
206                                                              const WorkloadInfo& info) const
207 {
208     return MakeWorkload<ClAdditionWorkload>(descriptor, info, m_CLCompileContext);
209 }
210
211 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
212                                                               const WorkloadInfo& info) const
213 {
214     return std::make_unique<ClArgMinMaxWorkload>(descriptor, info, m_CLCompileContext);
215 }
216
217 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchNormalization(
218     const BatchNormalizationQueueDescriptor& descriptor,
219     const WorkloadInfo& info) const
220 {
221     return MakeWorkload<ClBatchNormalizationFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
222 }
223
224 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
225                                                                    const WorkloadInfo& info) const
226 {
227     return MakeWorkload<ClBatchToSpaceNdWorkload>(descriptor, info, m_CLCompileContext);
228 }
229
230 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor,
231                                                                const WorkloadInfo& info) const
232 {
233     return MakeWorkload<ClComparisonWorkload>(descriptor, info, m_CLCompileContext);
234 }
235
236 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
237                                                            const WorkloadInfo& info) const
238 {
239     return MakeWorkload<ClConcatWorkload>(descriptor, info, m_CLCompileContext);
240 }
241
242 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
243                                                              const WorkloadInfo& info) const
244 {
245     return MakeWorkload<ClConstantWorkload>(descriptor, info, m_CLCompileContext);
246 }
247
248 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp16ToFp32(
249     const ConvertFp16ToFp32QueueDescriptor& descriptor,
250     const WorkloadInfo& info) const
251 {
252     return MakeWorkload<ClConvertFp16ToFp32Workload>(descriptor, info, m_CLCompileContext);
253 }
254
255 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp32ToFp16(
256     const ConvertFp32ToFp16QueueDescriptor& descriptor,
257     const WorkloadInfo& info) const
258 {
259     return MakeWorkload<ClConvertFp32ToFp16Workload>(descriptor, info, m_CLCompileContext);
260 }
261
262 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
263                                                                   const WorkloadInfo& info) const
264 {
265     bool isFastMathEnabled = false;
266     if (m_ModelContextPtr)
267     {
268         if (m_ModelContextPtr.get() != nullptr)
269         {
270             auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
271             if (modelOptions)
272             {
273                 isFastMathEnabled = modelOptions->IsFastMathEnabled();
274             }
275         }
276     }
277     return MakeWorkload<ClConvolution2dWorkload>(descriptor,
278                                                  info,
279                                                  m_MemoryManager->GetIntraLayerManager(),
280                                                  m_CLCompileContext,
281                                                  isFastMathEnabled);
282 }
283
284 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
285                                                           const WorkloadInfo& info) const
286 {
287     return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
288 }
289
290 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
291                                                                  const WorkloadInfo& info) const
292 {
293     return MakeWorkload<ClDepthToSpaceWorkload>(descriptor, info, m_CLCompileContext);
294 }
295
296 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDepthwiseConvolution2d(
297     const DepthwiseConvolution2dQueueDescriptor& descriptor,
298     const WorkloadInfo& info) const
299 {
300     return MakeWorkload<ClDepthwiseConvolutionWorkload>(descriptor, info, m_CLCompileContext);
301 }
302
303 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDequantize(const DequantizeQueueDescriptor& descriptor,
304                                                                const WorkloadInfo& info) const
305 {
306     return MakeWorkload<ClDequantizeWorkload>(descriptor, info, m_CLCompileContext);
307 }
308
309 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDetectionPostProcess(
310     const DetectionPostProcessQueueDescriptor& descriptor,
311     const WorkloadInfo& info) const
312 {
313     return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
314 }
315
316 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
317                                                              const WorkloadInfo& info) const
318 {
319     return MakeWorkload<ClDivisionFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
320 }
321
322 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
323                                                                      const WorkloadInfo& info) const
324 {
325     switch(descriptor.m_Parameters.m_Operation)
326     {
327         case UnaryOperation::Abs:
328         {
329             AbsQueueDescriptor absQueueDescriptor;
330             absQueueDescriptor.m_Inputs  = descriptor.m_Inputs;
331             absQueueDescriptor.m_Outputs = descriptor.m_Outputs;
332
333             return  std::make_unique<ClAbsWorkload>(absQueueDescriptor, info, m_CLCompileContext);
334         }
335         case UnaryOperation::Exp:
336             return std::make_unique<ClExpWorkload>(descriptor, info, m_CLCompileContext);
337         case UnaryOperation::Neg:
338             return std::make_unique<ClNegWorkload>(descriptor, info, m_CLCompileContext);
339         case UnaryOperation::Rsqrt:
340         {
341             RsqrtQueueDescriptor rsqrtQueueDescriptor;
342             rsqrtQueueDescriptor.m_Inputs  = descriptor.m_Inputs;
343             rsqrtQueueDescriptor.m_Outputs = descriptor.m_Outputs;
344
345             return std::make_unique<ClRsqrtWorkload>(rsqrtQueueDescriptor, info, m_CLCompileContext);
346         }
347         case UnaryOperation::LogicalNot:
348             return std::make_unique<ClLogicalNotWorkload>(descriptor, info, m_CLCompileContext);
349         default:
350             return nullptr;
351     }
352 }
353
354 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
355                                                           const WorkloadInfo& info) const
356 {
357     IgnoreUnused(descriptor);
358
359     ComparisonQueueDescriptor comparisonDescriptor;
360     comparisonDescriptor.m_Parameters = ComparisonDescriptor(ComparisonOperation::Equal);
361
362     return CreateComparison(comparisonDescriptor, info);
363 }
364
365 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFill(const FillQueueDescriptor& descriptor,
366                                                          const WorkloadInfo& info) const
367 {
368     return std::make_unique<ClFillWorkload>(descriptor, info, m_CLCompileContext);
369 }
370
371 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
372                                                           const WorkloadInfo& info) const
373 {
374     return MakeWorkload<ClFloorFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
375 }
376
377 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
378                                                                    const WorkloadInfo& info) const
379 {
380     return MakeWorkload<ClFullyConnectedWorkload>(descriptor,
381                                                   info,
382                                                   m_MemoryManager->GetIntraLayerManager(),
383                                                   m_CLCompileContext);
384 }
385
386 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
387                                                            const WorkloadInfo& info) const
388 {
389     return MakeWorkload<ClGatherWorkload>(descriptor, info, m_CLCompileContext);
390 }
391
392 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
393                                                             const WorkloadInfo& info) const
394 {
395     IgnoreUnused(descriptor);
396
397     ComparisonQueueDescriptor comparisonDescriptor;
398     comparisonDescriptor.m_Parameters = ComparisonDescriptor(ComparisonOperation::Greater);
399
400     return CreateComparison(comparisonDescriptor, info);
401 }
402
403 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
404                                                           const WorkloadInfo& info) const
405 {
406     return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
407 }
408
409 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateInstanceNormalization(
410     const InstanceNormalizationQueueDescriptor& descriptor,
411     const WorkloadInfo& info) const
412 {
413     return MakeWorkload<ClInstanceNormalizationWorkload>(descriptor, info, m_CLCompileContext);
414 }
415
416 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
417                                                                     const WorkloadInfo& info) const
418 {
419     return MakeWorkload<ClL2NormalizationFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
420 }
421
422 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor,
423                                                                   const WorkloadInfo& info) const
424 {
425     switch(descriptor.m_Parameters.m_Operation)
426     {
427         case LogicalBinaryOperation::LogicalAnd:
428             return std::make_unique<ClLogicalAndWorkload>(descriptor, info, m_CLCompileContext);
429         case LogicalBinaryOperation::LogicalOr:
430             return std::make_unique<ClLogicalOrWorkload>(descriptor, info, m_CLCompileContext);
431         default:
432             return nullptr;
433     }
434 }
435
436 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor,
437                                                                const WorkloadInfo& info) const
438 {
439     return MakeWorkload<ClLogSoftmaxWorkload>(descriptor,
440                                               info,
441                                               m_MemoryManager->GetIntraLayerManager(),
442                                               m_CLCompileContext);
443 }
444
445 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
446                                                          const WorkloadInfo& info) const
447 {
448     return MakeWorkload<ClLstmFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
449 }
450
451 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
452                                                             const WorkloadInfo& info) const
453 {
454     return MakeWorkload<ClMaximumWorkload>(descriptor, info, m_CLCompileContext);
455 }
456
457 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
458                                                          const WorkloadInfo& info) const
459 {
460     return MakeWorkload<ClMeanWorkload>(descriptor, info, m_CLCompileContext);
461 }
462
463 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
464                                                             const WorkloadInfo& info) const
465 {
466     if (descriptor.m_Inputs.empty() || !descriptor.m_Inputs[0])
467     {
468         throw InvalidArgumentException("ClWorkloadFactory: Invalid null input for MemCopy workload");
469     }
470
471     return MakeWorkload<CopyMemGenericWorkload>(descriptor, info);
472 }
473
474 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor,
475                                                               const WorkloadInfo& info) const
476 {
477     if (descriptor.m_Inputs.empty() || !descriptor.m_Inputs[0])
478     {
479         throw InvalidArgumentException("ClWorkloadFactory: Invalid null input for MemImport workload");
480     }
481
482     return std::make_unique<ImportMemGenericWorkload>(descriptor, info);
483 }
484
485 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
486                                                            const WorkloadInfo& info) const
487 {
488     return CreateConcat(descriptor, info);
489 }
490
491 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
492                                                             const WorkloadInfo& info) const
493 {
494     return MakeWorkload<ClMinimumWorkload>(descriptor, info, m_CLCompileContext);
495 }
496
497 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
498                                                                    const WorkloadInfo& info) const
499 {
500     return MakeWorkload<ClMultiplicationWorkload>(descriptor, info, m_CLCompileContext);
501 }
502
503 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
504                                                                   const WorkloadInfo& info) const
505 {
506     return MakeWorkload<ClNormalizationFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
507 }
508
509 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
510                                                            const WorkloadInfo& info) const
511 {
512     return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
513 }
514
515 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
516                                                         const WorkloadInfo& info) const
517 {
518     return MakeWorkload<ClPadWorkload>(descriptor, info, m_CLCompileContext);
519 }
520
521 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
522                                                             const WorkloadInfo& info) const
523 {
524     return MakeWorkload<ClPermuteWorkload>(descriptor, info, m_CLCompileContext);
525 }
526
527 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
528                                                               const WorkloadInfo& info) const
529 {
530     return MakeWorkload<ClPooling2dWorkload>(descriptor, info, m_CLCompileContext);
531 }
532
533 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
534                                                                 const WorkloadInfo& info) const
535 {
536     return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
537 }
538
539 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
540                                                           const WorkloadInfo &info) const
541 {
542     return MakeWorkload<ClPreluWorkload>(descriptor, info, m_CLCompileContext);
543 }
544
545 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& descriptor,
546                                                           const WorkloadInfo& info) const
547 {
548     return std::make_unique<ClQLstmWorkload>(descriptor, info, m_CLCompileContext);
549 }
550
551 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
552                                                              const WorkloadInfo& info) const
553 {
554     return MakeWorkload<ClQuantizeWorkload>(descriptor, info, m_CLCompileContext);
555 }
556
557 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
558                                                                   const WorkloadInfo& info) const
559 {
560     return MakeWorkload<ClQuantizedLstmWorkload>(descriptor, info, m_CLCompileContext);
561 }
562
563 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateRank(const RankQueueDescriptor& descriptor,
564                                                          const WorkloadInfo& info) const
565 {
566     return std::make_unique<ClRankWorkload>(descriptor, info);
567 }
568
569 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
570                                                             const WorkloadInfo& info) const
571 {
572     return MakeWorkload<ClReshapeWorkload>(descriptor, info, m_CLCompileContext);
573 }
574
575 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
576                                                            const WorkloadInfo& info) const
577 {
578     return MakeWorkload<ClResizeWorkload>(descriptor, info, m_CLCompileContext);
579 }
580
581 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
582                                                                    const WorkloadInfo& info) const
583 {
584     ResizeQueueDescriptor resizeDescriptor;
585     resizeDescriptor.m_Inputs  = descriptor.m_Inputs;
586     resizeDescriptor.m_Outputs = descriptor.m_Outputs;
587
588     resizeDescriptor.m_Parameters.m_Method       = ResizeMethod::Bilinear;
589     resizeDescriptor.m_Parameters.m_DataLayout   = descriptor.m_Parameters.m_DataLayout;
590     resizeDescriptor.m_Parameters.m_TargetHeight = descriptor.m_Parameters.m_TargetHeight;
591     resizeDescriptor.m_Parameters.m_TargetWidth  = descriptor.m_Parameters.m_TargetWidth;
592
593     return CreateResize(resizeDescriptor, info);
594 }
595
596 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
597                                                           const WorkloadInfo& info) const
598 {
599     IgnoreUnused(descriptor);
600
601     ElementwiseUnaryQueueDescriptor elementwiseUnaryDescriptor;
602     elementwiseUnaryDescriptor.m_Parameters = ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt);
603
604     return CreateElementwiseUnary(elementwiseUnaryDescriptor, info);
605 }
606
607 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSlice(const SliceQueueDescriptor& descriptor,
608                                                           const WorkloadInfo& info) const
609 {
610     return MakeWorkload<ClSliceWorkload>(descriptor, info, m_CLCompileContext);
611 }
612
613 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
614                                                             const WorkloadInfo& info) const
615 {
616     return std::make_unique<ClSoftmaxWorkload>(descriptor,
617                                                info,
618                                                m_MemoryManager->GetIntraLayerManager(),
619                                                m_CLCompileContext);
620 }
621
622 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
623                                                                    const WorkloadInfo& info) const
624 {
625     return MakeWorkload<ClSpaceToBatchNdWorkload>(descriptor, info, m_CLCompileContext);
626 }
627
628 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
629                                                                  const WorkloadInfo& info) const
630 {
631     return MakeWorkload<ClSpaceToDepthWorkload>(descriptor, info, m_CLCompileContext);
632 }
633
634 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
635                                                              const WorkloadInfo& info) const
636 {
637     return MakeWorkload<ClSplitterWorkload>(descriptor, info, m_CLCompileContext);
638 }
639
640 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateStack(const StackQueueDescriptor& descriptor,
641                                                           const WorkloadInfo& info) const
642 {
643     return MakeWorkload<ClStackWorkload>(descriptor, info, m_CLCompileContext);
644 }
645
646 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
647                                                                  const WorkloadInfo& info) const
648 {
649     return MakeWorkload<ClStridedSliceWorkload>(descriptor, info, m_CLCompileContext);
650 }
651
652 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
653                                                                 const WorkloadInfo& info) const
654 {
655     return MakeWorkload<ClSubtractionWorkload>(descriptor, info, m_CLCompileContext);
656 }
657
658 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& descriptor,
659                                                               const WorkloadInfo& info) const
660 {
661     return MakeWorkload<ClTransposeWorkload>(descriptor, info, m_CLCompileContext);
662 }
663
664 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateTransposeConvolution2d(
665     const TransposeConvolution2dQueueDescriptor& descriptor,
666     const WorkloadInfo& info) const
667 {
668     return MakeWorkload<ClTransposeConvolution2dWorkload>(descriptor,
669                                                           info,
670                                                           m_MemoryManager->GetIntraLayerManager(),
671                                                           m_CLCompileContext);
672 }
673
674 } // namespace armnn