2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include "WorkloadData.hpp"
8 #include "WorkloadInfo.hpp"
10 #include "Profiling.hpp"
15 /// Workload interface to enqueue a layer computation.
19 virtual ~IWorkload() {}
21 virtual void Execute() const = 0;
24 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
25 // in the various workload factories.
26 // There should never be an instantiation of a NullWorkload.
27 class NullWorkload : public IWorkload
29 NullWorkload()=delete;
32 template <typename QueueDescriptor>
33 class BaseWorkload : public IWorkload
37 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
40 m_Data.Validate(info);
43 const QueueDescriptor& GetData() const { return m_Data; }
46 const QueueDescriptor m_Data;
50 template <typename QueueDescriptor, armnn::DataType... DataTypes>
51 class TypedWorkload : public BaseWorkload<QueueDescriptor>
55 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
56 : BaseWorkload<QueueDescriptor>(descriptor, info)
58 std::vector<armnn::DataType> dataTypes = {DataTypes...};
59 armnn::DataType expectedInputType;
61 if (!info.m_InputTensorInfos.empty())
63 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
65 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
67 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
69 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
70 info.m_InputTensorInfos.end(),
72 return it.GetDataType() == expectedInputType;
74 "Trying to create workload with incorrect type");
76 armnn::DataType expectedOutputType;
78 if (!info.m_OutputTensorInfos.empty())
80 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
82 if (!info.m_InputTensorInfos.empty())
84 if (expectedOutputType != expectedInputType)
86 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
89 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
91 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
93 BOOST_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
94 info.m_OutputTensorInfos.end(),
96 return it.GetDataType() == expectedOutputType;
98 "Trying to create workload with incorrect type");
103 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
104 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
108 MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
109 : BaseWorkload<QueueDescriptor>(descriptor, info)
111 BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
112 info.m_InputTensorInfos.end(),
114 return it.GetDataType() == InputDataType;
116 "Trying to create workload with incorrect type");
117 BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
118 info.m_OutputTensorInfos.end(),
120 return it.GetDataType() == OutputDataType;
122 "Trying to create workload with incorrect type");
126 template <typename QueueDescriptor>
127 using FloatWorkload = TypedWorkload<QueueDescriptor,
128 armnn::DataType::Float16,
129 armnn::DataType::Float32>;
131 template <typename QueueDescriptor>
132 using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
134 template <typename QueueDescriptor>
135 using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
137 template <typename QueueDescriptor>
138 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
139 armnn::DataType::Float16,
140 armnn::DataType::Float32>;
142 template <typename QueueDescriptor>
143 using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
144 armnn::DataType::Float32,
145 armnn::DataType::Float16>;