27 template <
typename QueueDescriptor>
36 m_Data.Validate(info);
59 std::vector<armnn::DataType> dataTypes = {DataTypes...};
66 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
68 BOOST_ASSERT_MSG(
false,
"Trying to create workload with incorrect type");
73 return it.GetDataType() == expectedInputType;
75 "Trying to create workload with incorrect type");
85 if (expectedOutputType != expectedInputType)
87 BOOST_ASSERT_MSG(
false,
"Trying to create workload with incorrect type");
90 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
92 BOOST_ASSERT_MSG(
false,
"Trying to create workload with incorrect type");
97 return it.GetDataType() == expectedOutputType;
99 "Trying to create workload with incorrect type");
104 template <
typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
115 return it.GetDataType() == InputDataType;
117 "Trying to create workload with incorrect type");
122 return it.GetDataType() == OutputDataType;
124 "Trying to create workload with incorrect type");
129 template <
typename QueueDescriptor, armnn::DataType DataType>
140 "Trying to create workload with incorrect type");
146 return it.GetDataType() ==
DataType;
148 "Trying to create workload with incorrect type");
152 template <
typename QueueDescriptor>
157 template <
typename QueueDescriptor>
160 template <
typename QueueDescriptor>
163 template <
typename QueueDescriptor>
166 template <
typename QueueDescriptor>
169 template <
typename QueueDescriptor>
174 template <
typename QueueDescriptor>
179 template <
typename QueueDescriptor>
182 armnn::DataType::Float32>;
184 template <
typename QueueDescriptor>
187 armnn::DataType::Float16>;
189 template <
typename QueueDescriptor>
192 armnn::DataType::Float32>;
MultiTypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
Workload interface to enqueue a layer computation.
profiling::ProfilingGuid GetGuid() const final
const QueueDescriptor & GetData() const
std::vector< TensorInfo > m_OutputTensorInfos
const profiling::ProfilingGuid m_Guid
const QueueDescriptor m_Data
std::vector< TensorInfo > m_InputTensorInfos
TypedWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)
void PostAllocationConfigure() override
BaseWorkload(const QueueDescriptor &descriptor, const WorkloadInfo &info)