2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include "WorkloadData.hpp"
8 #include "WorkloadInfo.hpp"
10 #include <Profiling.hpp>
17 /// Workload interface to enqueue a layer computation.
21 virtual ~IWorkload() {}
23 virtual void Execute() const = 0;
26 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
27 // in the various workload factories.
28 // There should never be an instantiation of a NullWorkload.
29 class NullWorkload : public IWorkload
31 NullWorkload()=delete;
34 template <typename QueueDescriptor>
35 class BaseWorkload : public IWorkload
39 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
42 m_Data.Validate(info);
45 const QueueDescriptor& GetData() const { return m_Data; }
48 const QueueDescriptor m_Data;
52 template <typename QueueDescriptor, armnn::DataType... DataTypes>
53 class TypedWorkload : public BaseWorkload<QueueDescriptor>
57 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
58 : BaseWorkload<QueueDescriptor>(descriptor, info)
60 std::vector<armnn::DataType> dataTypes = {DataTypes...};
61 armnn::DataType expectedInputType;
63 if (!info.m_InputTensorInfos.empty())
65 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
67 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
69 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
71 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
72 info.m_InputTensorInfos.end(),
74 return it.GetDataType() == expectedInputType;
76 "Trying to create workload with incorrect type");
78 armnn::DataType expectedOutputType;
80 if (!info.m_OutputTensorInfos.empty())
82 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
84 if (!info.m_InputTensorInfos.empty())
86 if (expectedOutputType != expectedInputType)
88 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
91 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
93 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
95 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
96 info.m_OutputTensorInfos.end(),
98 return it.GetDataType() == expectedOutputType;
100 "Trying to create workload with incorrect type");
105 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
106 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
110 MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
111 : BaseWorkload<QueueDescriptor>(descriptor, info)
113 BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
114 info.m_InputTensorInfos.end(),
116 return it.GetDataType() == InputDataType;
118 "Trying to create workload with incorrect type");
119 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
120 info.m_OutputTensorInfos.end(),
122 return it.GetDataType() == OutputDataType;
124 "Trying to create workload with incorrect type");
128 template <typename QueueDescriptor>
129 using FloatWorkload = TypedWorkload<QueueDescriptor,
130 armnn::DataType::Float16,
131 armnn::DataType::Float32>;
133 template <typename QueueDescriptor>
134 using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
136 template <typename QueueDescriptor>
137 using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
139 template <typename QueueDescriptor>
140 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
141 armnn::DataType::Float16,
142 armnn::DataType::Float32>;
144 template <typename QueueDescriptor>
145 using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
146 armnn::DataType::Float32,
147 armnn::DataType::Float16>;