IVGCVSW-5483 'Implement Loading and Saving to File'
[platform/upstream/armnn.git] / src / backends / cl / ClWorkloadFactory.cpp
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "ClWorkloadFactory.hpp"
6 #include "ClBackendId.hpp"
7 #include "ClBackendModelContext.hpp"
8 #include "ClContextDeserializer.hpp"
9 #include "ClContextSerializer.hpp"
10
11 #include <Layer.hpp>
12
13 #include <armnn/Exceptions.hpp>
14 #include <armnn/Utils.hpp>
15 #include <armnn/utility/IgnoreUnused.hpp>
16 #include <armnn/utility/NumericCast.hpp>
17 #include <armnn/utility/PolymorphicDowncast.hpp>
18
19 #include <backendsCommon/CpuTensorHandle.hpp>
20 #include <backendsCommon/MakeWorkloadHelper.hpp>
21 #include <backendsCommon/MemCopyWorkload.hpp>
22 #include <backendsCommon/MemImportWorkload.hpp>
23
24 #include <cl/ClTensorHandle.hpp>
25 #include <cl/workloads/ClWorkloads.hpp>
26 #include <cl/workloads/ClWorkloadUtils.hpp>
27
28 #include <arm_compute/core/CL/CLKernelLibrary.h>
29 #include <arm_compute/runtime/CL/CLBufferAllocator.h>
30 #include <arm_compute/runtime/CL/CLScheduler.h>
31
32 #include <Filesystem.hpp>
33 #include <fstream>
34
35 namespace armnn
36 {
37
38 namespace
39 {
40 static const BackendId s_Id{ClBackendId()};
41 }
42
43 bool ClWorkloadFactory::IsLayerSupported(const Layer& layer,
44                                          Optional<DataType> dataType,
45                                          std::string& outReasonIfUnsupported)
46 {
47     return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
48 }
49
50 bool ClWorkloadFactory::IsLayerSupported(const IConnectableLayer& layer,
51                                          Optional<DataType> dataType,
52                                          std::string& outReasonIfUnsupported,
53                                          const ModelOptions& modelOptions)
54 {
55     return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported, modelOptions);
56 }
57
58 const BackendId& ClWorkloadFactory::GetBackendId() const
59 {
60     return s_Id;
61 }
62
63 void ClWorkloadFactory::AfterWorkloadsCreated()
64 {
65     if(m_ModelContextPtr)
66     {
67         auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
68         if (modelOptions->SaveCachedNetwork())
69         {
70             // Save map to a filepath provided in ModelOptions
71             auto filePath = modelOptions->GetCachedNetworkFilePath();
72             if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath))
73             {
74                 // Serialize ClContext to the file specified
75                 ClContextSerializer serializer;
76                 serializer.Serialize(m_CLCompileContext);
77                 std::ofstream file(filePath, std::ios::out | std::ios::binary);
78                 serializer.SaveSerializedToStream(file);
79             }
80         }
81     }
82 }
83
84 template <typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args>
85 std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
86                                                            const WorkloadInfo& info,
87                                                            Args&&... args)
88 {
89     try
90     {
91         return MakeWorkloadHelper<FloatWorkload, Uint8Workload>(descriptor, info, std::forward<Args>(args)...);
92     }
93     catch (const cl::Error& clError)
94     {
95         throw WrapClError(clError, CHECK_LOCATION());
96     }
97 }
98
99 template <typename Workload, typename QueueDescriptorType, typename... Args>
100 std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
101                                                            const WorkloadInfo& info,
102                                                            Args&&... args)
103 {
104     try
105     {
106         return std::make_unique<Workload>(descriptor, info, std::forward<Args>(args)...);
107     }
108     catch (const cl::Error& clError)
109     {
110         throw WrapClError(clError, CHECK_LOCATION());
111     }
112 }
113
114 void ClWorkloadFactory::InitializeCLCompileContext()
115 {
116     // Initialize our m_CLCompileContext using default device and context
117     auto context = arm_compute::CLKernelLibrary::get().context();
118     auto device  = arm_compute::CLKernelLibrary::get().get_device();
119     m_CLCompileContext = arm_compute::CLCompileContext(context, device);
120
121     if (m_ModelContextPtr)
122     {
123         // Load saved programs if the user has set a filepath
124         auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
125         auto filePath = modelOptions->GetCachedNetworkFilePath();
126         if (filePath != ""
127             && fs::exists(filePath)
128             && fs::is_regular_file(filePath)
129             && !(modelOptions->SaveCachedNetwork()))
130         {
131             // Deserialize binary file and load into m_CLCompileContext
132             ClContextDeserializer deserializer;
133             deserializer.Deserialize(m_CLCompileContext, context, device, filePath);
134         }
135     }
136 }
137
138 ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager)
139     : m_MemoryManager(memoryManager), m_ModelContextPtr(IBackendInternal::IBackendSpecificModelContextPtr{})
140 {
141     InitializeCLCompileContext();
142 }
143
144 ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager,
145                                      const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr)
146     : m_MemoryManager(memoryManager), m_ModelContextPtr(modelContextPtr)
147 {
148     InitializeCLCompileContext();
149 }
150
151 std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
152                                                                      const bool IsMemoryManaged) const
153 {
154     IgnoreUnused(IsMemoryManaged);
155     std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo);
156     tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
157
158     return tensorHandle;
159 }
160
161 std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
162                                                                      DataLayout dataLayout,
163                                                                      const bool IsMemoryManaged) const
164 {
165     IgnoreUnused(IsMemoryManaged);
166     std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout);
167     tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
168
169     return tensorHandle;
170 }
171
172 std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateSubTensorHandle(ITensorHandle& parent,
173                                                                         TensorShape const& subTensorShape,
174                                                                         unsigned int const* subTensorOrigin) const
175 {
176     arm_compute::Coordinates coords;
177     arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
178
179     coords.set_num_dimensions(subTensorShape.GetNumDimensions());
180     for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); i++)
181     {
182         // Arm compute indexes tensor coords in reverse order.
183         unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
184         coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex]));
185     }
186
187     const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
188     if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
189     {
190         return nullptr;
191     }
192
193     return std::make_unique<ClSubTensorHandle>(
194         PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords);
195 }
196
197 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateAbs(const AbsQueueDescriptor& descriptor,
198                                                         const WorkloadInfo& info) const
199 {
200     IgnoreUnused(descriptor);
201
202     ElementwiseUnaryQueueDescriptor elementwiseUnaryDescriptor;
203     elementwiseUnaryDescriptor.m_Parameters = ElementwiseUnaryDescriptor(UnaryOperation::Abs);
204
205     return CreateElementwiseUnary(elementwiseUnaryDescriptor, info);
206 }
207
208 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
209                                                                const WorkloadInfo& info) const
210 {
211     return MakeWorkload<ClActivationWorkload>(descriptor, info, m_CLCompileContext);
212 }
213
214 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
215                                                              const WorkloadInfo& info) const
216 {
217     return MakeWorkload<ClAdditionWorkload>(descriptor, info, m_CLCompileContext);
218 }
219
220 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
221                                                               const WorkloadInfo& info) const
222 {
223     return std::make_unique<ClArgMinMaxWorkload>(descriptor, info, m_CLCompileContext);
224 }
225
226 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchNormalization(
227     const BatchNormalizationQueueDescriptor& descriptor,
228     const WorkloadInfo& info) const
229 {
230     return MakeWorkload<ClBatchNormalizationFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
231 }
232
233 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
234                                                                    const WorkloadInfo& info) const
235 {
236     return MakeWorkload<ClBatchToSpaceNdWorkload>(descriptor, info, m_CLCompileContext);
237 }
238
239 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor,
240                                                                const WorkloadInfo& info) const
241 {
242     return MakeWorkload<ClComparisonWorkload>(descriptor, info, m_CLCompileContext);
243 }
244
245 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
246                                                            const WorkloadInfo& info) const
247 {
248     return MakeWorkload<ClConcatWorkload>(descriptor, info, m_CLCompileContext);
249 }
250
251 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
252                                                              const WorkloadInfo& info) const
253 {
254     return MakeWorkload<ClConstantWorkload>(descriptor, info, m_CLCompileContext);
255 }
256
257 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp16ToFp32(
258     const ConvertFp16ToFp32QueueDescriptor& descriptor,
259     const WorkloadInfo& info) const
260 {
261     return MakeWorkload<ClConvertFp16ToFp32Workload>(descriptor, info, m_CLCompileContext);
262 }
263
264 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp32ToFp16(
265     const ConvertFp32ToFp16QueueDescriptor& descriptor,
266     const WorkloadInfo& info) const
267 {
268     return MakeWorkload<ClConvertFp32ToFp16Workload>(descriptor, info, m_CLCompileContext);
269 }
270
271 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
272                                                                   const WorkloadInfo& info) const
273 {
274     bool isFastMathEnabled = false;
275     if (m_ModelContextPtr)
276     {
277         if (m_ModelContextPtr.get() != nullptr)
278         {
279             auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
280             if (modelOptions)
281             {
282                 isFastMathEnabled = modelOptions->IsFastMathEnabled();
283             }
284         }
285     }
286     return MakeWorkload<ClConvolution2dWorkload>(descriptor,
287                                                  info,
288                                                  m_MemoryManager->GetIntraLayerManager(),
289                                                  m_CLCompileContext,
290                                                  isFastMathEnabled);
291 }
292
293 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
294                                                           const WorkloadInfo& info) const
295 {
296     return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
297 }
298
299 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
300                                                                  const WorkloadInfo& info) const
301 {
302     return MakeWorkload<ClDepthToSpaceWorkload>(descriptor, info, m_CLCompileContext);
303 }
304
305 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDepthwiseConvolution2d(
306     const DepthwiseConvolution2dQueueDescriptor& descriptor,
307     const WorkloadInfo& info) const
308 {
309     return MakeWorkload<ClDepthwiseConvolutionWorkload>(descriptor, info, m_CLCompileContext);
310 }
311
312 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDequantize(const DequantizeQueueDescriptor& descriptor,
313                                                                const WorkloadInfo& info) const
314 {
315     return MakeWorkload<ClDequantizeWorkload>(descriptor, info, m_CLCompileContext);
316 }
317
318 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDetectionPostProcess(
319     const DetectionPostProcessQueueDescriptor& descriptor,
320     const WorkloadInfo& info) const
321 {
322     return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
323 }
324
325 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
326                                                              const WorkloadInfo& info) const
327 {
328     return MakeWorkload<ClDivisionFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
329 }
330
331 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
332                                                                      const WorkloadInfo& info) const
333 {
334     switch(descriptor.m_Parameters.m_Operation)
335     {
336         case UnaryOperation::Abs:
337         {
338             AbsQueueDescriptor absQueueDescriptor;
339             absQueueDescriptor.m_Inputs  = descriptor.m_Inputs;
340             absQueueDescriptor.m_Outputs = descriptor.m_Outputs;
341
342             return  std::make_unique<ClAbsWorkload>(absQueueDescriptor, info, m_CLCompileContext);
343         }
344         case UnaryOperation::Exp:
345             return std::make_unique<ClExpWorkload>(descriptor, info, m_CLCompileContext);
346         case UnaryOperation::Neg:
347             return std::make_unique<ClNegWorkload>(descriptor, info, m_CLCompileContext);
348         case UnaryOperation::Rsqrt:
349         {
350             RsqrtQueueDescriptor rsqrtQueueDescriptor;
351             rsqrtQueueDescriptor.m_Inputs  = descriptor.m_Inputs;
352             rsqrtQueueDescriptor.m_Outputs = descriptor.m_Outputs;
353
354             return std::make_unique<ClRsqrtWorkload>(rsqrtQueueDescriptor, info, m_CLCompileContext);
355         }
356         case UnaryOperation::LogicalNot:
357             return std::make_unique<ClLogicalNotWorkload>(descriptor, info, m_CLCompileContext);
358         default:
359             return nullptr;
360     }
361 }
362
363 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
364                                                           const WorkloadInfo& info) const
365 {
366     IgnoreUnused(descriptor);
367
368     ComparisonQueueDescriptor comparisonDescriptor;
369     comparisonDescriptor.m_Parameters = ComparisonDescriptor(ComparisonOperation::Equal);
370
371     return CreateComparison(comparisonDescriptor, info);
372 }
373
374 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFill(const FillQueueDescriptor& descriptor,
375                                                          const WorkloadInfo& info) const
376 {
377     return std::make_unique<ClFillWorkload>(descriptor, info, m_CLCompileContext);
378 }
379
380 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
381                                                           const WorkloadInfo& info) const
382 {
383     return MakeWorkload<ClFloorFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
384 }
385
386 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
387                                                                    const WorkloadInfo& info) const
388 {
389     return MakeWorkload<ClFullyConnectedWorkload>(descriptor,
390                                                   info,
391                                                   m_MemoryManager->GetIntraLayerManager(),
392                                                   m_CLCompileContext);
393 }
394
395 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
396                                                            const WorkloadInfo& info) const
397 {
398     return MakeWorkload<ClGatherWorkload>(descriptor, info, m_CLCompileContext);
399 }
400
401 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
402                                                             const WorkloadInfo& info) const
403 {
404     IgnoreUnused(descriptor);
405
406     ComparisonQueueDescriptor comparisonDescriptor;
407     comparisonDescriptor.m_Parameters = ComparisonDescriptor(ComparisonOperation::Greater);
408
409     return CreateComparison(comparisonDescriptor, info);
410 }
411
412 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
413                                                           const WorkloadInfo& info) const
414 {
415     return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
416 }
417
418 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateInstanceNormalization(
419     const InstanceNormalizationQueueDescriptor& descriptor,
420     const WorkloadInfo& info) const
421 {
422     return MakeWorkload<ClInstanceNormalizationWorkload>(descriptor, info, m_CLCompileContext);
423 }
424
425 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
426                                                                     const WorkloadInfo& info) const
427 {
428     return MakeWorkload<ClL2NormalizationFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
429 }
430
431 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor,
432                                                                   const WorkloadInfo& info) const
433 {
434     switch(descriptor.m_Parameters.m_Operation)
435     {
436         case LogicalBinaryOperation::LogicalAnd:
437             return std::make_unique<ClLogicalAndWorkload>(descriptor, info, m_CLCompileContext);
438         case LogicalBinaryOperation::LogicalOr:
439             return std::make_unique<ClLogicalOrWorkload>(descriptor, info, m_CLCompileContext);
440         default:
441             return nullptr;
442     }
443 }
444
445 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor,
446                                                                const WorkloadInfo& info) const
447 {
448     return MakeWorkload<ClLogSoftmaxWorkload>(descriptor,
449                                               info,
450                                               m_MemoryManager->GetIntraLayerManager(),
451                                               m_CLCompileContext);
452 }
453
454 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
455                                                          const WorkloadInfo& info) const
456 {
457     return MakeWorkload<ClLstmFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
458 }
459
460 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
461                                                             const WorkloadInfo& info) const
462 {
463     return MakeWorkload<ClMaximumWorkload>(descriptor, info, m_CLCompileContext);
464 }
465
466 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
467                                                          const WorkloadInfo& info) const
468 {
469     return MakeWorkload<ClMeanWorkload>(descriptor, info, m_CLCompileContext);
470 }
471
472 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
473                                                             const WorkloadInfo& info) const
474 {
475     if (descriptor.m_Inputs.empty() || !descriptor.m_Inputs[0])
476     {
477         throw InvalidArgumentException("ClWorkloadFactory: Invalid null input for MemCopy workload");
478     }
479
480     return MakeWorkload<CopyMemGenericWorkload>(descriptor, info);
481 }
482
483 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor,
484                                                               const WorkloadInfo& info) const
485 {
486     if (descriptor.m_Inputs.empty() || !descriptor.m_Inputs[0])
487     {
488         throw InvalidArgumentException("ClWorkloadFactory: Invalid null input for MemImport workload");
489     }
490
491     return std::make_unique<ImportMemGenericWorkload>(descriptor, info);
492 }
493
494 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
495                                                            const WorkloadInfo& info) const
496 {
497     return CreateConcat(descriptor, info);
498 }
499
500 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
501                                                             const WorkloadInfo& info) const
502 {
503     return MakeWorkload<ClMinimumWorkload>(descriptor, info, m_CLCompileContext);
504 }
505
506 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
507                                                                    const WorkloadInfo& info) const
508 {
509     return MakeWorkload<ClMultiplicationWorkload>(descriptor, info, m_CLCompileContext);
510 }
511
512 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
513                                                                   const WorkloadInfo& info) const
514 {
515     return MakeWorkload<ClNormalizationFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
516 }
517
518 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
519                                                            const WorkloadInfo& info) const
520 {
521     return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
522 }
523
524 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
525                                                         const WorkloadInfo& info) const
526 {
527     return MakeWorkload<ClPadWorkload>(descriptor, info, m_CLCompileContext);
528 }
529
530 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
531                                                             const WorkloadInfo& info) const
532 {
533     return MakeWorkload<ClPermuteWorkload>(descriptor, info, m_CLCompileContext);
534 }
535
536 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
537                                                               const WorkloadInfo& info) const
538 {
539     return MakeWorkload<ClPooling2dWorkload>(descriptor, info, m_CLCompileContext);
540 }
541
542 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
543                                                                 const WorkloadInfo& info) const
544 {
545     return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
546 }
547
548 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
549                                                           const WorkloadInfo &info) const
550 {
551     return MakeWorkload<ClPreluWorkload>(descriptor, info, m_CLCompileContext);
552 }
553
554 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& descriptor,
555                                                           const WorkloadInfo& info) const
556 {
557     return std::make_unique<ClQLstmWorkload>(descriptor, info, m_CLCompileContext);
558 }
559
560 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
561                                                              const WorkloadInfo& info) const
562 {
563     return MakeWorkload<ClQuantizeWorkload>(descriptor, info, m_CLCompileContext);
564 }
565
566 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
567                                                                   const WorkloadInfo& info) const
568 {
569     return MakeWorkload<ClQuantizedLstmWorkload>(descriptor, info, m_CLCompileContext);
570 }
571
572 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateRank(const RankQueueDescriptor& descriptor,
573                                                          const WorkloadInfo& info) const
574 {
575     return std::make_unique<ClRankWorkload>(descriptor, info);
576 }
577
578 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
579                                                             const WorkloadInfo& info) const
580 {
581     return MakeWorkload<ClReshapeWorkload>(descriptor, info, m_CLCompileContext);
582 }
583
584 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
585                                                            const WorkloadInfo& info) const
586 {
587     return MakeWorkload<ClResizeWorkload>(descriptor, info, m_CLCompileContext);
588 }
589
590 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
591                                                                    const WorkloadInfo& info) const
592 {
593     ResizeQueueDescriptor resizeDescriptor;
594     resizeDescriptor.m_Inputs  = descriptor.m_Inputs;
595     resizeDescriptor.m_Outputs = descriptor.m_Outputs;
596
597     resizeDescriptor.m_Parameters.m_Method       = ResizeMethod::Bilinear;
598     resizeDescriptor.m_Parameters.m_DataLayout   = descriptor.m_Parameters.m_DataLayout;
599     resizeDescriptor.m_Parameters.m_TargetHeight = descriptor.m_Parameters.m_TargetHeight;
600     resizeDescriptor.m_Parameters.m_TargetWidth  = descriptor.m_Parameters.m_TargetWidth;
601
602     return CreateResize(resizeDescriptor, info);
603 }
604
605 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
606                                                           const WorkloadInfo& info) const
607 {
608     IgnoreUnused(descriptor);
609
610     ElementwiseUnaryQueueDescriptor elementwiseUnaryDescriptor;
611     elementwiseUnaryDescriptor.m_Parameters = ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt);
612
613     return CreateElementwiseUnary(elementwiseUnaryDescriptor, info);
614 }
615
616 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSlice(const SliceQueueDescriptor& descriptor,
617                                                           const WorkloadInfo& info) const
618 {
619     return MakeWorkload<ClSliceWorkload>(descriptor, info, m_CLCompileContext);
620 }
621
622 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
623                                                             const WorkloadInfo& info) const
624 {
625     return std::make_unique<ClSoftmaxWorkload>(descriptor,
626                                                info,
627                                                m_MemoryManager->GetIntraLayerManager(),
628                                                m_CLCompileContext);
629 }
630
631 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
632                                                                    const WorkloadInfo& info) const
633 {
634     return MakeWorkload<ClSpaceToBatchNdWorkload>(descriptor, info, m_CLCompileContext);
635 }
636
637 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
638                                                                  const WorkloadInfo& info) const
639 {
640     return MakeWorkload<ClSpaceToDepthWorkload>(descriptor, info, m_CLCompileContext);
641 }
642
643 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
644                                                              const WorkloadInfo& info) const
645 {
646     return MakeWorkload<ClSplitterWorkload>(descriptor, info, m_CLCompileContext);
647 }
648
649 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateStack(const StackQueueDescriptor& descriptor,
650                                                           const WorkloadInfo& info) const
651 {
652     return MakeWorkload<ClStackWorkload>(descriptor, info, m_CLCompileContext);
653 }
654
655 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
656                                                                  const WorkloadInfo& info) const
657 {
658     return MakeWorkload<ClStridedSliceWorkload>(descriptor, info, m_CLCompileContext);
659 }
660
661 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
662                                                                 const WorkloadInfo& info) const
663 {
664     return MakeWorkload<ClSubtractionWorkload>(descriptor, info, m_CLCompileContext);
665 }
666
667 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& descriptor,
668                                                               const WorkloadInfo& info) const
669 {
670     return MakeWorkload<ClTransposeWorkload>(descriptor, info, m_CLCompileContext);
671 }
672
673 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateTransposeConvolution2d(
674     const TransposeConvolution2dQueueDescriptor& descriptor,
675     const WorkloadInfo& info) const
676 {
677     return MakeWorkload<ClTransposeConvolution2dWorkload>(descriptor,
678                                                           info,
679                                                           m_MemoryManager->GetIntraLayerManager(),
680                                                           m_CLCompileContext);
681 }
682
683 } // namespace armnn