4cfffd4646c5a46381b8f35101ddc5ff15cc721f
[platform/upstream/armnn.git] / src / backends / Workload.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "WorkloadData.hpp"
8 #include "WorkloadInfo.hpp"
9 #include <algorithm>
10 #include "Profiling.hpp"
11
12 namespace armnn
13 {
14
15 /// Workload interface to enqueue a layer computation.
16 class IWorkload
17 {
18 public:
19     virtual ~IWorkload() {}
20
21     virtual void Execute() const = 0;
22 };
23
24 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
25 // in the various workload factories.
26 // There should never be an instantiation of a NullWorkload.
27 class NullWorkload : public IWorkload
28 {
29     NullWorkload()=delete;
30 };
31
32 template <typename QueueDescriptor>
33 class BaseWorkload : public IWorkload
34 {
35 public:
36
37     BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
38         : m_Data(descriptor)
39     {
40         m_Data.Validate(info);
41     }
42
43     const QueueDescriptor& GetData() const { return m_Data; }
44
45 protected:
46     const QueueDescriptor m_Data;
47 };
48
49 // TypedWorkload used
50 template <typename QueueDescriptor, armnn::DataType... DataTypes>
51 class TypedWorkload : public BaseWorkload<QueueDescriptor>
52 {
53 public:
54
55     TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
56         : BaseWorkload<QueueDescriptor>(descriptor, info)
57     {
58         std::vector<armnn::DataType> dataTypes = {DataTypes...};
59         armnn::DataType expectedInputType;
60
61         if (!info.m_InputTensorInfos.empty())
62         {
63             expectedInputType = info.m_InputTensorInfos.front().GetDataType();
64
65             if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
66             {
67                 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
68             }
69             BOOST_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
70                                          info.m_InputTensorInfos.end(),
71                                          [&](auto it){
72                                              return it.GetDataType() == expectedInputType;
73                                          }),
74                              "Trying to create workload with incorrect type");
75         }
76         armnn::DataType expectedOutputType;
77
78         if (!info.m_OutputTensorInfos.empty())
79         {
80             expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
81
82             if (!info.m_InputTensorInfos.empty())
83             {
84                 if (expectedOutputType != expectedInputType)
85                 {
86                     BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
87                 }
88             }
89             else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
90             {
91                 BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type");
92             }
93             BOOST_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
94                                          info.m_OutputTensorInfos.end(),
95                                          [&](auto it){
96                                              return it.GetDataType() == expectedOutputType;
97                                          }),
98                              "Trying to create workload with incorrect type");
99         }
100     }
101 };
102
103 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
104 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
105 {
106 public:
107
108     MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
109         : BaseWorkload<QueueDescriptor>(descriptor, info)
110     {
111         BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
112                                      info.m_InputTensorInfos.end(),
113                                      [&](auto it){
114                                          return it.GetDataType() == InputDataType;
115                                      }),
116                          "Trying to create workload with incorrect type");
117         BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
118                                      info.m_OutputTensorInfos.end(),
119                                      [&](auto it){
120                                          return it.GetDataType() == OutputDataType;
121                                      }),
122                          "Trying to create workload with incorrect type");
123     }
124 };
125
126 template <typename QueueDescriptor>
127 using FloatWorkload = TypedWorkload<QueueDescriptor,
128                                     armnn::DataType::Float16,
129                                     armnn::DataType::Float32>;
130
131 template <typename QueueDescriptor>
132 using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
133
134 template <typename QueueDescriptor>
135 using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
136
137 template <typename QueueDescriptor>
138 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
139                                                     armnn::DataType::Float16,
140                                                     armnn::DataType::Float32>;
141
142 template <typename QueueDescriptor>
143 using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
144                                                     armnn::DataType::Float32,
145                                                     armnn::DataType::Float16>;
146
147 } //namespace armnn