2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
5 #include "ClWorkloadFactory.hpp"
6 #include "ClBackendId.hpp"
7 #include "ClBackendModelContext.hpp"
8 #include "ClContextDeserializer.hpp"
9 #include "ClContextSerializer.hpp"
13 #include <armnn/Exceptions.hpp>
14 #include <armnn/Utils.hpp>
15 #include <armnn/utility/IgnoreUnused.hpp>
16 #include <armnn/utility/NumericCast.hpp>
17 #include <armnn/utility/PolymorphicDowncast.hpp>
19 #include <backendsCommon/CpuTensorHandle.hpp>
20 #include <backendsCommon/MakeWorkloadHelper.hpp>
21 #include <backendsCommon/MemCopyWorkload.hpp>
22 #include <backendsCommon/MemImportWorkload.hpp>
24 #include <cl/ClTensorHandle.hpp>
25 #include <cl/workloads/ClWorkloads.hpp>
26 #include <cl/workloads/ClWorkloadUtils.hpp>
28 #include <arm_compute/core/CL/CLKernelLibrary.h>
29 #include <arm_compute/runtime/CL/CLBufferAllocator.h>
30 #include <arm_compute/runtime/CL/CLScheduler.h>
32 #include <Filesystem.hpp>
40 static const BackendId s_Id{ClBackendId()};
43 bool ClWorkloadFactory::IsLayerSupported(const Layer& layer,
44 Optional<DataType> dataType,
45 std::string& outReasonIfUnsupported)
47 return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
50 bool ClWorkloadFactory::IsLayerSupported(const IConnectableLayer& layer,
51 Optional<DataType> dataType,
52 std::string& outReasonIfUnsupported,
53 const ModelOptions& modelOptions)
55 return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported, modelOptions);
58 const BackendId& ClWorkloadFactory::GetBackendId() const
63 void ClWorkloadFactory::AfterWorkloadsCreated()
67 auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
68 if (modelOptions->SaveCachedNetwork())
70 // Save map to a filepath provided in ModelOptions
71 auto filePath = modelOptions->GetCachedNetworkFilePath();
72 if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath))
74 // Serialize ClContext to the file specified
75 ClContextSerializer serializer;
76 serializer.Serialize(m_CLCompileContext);
77 std::ofstream file(filePath, std::ios::out | std::ios::binary);
78 serializer.SaveSerializedToStream(file);
84 template <typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args>
85 std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
86 const WorkloadInfo& info,
91 return MakeWorkloadHelper<FloatWorkload, Uint8Workload>(descriptor, info, std::forward<Args>(args)...);
93 catch (const cl::Error& clError)
95 throw WrapClError(clError, CHECK_LOCATION());
99 template <typename Workload, typename QueueDescriptorType, typename... Args>
100 std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
101 const WorkloadInfo& info,
106 return std::make_unique<Workload>(descriptor, info, std::forward<Args>(args)...);
108 catch (const cl::Error& clError)
110 throw WrapClError(clError, CHECK_LOCATION());
114 void ClWorkloadFactory::InitializeCLCompileContext()
116 // Initialize our m_CLCompileContext using default device and context
117 auto context = arm_compute::CLKernelLibrary::get().context();
118 auto device = arm_compute::CLKernelLibrary::get().get_device();
119 m_CLCompileContext = arm_compute::CLCompileContext(context, device);
121 if (m_ModelContextPtr)
123 // Load saved programs if the user has set a filepath
124 auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
125 auto filePath = modelOptions->GetCachedNetworkFilePath();
127 && fs::exists(filePath)
128 && fs::is_regular_file(filePath)
129 && !(modelOptions->SaveCachedNetwork()))
131 // Deserialize binary file and load into m_CLCompileContext
132 ClContextDeserializer deserializer;
133 deserializer.Deserialize(m_CLCompileContext, context, device, filePath);
138 ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager)
139 : m_MemoryManager(memoryManager), m_ModelContextPtr(IBackendInternal::IBackendSpecificModelContextPtr{})
141 InitializeCLCompileContext();
144 ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager,
145 const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr)
146 : m_MemoryManager(memoryManager), m_ModelContextPtr(modelContextPtr)
148 InitializeCLCompileContext();
151 std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
152 const bool IsMemoryManaged) const
154 IgnoreUnused(IsMemoryManaged);
155 std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo);
156 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
161 std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
162 DataLayout dataLayout,
163 const bool IsMemoryManaged) const
165 IgnoreUnused(IsMemoryManaged);
166 std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout);
167 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
172 std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateSubTensorHandle(ITensorHandle& parent,
173 TensorShape const& subTensorShape,
174 unsigned int const* subTensorOrigin) const
176 arm_compute::Coordinates coords;
177 arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
179 coords.set_num_dimensions(subTensorShape.GetNumDimensions());
180 for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); i++)
182 // Arm compute indexes tensor coords in reverse order.
183 unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
184 coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex]));
187 const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
188 if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
193 return std::make_unique<ClSubTensorHandle>(
194 PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords);
197 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateAbs(const AbsQueueDescriptor& descriptor,
198 const WorkloadInfo& info) const
200 IgnoreUnused(descriptor);
202 ElementwiseUnaryQueueDescriptor elementwiseUnaryDescriptor;
203 elementwiseUnaryDescriptor.m_Parameters = ElementwiseUnaryDescriptor(UnaryOperation::Abs);
205 return CreateElementwiseUnary(elementwiseUnaryDescriptor, info);
208 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
209 const WorkloadInfo& info) const
211 return MakeWorkload<ClActivationWorkload>(descriptor, info, m_CLCompileContext);
214 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
215 const WorkloadInfo& info) const
217 return MakeWorkload<ClAdditionWorkload>(descriptor, info, m_CLCompileContext);
220 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
221 const WorkloadInfo& info) const
223 return std::make_unique<ClArgMinMaxWorkload>(descriptor, info, m_CLCompileContext);
226 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchNormalization(
227 const BatchNormalizationQueueDescriptor& descriptor,
228 const WorkloadInfo& info) const
230 return MakeWorkload<ClBatchNormalizationFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
233 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
234 const WorkloadInfo& info) const
236 return MakeWorkload<ClBatchToSpaceNdWorkload>(descriptor, info, m_CLCompileContext);
239 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor,
240 const WorkloadInfo& info) const
242 return MakeWorkload<ClComparisonWorkload>(descriptor, info, m_CLCompileContext);
245 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
246 const WorkloadInfo& info) const
248 return MakeWorkload<ClConcatWorkload>(descriptor, info, m_CLCompileContext);
251 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
252 const WorkloadInfo& info) const
254 return MakeWorkload<ClConstantWorkload>(descriptor, info, m_CLCompileContext);
257 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp16ToFp32(
258 const ConvertFp16ToFp32QueueDescriptor& descriptor,
259 const WorkloadInfo& info) const
261 return MakeWorkload<ClConvertFp16ToFp32Workload>(descriptor, info, m_CLCompileContext);
264 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp32ToFp16(
265 const ConvertFp32ToFp16QueueDescriptor& descriptor,
266 const WorkloadInfo& info) const
268 return MakeWorkload<ClConvertFp32ToFp16Workload>(descriptor, info, m_CLCompileContext);
271 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
272 const WorkloadInfo& info) const
274 bool isFastMathEnabled = false;
275 if (m_ModelContextPtr)
277 if (m_ModelContextPtr.get() != nullptr)
279 auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
282 isFastMathEnabled = modelOptions->IsFastMathEnabled();
286 return MakeWorkload<ClConvolution2dWorkload>(descriptor,
288 m_MemoryManager->GetIntraLayerManager(),
293 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
294 const WorkloadInfo& info) const
296 return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
299 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
300 const WorkloadInfo& info) const
302 return MakeWorkload<ClDepthToSpaceWorkload>(descriptor, info, m_CLCompileContext);
305 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDepthwiseConvolution2d(
306 const DepthwiseConvolution2dQueueDescriptor& descriptor,
307 const WorkloadInfo& info) const
309 return MakeWorkload<ClDepthwiseConvolutionWorkload>(descriptor, info, m_CLCompileContext);
312 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDequantize(const DequantizeQueueDescriptor& descriptor,
313 const WorkloadInfo& info) const
315 return MakeWorkload<ClDequantizeWorkload>(descriptor, info, m_CLCompileContext);
318 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDetectionPostProcess(
319 const DetectionPostProcessQueueDescriptor& descriptor,
320 const WorkloadInfo& info) const
322 return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
325 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
326 const WorkloadInfo& info) const
328 return MakeWorkload<ClDivisionFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
331 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
332 const WorkloadInfo& info) const
334 switch(descriptor.m_Parameters.m_Operation)
336 case UnaryOperation::Abs:
338 AbsQueueDescriptor absQueueDescriptor;
339 absQueueDescriptor.m_Inputs = descriptor.m_Inputs;
340 absQueueDescriptor.m_Outputs = descriptor.m_Outputs;
342 return std::make_unique<ClAbsWorkload>(absQueueDescriptor, info, m_CLCompileContext);
344 case UnaryOperation::Exp:
345 return std::make_unique<ClExpWorkload>(descriptor, info, m_CLCompileContext);
346 case UnaryOperation::Neg:
347 return std::make_unique<ClNegWorkload>(descriptor, info, m_CLCompileContext);
348 case UnaryOperation::Rsqrt:
350 RsqrtQueueDescriptor rsqrtQueueDescriptor;
351 rsqrtQueueDescriptor.m_Inputs = descriptor.m_Inputs;
352 rsqrtQueueDescriptor.m_Outputs = descriptor.m_Outputs;
354 return std::make_unique<ClRsqrtWorkload>(rsqrtQueueDescriptor, info, m_CLCompileContext);
356 case UnaryOperation::LogicalNot:
357 return std::make_unique<ClLogicalNotWorkload>(descriptor, info, m_CLCompileContext);
363 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
364 const WorkloadInfo& info) const
366 IgnoreUnused(descriptor);
368 ComparisonQueueDescriptor comparisonDescriptor;
369 comparisonDescriptor.m_Parameters = ComparisonDescriptor(ComparisonOperation::Equal);
371 return CreateComparison(comparisonDescriptor, info);
374 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFill(const FillQueueDescriptor& descriptor,
375 const WorkloadInfo& info) const
377 return std::make_unique<ClFillWorkload>(descriptor, info, m_CLCompileContext);
380 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
381 const WorkloadInfo& info) const
383 return MakeWorkload<ClFloorFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
386 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
387 const WorkloadInfo& info) const
389 return MakeWorkload<ClFullyConnectedWorkload>(descriptor,
391 m_MemoryManager->GetIntraLayerManager(),
395 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
396 const WorkloadInfo& info) const
398 return MakeWorkload<ClGatherWorkload>(descriptor, info, m_CLCompileContext);
401 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
402 const WorkloadInfo& info) const
404 IgnoreUnused(descriptor);
406 ComparisonQueueDescriptor comparisonDescriptor;
407 comparisonDescriptor.m_Parameters = ComparisonDescriptor(ComparisonOperation::Greater);
409 return CreateComparison(comparisonDescriptor, info);
412 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
413 const WorkloadInfo& info) const
415 return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
418 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateInstanceNormalization(
419 const InstanceNormalizationQueueDescriptor& descriptor,
420 const WorkloadInfo& info) const
422 return MakeWorkload<ClInstanceNormalizationWorkload>(descriptor, info, m_CLCompileContext);
425 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
426 const WorkloadInfo& info) const
428 return MakeWorkload<ClL2NormalizationFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
431 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor,
432 const WorkloadInfo& info) const
434 switch(descriptor.m_Parameters.m_Operation)
436 case LogicalBinaryOperation::LogicalAnd:
437 return std::make_unique<ClLogicalAndWorkload>(descriptor, info, m_CLCompileContext);
438 case LogicalBinaryOperation::LogicalOr:
439 return std::make_unique<ClLogicalOrWorkload>(descriptor, info, m_CLCompileContext);
445 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor,
446 const WorkloadInfo& info) const
448 return MakeWorkload<ClLogSoftmaxWorkload>(descriptor,
450 m_MemoryManager->GetIntraLayerManager(),
454 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
455 const WorkloadInfo& info) const
457 return MakeWorkload<ClLstmFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
460 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
461 const WorkloadInfo& info) const
463 return MakeWorkload<ClMaximumWorkload>(descriptor, info, m_CLCompileContext);
466 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
467 const WorkloadInfo& info) const
469 return MakeWorkload<ClMeanWorkload>(descriptor, info, m_CLCompileContext);
472 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
473 const WorkloadInfo& info) const
475 if (descriptor.m_Inputs.empty() || !descriptor.m_Inputs[0])
477 throw InvalidArgumentException("ClWorkloadFactory: Invalid null input for MemCopy workload");
480 return MakeWorkload<CopyMemGenericWorkload>(descriptor, info);
483 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor,
484 const WorkloadInfo& info) const
486 if (descriptor.m_Inputs.empty() || !descriptor.m_Inputs[0])
488 throw InvalidArgumentException("ClWorkloadFactory: Invalid null input for MemImport workload");
491 return std::make_unique<ImportMemGenericWorkload>(descriptor, info);
494 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
495 const WorkloadInfo& info) const
497 return CreateConcat(descriptor, info);
500 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
501 const WorkloadInfo& info) const
503 return MakeWorkload<ClMinimumWorkload>(descriptor, info, m_CLCompileContext);
506 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
507 const WorkloadInfo& info) const
509 return MakeWorkload<ClMultiplicationWorkload>(descriptor, info, m_CLCompileContext);
512 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
513 const WorkloadInfo& info) const
515 return MakeWorkload<ClNormalizationFloatWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
518 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
519 const WorkloadInfo& info) const
521 return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
524 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
525 const WorkloadInfo& info) const
527 return MakeWorkload<ClPadWorkload>(descriptor, info, m_CLCompileContext);
530 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
531 const WorkloadInfo& info) const
533 return MakeWorkload<ClPermuteWorkload>(descriptor, info, m_CLCompileContext);
536 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
537 const WorkloadInfo& info) const
539 return MakeWorkload<ClPooling2dWorkload>(descriptor, info, m_CLCompileContext);
542 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
543 const WorkloadInfo& info) const
545 return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info, m_CLCompileContext);
548 std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
549 const WorkloadInfo &info) const
551 return MakeWorkload<ClPreluWorkload>(descriptor, info, m_CLCompileContext);
554 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& descriptor,
555 const WorkloadInfo& info) const
557 return std::make_unique<ClQLstmWorkload>(descriptor, info, m_CLCompileContext);
560 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
561 const WorkloadInfo& info) const
563 return MakeWorkload<ClQuantizeWorkload>(descriptor, info, m_CLCompileContext);
566 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
567 const WorkloadInfo& info) const
569 return MakeWorkload<ClQuantizedLstmWorkload>(descriptor, info, m_CLCompileContext);
572 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateRank(const RankQueueDescriptor& descriptor,
573 const WorkloadInfo& info) const
575 return std::make_unique<ClRankWorkload>(descriptor, info);
578 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
579 const WorkloadInfo& info) const
581 return MakeWorkload<ClReshapeWorkload>(descriptor, info, m_CLCompileContext);
584 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
585 const WorkloadInfo& info) const
587 return MakeWorkload<ClResizeWorkload>(descriptor, info, m_CLCompileContext);
590 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
591 const WorkloadInfo& info) const
593 ResizeQueueDescriptor resizeDescriptor;
594 resizeDescriptor.m_Inputs = descriptor.m_Inputs;
595 resizeDescriptor.m_Outputs = descriptor.m_Outputs;
597 resizeDescriptor.m_Parameters.m_Method = ResizeMethod::Bilinear;
598 resizeDescriptor.m_Parameters.m_DataLayout = descriptor.m_Parameters.m_DataLayout;
599 resizeDescriptor.m_Parameters.m_TargetHeight = descriptor.m_Parameters.m_TargetHeight;
600 resizeDescriptor.m_Parameters.m_TargetWidth = descriptor.m_Parameters.m_TargetWidth;
602 return CreateResize(resizeDescriptor, info);
605 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
606 const WorkloadInfo& info) const
608 IgnoreUnused(descriptor);
610 ElementwiseUnaryQueueDescriptor elementwiseUnaryDescriptor;
611 elementwiseUnaryDescriptor.m_Parameters = ElementwiseUnaryDescriptor(UnaryOperation::Rsqrt);
613 return CreateElementwiseUnary(elementwiseUnaryDescriptor, info);
616 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSlice(const SliceQueueDescriptor& descriptor,
617 const WorkloadInfo& info) const
619 return MakeWorkload<ClSliceWorkload>(descriptor, info, m_CLCompileContext);
622 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
623 const WorkloadInfo& info) const
625 return std::make_unique<ClSoftmaxWorkload>(descriptor,
627 m_MemoryManager->GetIntraLayerManager(),
631 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
632 const WorkloadInfo& info) const
634 return MakeWorkload<ClSpaceToBatchNdWorkload>(descriptor, info, m_CLCompileContext);
637 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
638 const WorkloadInfo& info) const
640 return MakeWorkload<ClSpaceToDepthWorkload>(descriptor, info, m_CLCompileContext);
643 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
644 const WorkloadInfo& info) const
646 return MakeWorkload<ClSplitterWorkload>(descriptor, info, m_CLCompileContext);
649 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateStack(const StackQueueDescriptor& descriptor,
650 const WorkloadInfo& info) const
652 return MakeWorkload<ClStackWorkload>(descriptor, info, m_CLCompileContext);
655 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
656 const WorkloadInfo& info) const
658 return MakeWorkload<ClStridedSliceWorkload>(descriptor, info, m_CLCompileContext);
661 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
662 const WorkloadInfo& info) const
664 return MakeWorkload<ClSubtractionWorkload>(descriptor, info, m_CLCompileContext);
667 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& descriptor,
668 const WorkloadInfo& info) const
670 return MakeWorkload<ClTransposeWorkload>(descriptor, info, m_CLCompileContext);
673 std::unique_ptr<IWorkload> ClWorkloadFactory::CreateTransposeConvolution2d(
674 const TransposeConvolution2dQueueDescriptor& descriptor,
675 const WorkloadInfo& info) const
677 return MakeWorkload<ClTransposeConvolution2dWorkload>(descriptor,
679 m_MemoryManager->GetIntraLayerManager(),