2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <backends/OutputHandler.hpp>
8 #include <backends/aclCommon/memory/BaseMemoryManager.hpp>
10 #include <boost/core/ignore_unused.hpp>
11 #include <boost/optional.hpp>
16 // Neon workload factory.
17 class NeonWorkloadFactory : public IWorkloadFactory
20 NeonWorkloadFactory();
22 virtual Compute GetCompute() const override { return Compute::CpuAcc; }
24 static bool IsLayerSupported(const Layer& layer, boost::optional<DataType> dataType,
25 std::string& outReasonIfUnsupported);
27 virtual bool SupportsSubTensors() const override { return true; }
29 virtual std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
30 TensorShape const& subTensorShape,
31 unsigned int const* subTensorOrigin) const override;
33 virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
35 virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
36 DataLayout dataLayout) const override;
38 virtual std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
39 const WorkloadInfo& info) const override;
41 virtual std::unique_ptr<IWorkload> CreateOutput(const OutputQueueDescriptor& descriptor,
42 const WorkloadInfo& info) const override;
44 virtual std::unique_ptr<IWorkload> CreateActivation(const ActivationQueueDescriptor& descriptor,
45 const WorkloadInfo& info) const override;
47 virtual std::unique_ptr<IWorkload> CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
48 const WorkloadInfo& info) const override;
50 virtual std::unique_ptr<IWorkload> CreateSplitter(const SplitterQueueDescriptor& descriptor,
51 const WorkloadInfo& info) const override;
53 virtual std::unique_ptr<IWorkload> CreateMerger(const MergerQueueDescriptor& descriptor,
54 const WorkloadInfo& info) const override;
56 virtual std::unique_ptr<IWorkload> CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
57 const WorkloadInfo& info) const override;
59 virtual std::unique_ptr<IWorkload> CreatePermute(const PermuteQueueDescriptor& descriptor,
60 const WorkloadInfo& info) const override;
62 virtual std::unique_ptr<IWorkload> CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
63 const WorkloadInfo& info) const override;
65 virtual std::unique_ptr<IWorkload> CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
66 const WorkloadInfo& info) const override;
68 virtual std::unique_ptr<IWorkload> CreateDepthwiseConvolution2d(
69 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const override;
71 virtual std::unique_ptr<IWorkload> CreateNormalization(const NormalizationQueueDescriptor& descriptor,
72 const WorkloadInfo& info) const override;
74 virtual std::unique_ptr<IWorkload> CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
75 const WorkloadInfo& info) const override;
77 virtual std::unique_ptr<IWorkload> CreateAddition(const AdditionQueueDescriptor& descriptor,
78 const WorkloadInfo& info) const override;
80 virtual std::unique_ptr<IWorkload> CreateBatchNormalization(const BatchNormalizationQueueDescriptor& descriptor,
81 const WorkloadInfo& info) const override;
83 virtual std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
84 const WorkloadInfo& info) const override;
86 virtual std::unique_ptr<IWorkload> CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
87 const WorkloadInfo& info) const override;
89 virtual std::unique_ptr<IWorkload> CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
90 const WorkloadInfo& info) const override;
92 virtual std::unique_ptr<IWorkload> CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
93 const WorkloadInfo& info) const override;
95 virtual std::unique_ptr<IWorkload> CreateConstant(const ConstantQueueDescriptor& descriptor,
96 const WorkloadInfo& info) const override;
98 virtual std::unique_ptr<IWorkload> CreateReshape(const ReshapeQueueDescriptor& descriptor,
99 const WorkloadInfo& info) const override;
101 virtual std::unique_ptr<IWorkload> CreateFloor(const FloorQueueDescriptor& descriptor,
102 const WorkloadInfo& info) const override;
104 virtual std::unique_ptr<IWorkload> CreateLstm(const LstmQueueDescriptor& descriptor,
105 const WorkloadInfo& info) const override;
107 virtual std::unique_ptr<IWorkload> CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
108 const WorkloadInfo& info) const override;
110 virtual std::unique_ptr<IWorkload> CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
111 const WorkloadInfo& info) const override;
113 virtual std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor,
114 const WorkloadInfo& info) const override;
116 virtual std::unique_ptr<IWorkload> CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
117 const WorkloadInfo& info) const override;
119 virtual std::unique_ptr<IWorkload> CreateMean(const MeanQueueDescriptor& descriptor,
120 const WorkloadInfo& Info) const override;
122 virtual std::unique_ptr<IWorkload> CreatePad(const PadQueueDescriptor& descriptor,
123 const WorkloadInfo& info) const override;
125 virtual void Finalize() override;
127 virtual void Release() override;
129 virtual void Acquire() override;
132 #ifdef ARMCOMPUTENEON_ENABLED
133 mutable NeonMemoryManager m_MemoryManager;