IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / backendsCommon / Workload.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "WorkloadData.hpp"
8 #include "WorkloadInfo.hpp"
9
10 #include <Profiling.hpp>
11
12 #include <algorithm>
13
14 namespace armnn
15 {
16
17 /// Workload interface to enqueue a layer computation.
18 class IWorkload
19 {
20 public:
21     virtual ~IWorkload() {}
22
23     virtual void Execute() const = 0;
24 };
25
26 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
27 // in the various workload factories.
28 // There should never be an instantiation of a NullWorkload.
29 class NullWorkload : public IWorkload
30 {
31     NullWorkload()=delete;
32 };
33
34 template <typename QueueDescriptor>
35 class BaseWorkload : public IWorkload
36 {
37 public:
38
39     BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
40         : m_Data(descriptor)
41     {
42         m_Data.Validate(info);
43     }
44
45     const QueueDescriptor& GetData() const { return m_Data; }
46
47 protected:
48     const QueueDescriptor m_Data;
49 };
50
51 // TypedWorkload used
52 template <typename QueueDescriptor, armnn::DataType... DataTypes>
53 class TypedWorkload : public BaseWorkload<QueueDescriptor>
54 {
55 public:
56
57     TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
58         : BaseWorkload<QueueDescriptor>(descriptor, info)
59     {
60         std::vector<armnn::DataType> dataTypes = {DataTypes...};
61         armnn::DataType expectedInputType;
62
63         if (!info.m_InputTensorInfos.empty())
64         {
65             expectedInputType = info.m_InputTensorInfos.front().GetDataType();
66
67             if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
68             {
69                 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
70             }
71             BOOST_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
72                                          info.m_InputTensorInfos.end(),
73                                          [&](auto it){
74                                              return it.GetDataType() == expectedInputType;
75                                          }),
76                              "Trying to create workload with incorrect type");
77         }
78         armnn::DataType expectedOutputType;
79
80         if (!info.m_OutputTensorInfos.empty())
81         {
82             expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
83
84             if (!info.m_InputTensorInfos.empty())
85             {
86                 if (expectedOutputType != expectedInputType)
87                 {
88                     BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
89                 }
90             }
91             else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
92             {
93                 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
94             }
95             BOOST_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
96                                          info.m_OutputTensorInfos.end(),
97                                          [&](auto it){
98                                              return it.GetDataType() == expectedOutputType;
99                                          }),
100                              "Trying to create workload with incorrect type");
101         }
102     }
103 };
104
105 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
106 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
107 {
108 public:
109
110     MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
111         : BaseWorkload<QueueDescriptor>(descriptor, info)
112     {
113         BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
114                                      info.m_InputTensorInfos.end(),
115                                      [&](auto it){
116                                          return it.GetDataType() == InputDataType;
117                                      }),
118                          "Trying to create workload with incorrect type");
119         BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
120                                      info.m_OutputTensorInfos.end(),
121                                      [&](auto it){
122                                          return it.GetDataType() == OutputDataType;
123                                      }),
124                          "Trying to create workload with incorrect type");
125     }
126 };
127
128 template <typename QueueDescriptor>
129 using FloatWorkload = TypedWorkload<QueueDescriptor,
130                                     armnn::DataType::Float16,
131                                     armnn::DataType::Float32>;
132
133 template <typename QueueDescriptor>
134 using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
135
136 template <typename QueueDescriptor>
137 using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
138
139 template <typename QueueDescriptor>
140 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
141                                                     armnn::DataType::Float16,
142                                                     armnn::DataType::Float32>;
143
144 template <typename QueueDescriptor>
145 using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
146                                                     armnn::DataType::Float32,
147                                                     armnn::DataType::Float16>;
148
149 } //namespace armnn