IVGCVSW-4246 Clean build of backends with -Wextra
[platform/upstream/armnn.git] / src / backends / backendsCommon / WorkloadFactory.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <Layer.hpp>
7 #include <LayersFwd.hpp>
8
9 #include <armnn/Types.hpp>
10 #include <armnn/LayerSupport.hpp>
11 #include <armnn/ILayerSupport.hpp>
12 #include <armnn/BackendRegistry.hpp>
13
14 #include <backendsCommon/WorkloadFactory.hpp>
15 #include <armnn/backends/IBackendInternal.hpp>
16 #include <backendsCommon/CpuTensorHandle.hpp>
17 #include <backendsCommon/WorkloadFactory.hpp>
18
19 #include <backendsCommon/test/WorkloadTestUtils.hpp>
20
21 #include <boost/cast.hpp>
22 #include <boost/iterator/transform_iterator.hpp>
23
24 #include <cstring>
25 #include <sstream>
26
27 namespace armnn
28 {
29
30 namespace
31 {
32
33 const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
34 {
35     if (!type)
36     {
37         return info;
38     }
39
40     return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
41 }
42
43 } // anonymous namespace
44
45 bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
46                                         const IConnectableLayer& connectableLayer,
47                                         Optional<DataType> dataType,
48                                         std::string& outReasonIfUnsupported)
49 {
50     Optional<std::string&> reason = outReasonIfUnsupported;
51     bool result;
52     const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
53
54     auto const& backendRegistry = BackendRegistryInstance();
55     if (!backendRegistry.IsBackendRegistered(backendId))
56     {
57         std::stringstream ss;
58         ss << connectableLayer.GetName() << " is not supported on " << backendId
59            << " because this backend is not registered.";
60
61         outReasonIfUnsupported = ss.str();
62         return false;
63     }
64
65     auto backendFactory = backendRegistry.GetFactory(backendId);
66     auto backendObject = backendFactory();
67     auto layerSupportObject = backendObject->GetLayerSupport();
68
69     switch(layer.GetType())
70     {
71         case LayerType::Abs:
72         {
73             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
74             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
75             result = layerSupportObject->IsAbsSupported(OverrideDataType(input, dataType),
76                                                         OverrideDataType(output, dataType),
77                                                         reason);
78             break;
79         }
80         case LayerType::Activation:
81         {
82             auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
83             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
84             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
85             result = layerSupportObject->IsActivationSupported(
86                                            OverrideDataType(input, dataType),
87                                            OverrideDataType(output, dataType),
88                                            cLayer->GetParameters(),
89                                            reason);
90             break;
91         }
92         case LayerType::Addition:
93         {
94             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
95             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
96             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
97             result = layerSupportObject->IsAdditionSupported(
98                                         OverrideDataType(input0, dataType),
99                                         OverrideDataType(input1, dataType),
100                                         OverrideDataType(output, dataType),
101                                         reason);
102             break;
103         }
104         case LayerType::ArgMinMax:
105         {
106             auto cLayer = boost::polymorphic_downcast<const ArgMinMaxLayer*>(&layer);
107             const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
108
109             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
110             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
111             result = layerSupportObject->IsArgMinMaxSupported(
112                     OverrideDataType(input, dataType),
113                     OverrideDataType(output, DataType::Signed32),
114                     descriptor,
115                     reason);
116             break;
117         }
118         case LayerType::BatchNormalization:
119         {
120             auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
121             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
122             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
123             const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
124             const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
125             const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
126             const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
127             result = layerSupportObject->IsBatchNormalizationSupported(
128                                                    OverrideDataType(input, dataType),
129                                                    OverrideDataType(output, dataType),
130                                                    OverrideDataType(mean, dataType),
131                                                    OverrideDataType(var, dataType),
132                                                    OverrideDataType(beta, dataType),
133                                                    OverrideDataType(gamma, dataType),
134                                                    cLayer->GetParameters(),
135                                                    reason);
136             break;
137         }
138         case LayerType::BatchToSpaceNd:
139         {
140             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
141             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
142             auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
143
144             result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
145                                                                    OverrideDataType(output, dataType),
146                                                                    cLayer->GetParameters(),
147                                                                    reason);
148             break;
149         }
150         case LayerType::Comparison:
151         {
152             auto cLayer = boost::polymorphic_downcast<const ComparisonLayer*>(&layer);
153
154             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
155             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
156             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
157
158             result = layerSupportObject->IsComparisonSupported(OverrideDataType(input0, dataType),
159                                                                OverrideDataType(input1, dataType),
160                                                                OverrideDataType(output, DataType::Boolean),
161                                                                cLayer->GetParameters(),
162                                                                reason);
163             break;
164         }
165         case LayerType::Constant:
166         {
167             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
168             result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
169             break;
170         }
171         case LayerType::ConvertFp16ToFp32:
172         {
173             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
174             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
175             result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
176             break;
177         }
178         case LayerType::ConvertFp32ToFp16:
179         {
180             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
181             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
182             result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
183             break;
184         }
185         case LayerType::Convolution2d:
186         {
187             auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
188
189             const TensorInfo input  = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
190                                                        dataType);
191             const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
192             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
193
194             const Convolution2dDescriptor& descriptor  = cLayer->GetParameters();
195
196             // Construct optional biases object based on the value of m_BiasEnabled
197             Optional<TensorInfo> biases;
198             if (descriptor.m_BiasEnabled)
199             {
200                 biases =
201                     OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
202             }
203
204             result = layerSupportObject->IsConvolution2dSupported(
205                                               input,
206                                               output,
207                                               descriptor,
208                                               OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
209                                               biases,
210                                               reason);
211             break;
212         }
213         case LayerType::Debug:
214         {
215             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
216             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
217
218             result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
219                                                           OverrideDataType(output, dataType),
220                                                           reason);
221             break;
222         }
223         case LayerType::DepthToSpace:
224         {
225             auto cLayer = boost::polymorphic_downcast<const DepthToSpaceLayer*>(&layer);
226
227             const TensorInfo& input  = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
228             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
229
230             result = layerSupportObject->IsDepthToSpaceSupported(OverrideDataType(input, dataType),
231                                                                  OverrideDataType(output, dataType),
232                                                                  cLayer->GetParameters(),
233                                                                  reason);
234             break;
235         }
236         case LayerType::DepthwiseConvolution2d:
237         {
238             auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
239             const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
240                                                        dataType);
241             const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
242             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
243
244             const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
245
246             // Construct optional biases object based on the value of m_BiasEnabled
247             Optional<TensorInfo> biases;
248             if (descriptor.m_BiasEnabled)
249             {
250                 biases =
251                     OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
252             }
253
254             result = layerSupportObject->IsDepthwiseConvolutionSupported(
255                                                      input,
256                                                      output,
257                                                      descriptor,
258                                                      OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
259                                                      biases,
260                                                      reason);
261             break;
262         }
263         case LayerType::Dequantize:
264         {
265             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
266             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
267
268             result = layerSupportObject->IsDequantizeSupported(input,
269                                                                OverrideDataType(output, dataType),
270                                                                reason);
271             break;
272         }
273         case LayerType::DetectionPostProcess:
274         {
275             auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
276             const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
277             const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
278             const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
279
280             const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
281             const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
282             const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
283             const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
284
285             const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
286             result = layerSupportObject->IsDetectionPostProcessSupported(boxEncodings,
287                                                                          scores,
288                                                                          anchors,
289                                                                          detectionBoxes,
290                                                                          detectionClasses,
291                                                                          detectionScores,
292                                                                          numDetections,
293                                                                          descriptor,
294                                                                          reason);
295             break;
296         }
297         case LayerType::FakeQuantization:
298         {
299             auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
300             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
301             result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
302                                                                      cLayer->GetParameters(),
303                                                                      reason);
304             break;
305         }
306         case LayerType::Floor:
307         {
308             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
309             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
310             result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
311                                                           OverrideDataType(output, dataType),
312                                                           reason);
313             break;
314         }
315         case LayerType::FullyConnected:
316         {
317             auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
318             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
319             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
320             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
321
322             TensorInfo biasInfo;
323             const TensorInfo * biasInfoPtr = nullptr;
324             static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
325             static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
326             static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
327
328             const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
329             if (descriptor.m_BiasEnabled)
330             {
331                 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
332                 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
333                 biasInfoPtr = &biasInfo;
334             }
335             else
336             {
337                 // If biases are not enabled pass a dummy tensorinfo for the validation
338                 switch(input.GetDataType())
339                 {
340                     case DataType::Float16:
341                     {
342                         biasInfoPtr = &dummyFloat16Bias;
343                         break;
344                     }
345                     case DataType::Float32:
346                     {
347                         biasInfoPtr = &dummyFloat32Bias;
348                         break;
349                     }
350                     case DataType::QuantisedAsymm8:
351                     case DataType::QuantisedSymm16:
352                     {
353                         biasInfoPtr = &dummyQA8Bias;
354                         break;
355                     }
356                     default:
357                     {
358                         BOOST_ASSERT_MSG(false, "Unexpected bias type");
359                     }
360                 }
361             }
362
363             result = layerSupportObject->IsFullyConnectedSupported(
364                                                OverrideDataType(input, dataType),
365                                                OverrideDataType(output, dataType),
366                                                OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
367                                                *biasInfoPtr,
368                                                descriptor,
369                                                reason);
370             break;
371         }
372         case LayerType::Gather:
373         {
374             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
375             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
376             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
377             result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
378                                                            input1,
379                                                            OverrideDataType(output, dataType),
380                                                            reason);
381             break;
382         }
383         case LayerType::Input:
384         {
385             const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
386             result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
387             break;
388         }
389         case LayerType::InstanceNormalization:
390         {
391             auto cLayer = boost::polymorphic_downcast<const InstanceNormalizationLayer*>(&layer);
392             const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
393
394             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
395             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
396
397             result = layerSupportObject->IsInstanceNormalizationSupported(
398                 OverrideDataType(input, dataType),
399                 OverrideDataType(output, dataType),
400                 descriptor,
401                 reason);
402             break;
403         }
404         case LayerType::L2Normalization:
405         {
406             auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
407             const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
408
409             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
410             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
411
412             result = layerSupportObject->IsL2NormalizationSupported(
413                                                 OverrideDataType(input, dataType),
414                                                 OverrideDataType(output, dataType),
415                                                 descriptor,
416                                                 reason);
417             break;
418         }
419         case LayerType::LogSoftmax:
420         {
421             auto cLayer = boost::polymorphic_downcast<const LogSoftmaxLayer*>(&layer);
422
423             const TensorInfo& input  = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
424             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
425
426             result = layerSupportObject->IsLogSoftmaxSupported(OverrideDataType(input, dataType),
427                                                                OverrideDataType(output, dataType),
428                                                                cLayer->GetParameters(),
429                                                                reason);
430             break;
431         }
432         case LayerType::Lstm:
433         {
434             auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
435             const LstmDescriptor& descriptor = cLayer->GetParameters();
436
437             // All inputs.
438             const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
439                                                        dataType);
440             const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
441                                                                dataType);
442             const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
443                                                              dataType);
444             // All outputs
445             const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
446             const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
447             const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
448             const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
449
450             // Basic parameters
451             const TensorInfo& inputToForgetWeights
452                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
453             const TensorInfo& inputToCellWeights
454                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
455             const TensorInfo& inputToOutputWeights
456                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
457             const TensorInfo& recurrentToForgetWeights
458                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
459             const TensorInfo& recurrentToCellWeights
460                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
461             const TensorInfo& recurrentToOutputWeights
462                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
463             const TensorInfo& forgetGateBias
464                     = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
465             const TensorInfo& cellBias
466                     = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
467             const TensorInfo& outputGateBias
468                     = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
469
470             LstmInputParamsInfo paramsInfo;
471
472             paramsInfo.m_InputToForgetWeights     = &inputToForgetWeights;
473             paramsInfo.m_InputToCellWeights       = &inputToCellWeights;
474             paramsInfo.m_InputToOutputWeights     = &inputToOutputWeights;
475             paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
476             paramsInfo.m_RecurrentToCellWeights   = &recurrentToCellWeights;
477             paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
478             paramsInfo.m_ForgetGateBias           = &forgetGateBias;
479             paramsInfo.m_CellBias                 = &cellBias;
480             paramsInfo.m_OutputGateBias           = &outputGateBias;
481
482
483             // Optional parameters
484             TensorInfo optInputToInputWeights;
485             TensorInfo optRecurrentToInputWeights;
486             TensorInfo optCellToInputWeights;
487             TensorInfo optInputGateBias;
488             TensorInfo optProjectionWeights;
489             TensorInfo optProjectionBias;
490             TensorInfo optCellToForgetWeights;
491             TensorInfo optCellToOutputWeights;
492             TensorInfo optInputLayerNormWeights;
493             TensorInfo optForgetLayerNormWeights;
494             TensorInfo optCellLayerNormWeights;
495             TensorInfo optOutputLayerNormWeights;
496
497             if(!descriptor.m_CifgEnabled)
498             {
499                 optInputToInputWeights =
500                     OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
501                 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
502
503                 optRecurrentToInputWeights =
504                     OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
505                 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
506                 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
507                 {
508                     optCellToInputWeights =
509                         OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
510                     paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
511                 }
512                 optInputGateBias =
513                        OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
514                 paramsInfo.m_InputGateBias = &optInputGateBias;
515             }
516
517             if(descriptor.m_ProjectionEnabled)
518             {
519                 optProjectionWeights =
520                     OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
521                 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
522                 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
523                 {
524                     optProjectionBias =
525                         OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
526                     paramsInfo.m_ProjectionBias = &optProjectionBias;
527                 }
528             }
529
530             if(descriptor.m_PeepholeEnabled)
531             {
532                 optCellToForgetWeights =
533                     OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
534                 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
535                 optCellToOutputWeights =
536                     OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
537                 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
538             }
539
540             if(descriptor.m_LayerNormEnabled)
541             {
542                 if (!descriptor.m_CifgEnabled)
543                 {
544                     optInputLayerNormWeights = OverrideDataType(
545                             cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
546                     paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
547                 }
548
549                 optForgetLayerNormWeights = OverrideDataType(
550                         cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
551                 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
552
553                 optCellLayerNormWeights = OverrideDataType(
554                         cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
555                 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
556
557                 optOutputLayerNormWeights = OverrideDataType(
558                         cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
559                 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
560             }
561
562             result = layerSupportObject->IsLstmSupported(
563                                      input,
564                                      outputStateIn,
565                                      cellStateIn,
566                                      scratchBuffer,
567                                      outputStateOut,
568                                      cellStateOut,
569                                      output,
570                                      descriptor,
571                                      paramsInfo,
572                                      reason);
573             break;
574         }
575         case LayerType::Maximum:
576         {
577             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
578             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
579             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
580
581             result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
582                                                             OverrideDataType(input1, dataType),
583                                                             OverrideDataType(output, dataType),
584                                                             reason);
585             break;
586         }
587         case LayerType::MemCopy:
588         {
589             const TensorInfo& input  = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
590             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
591
592             result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
593                                                             OverrideDataType(output, dataType),
594                                                             reason);
595             break;
596         }
597         case LayerType::MemImport:
598         {
599             const TensorInfo& input  = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
600             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
601
602             result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType),
603                                                               OverrideDataType(output, dataType),
604                                                               reason);
605             break;
606         }
607         case LayerType::Merge:
608         {
609             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
610             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
611             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
612
613             result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
614                                                           OverrideDataType(input1, dataType),
615                                                           OverrideDataType(output, dataType),
616                                                           reason);
617             break;
618         }
619         case LayerType::Concat:
620         {
621             auto cLayer = boost::polymorphic_downcast<const ConcatLayer*>(&layer);
622
623             // Get vector of all inputs.
624             auto getTensorInfo = [&dataType](const InputSlot& slot)
625                 {
626                     return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
627                 };
628             auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
629             auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
630             std::vector<TensorInfo> inputs(beginI, endI);
631
632             auto getTensorInfoPtr = [](const TensorInfo& info)
633                 {
634                     return &info;
635                 };
636             auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
637             auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
638             std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
639
640             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
641
642             result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
643
644
645             break;
646         }
647         case LayerType::Multiplication:
648         {
649             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
650             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
651             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
652             result = layerSupportObject->IsMultiplicationSupported(
653                                                OverrideDataType(input0, dataType),
654                                                OverrideDataType(input1, dataType),
655                                                OverrideDataType(output, dataType),
656                                                reason);
657             break;
658         }
659         case LayerType::Normalization:
660         {
661             auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
662             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
663             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
664             result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
665                                                                   OverrideDataType(output, dataType),
666                                                                   cLayer->GetParameters(),
667                                                                   reason);
668             break;
669         }
670         case LayerType::Output:
671         {
672             const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
673             result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
674             break;
675         }
676         case LayerType::Permute:
677         {
678             auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
679             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
680             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
681             result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
682                                                             OverrideDataType(output, dataType),
683                                                             cLayer->GetParameters(),
684                                                             reason);
685             break;
686         }
687         case LayerType::Pad:
688         {
689             auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
690             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
691             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
692             result = layerSupportObject->IsPadSupported(
693                                     OverrideDataType(input, dataType),
694                                     OverrideDataType(output, dataType),
695                                     cLayer->GetParameters(),
696                                     reason);
697             break;
698         }
699         case LayerType::Pooling2d:
700         {
701             auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
702             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
703             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
704             result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
705                                                               OverrideDataType(output, dataType),
706                                                               cLayer->GetParameters(),
707                                                               reason);
708             break;
709         }
710         case LayerType::PreCompiled:
711         {
712             auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
713             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
714             result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
715                                                                 cLayer->GetParameters(),
716                                                                 reason);
717             break;
718         }
719         case LayerType::Quantize:
720         {
721             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
722             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
723             result = layerSupportObject->IsQuantizeSupported(input, output, reason);
724             break;
725         }
726         case LayerType::QuantizedLstm:
727         {
728             auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer);
729
730             // Inputs
731             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
732             const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
733             const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
734
735             // Outputs
736             const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
737             const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
738
739             // QuantizedLstm parameters
740             QuantizedLstmInputParamsInfo paramsInfo;
741
742             paramsInfo.m_InputToInputWeights      =
743                     &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
744             paramsInfo.m_InputToForgetWeights     =
745                     &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
746             paramsInfo.m_InputToCellWeights       =
747                     &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
748             paramsInfo.m_InputToOutputWeights     =
749                     &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
750
751             paramsInfo.m_RecurrentToInputWeights  =
752                     &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
753             paramsInfo.m_RecurrentToForgetWeights =
754                     &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
755             paramsInfo.m_RecurrentToCellWeights   =
756                     &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
757             paramsInfo.m_RecurrentToOutputWeights =
758                     &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
759
760             paramsInfo.m_InputGateBias            =
761                     &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
762             paramsInfo.m_ForgetGateBias           =
763                     &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
764             paramsInfo.m_CellBias                 =
765                     &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
766             paramsInfo.m_OutputGateBias           =
767                     &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
768
769             result = layerSupportObject->IsQuantizedLstmSupported(input,
770                                                                   previousCellStateIn,
771                                                                   previousOutputIn,
772                                                                   cellStateOut,
773                                                                   output,
774                                                                   paramsInfo,
775                                                                   reason);
776             break;
777         }
778         case LayerType::Division:
779         {
780             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
781             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
782             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
783             result = layerSupportObject->IsDivisionSupported(
784                                          OverrideDataType(input0, dataType),
785                                          OverrideDataType(input1, dataType),
786                                          OverrideDataType(output, dataType),
787                                          reason);
788             break;
789         }
790         case LayerType::Reshape:
791         {
792             auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
793             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
794             result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
795                                                             cLayer->GetParameters(),
796                                                             reason);
797             break;
798         }
799         case LayerType::Resize:
800         {
801             auto cLayer = boost::polymorphic_downcast<const ResizeLayer*>(&layer);
802             const TensorInfo& input  = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
803             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
804             result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
805                                                            OverrideDataType(output, dataType),
806                                                            cLayer->GetParameters(),
807                                                            reason);
808             break;
809         }
810         case LayerType::Rsqrt:
811         {
812             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
813             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
814             result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
815                                                           OverrideDataType(output, dataType),
816                                                           reason);
817             break;
818         }
819         case LayerType::Slice:
820         {
821             auto cLayer = boost::polymorphic_downcast<const SliceLayer*>(&layer);
822
823             const TensorInfo& input  = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
824             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
825
826             result = layerSupportObject->IsSliceSupported(OverrideDataType(input, dataType),
827                                                           OverrideDataType(output, dataType),
828                                                           cLayer->GetParameters(),
829                                                           reason);
830             break;
831         }
832         case LayerType::Softmax:
833         {
834             auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
835             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
836             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
837             result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
838                                                             OverrideDataType(output, dataType),
839                                                             cLayer->GetParameters(),
840                                                             reason);
841             break;
842         }
843         case LayerType::SpaceToBatchNd:
844         {
845             auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
846             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
847             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
848             result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
849                                                                    OverrideDataType(output, dataType),
850                                                                    cLayer->GetParameters(),
851                                                                    reason);
852             break;
853         }
854         case LayerType::SpaceToDepth:
855         {
856             auto cLayer = boost::polymorphic_downcast<const SpaceToDepthLayer*>(&layer);
857
858             const TensorInfo& input  = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
859             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
860
861             result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
862                                                                  OverrideDataType(output, dataType),
863                                                                  cLayer->GetParameters(),
864                                                                  reason);
865             break;
866         }
867         case LayerType::Splitter:
868         {
869             auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
870             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
871
872             // Get vector of all outputs.
873             auto getTensorInfo = [&dataType](const OutputSlot& slot)
874             {
875                 return OverrideDataType(slot.GetTensorInfo(), dataType);
876             };
877             auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
878             auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
879             std::vector<TensorInfo> outputs(beginI, endI);
880
881             const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
882
883             result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
884                                                              outputPtrs,
885                                                              cLayer->GetParameters(),
886                                                              reason);
887             break;
888         }
889         case LayerType::Stack:
890         {
891             auto cLayer = boost::polymorphic_downcast<const StackLayer*>(&layer);
892
893             // Get vector of all inputs.
894             auto getTensorInfo = [&dataType](const InputSlot& slot)
895                 {
896                     return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
897                 };
898             auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
899             auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
900             std::vector<TensorInfo> inputs(beginI, endI);
901
902             auto getTensorInfoPtr = [](const TensorInfo& info)
903                 {
904                     return &info;
905                 };
906             auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
907             auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
908             std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
909
910             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
911
912             result = layerSupportObject->IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
913
914             break;
915         }
916         case LayerType::StandIn:
917         {
918             auto cLayer = boost::polymorphic_downcast<const StandInLayer*>(&layer);
919
920             // Get vector of all inputs.
921             auto getTensorInfoIn = [&dataType](const InputSlot& slot)
922                 {
923                     return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
924                 };
925             auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
926                 {
927                     return OverrideDataType(slot.GetTensorInfo(), dataType);
928                 };
929             auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfoIn);
930             auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfoIn);
931             std::vector<TensorInfo> inputs(beginI, endI);
932
933             auto beginO = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
934             auto endO = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfoOut);
935             std::vector<TensorInfo> outputs(beginO, endO);
936
937
938             auto getTensorInfoPtr = [](const TensorInfo& info)
939                 {
940                     return &info;
941                 };
942             auto beginPtrI = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
943             auto endPtrI = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
944             std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
945
946             auto beginPtrO = boost::make_transform_iterator(outputs.begin(), getTensorInfoPtr);
947             auto endPtrO = boost::make_transform_iterator(outputs.end(), getTensorInfoPtr);
948             std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
949
950
951             result = layerSupportObject->IsStandInSupported(inputPtrs,
952                                                             outputPtrs,
953                                                             cLayer->GetParameters(),
954                                                             reason);
955             break;
956         }
957         case LayerType::StridedSlice:
958         {
959             auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
960             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
961             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
962             result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
963                                                                  OverrideDataType(output, dataType),
964                                                                  cLayer->GetParameters(),
965                                                                  reason);
966             break;
967         }
968         case LayerType::Subtraction:
969         {
970             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
971             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
972             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
973             result = layerSupportObject->IsSubtractionSupported(
974                                             OverrideDataType(input0, dataType),
975                                             OverrideDataType(input1, dataType),
976                                             OverrideDataType(output, dataType),
977                                             reason);
978             break;
979         }
980         case LayerType::Switch:
981         {
982             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
983             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
984             const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
985             const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
986             result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
987                                                            OverrideDataType(input1, dataType),
988                                                            OverrideDataType(output0, dataType),
989                                                            OverrideDataType(output1, dataType),
990                                                            reason);
991             break;
992         }
993         case LayerType::Mean:
994         {
995             auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
996             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
997             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
998             result = layerSupportObject->IsMeanSupported(
999                                      OverrideDataType(input, dataType),
1000                                      OverrideDataType(output, dataType),
1001                                      cLayer->GetParameters(),
1002                                      reason);
1003             break;
1004         }
1005         case LayerType::Minimum:
1006         {
1007             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1008             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1009             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1010             result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
1011                                                             OverrideDataType(input1, dataType),
1012                                                             OverrideDataType(output, dataType),
1013                                                             reason);
1014             break;
1015         }
1016         case LayerType::Prelu:
1017         {
1018             const TensorInfo& input  = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1019             const TensorInfo& alpha  = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1020             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1021             result = layerSupportObject->IsPreluSupported(OverrideDataType(input,  dataType),
1022                                                           OverrideDataType(alpha,  dataType),
1023                                                           OverrideDataType(output, dataType),
1024                                                           reason);
1025             break;
1026         }
1027         case LayerType::TransposeConvolution2d:
1028         {
1029             auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer);
1030
1031             const TensorInfo input  = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1032                                                        dataType);
1033             const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1034
1035             const TransposeConvolution2dDescriptor& descriptor  = cLayer->GetParameters();
1036
1037             Optional<TensorInfo> biases;
1038             if (descriptor.m_BiasEnabled)
1039             {
1040                 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
1041                 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1042                                           GetBiasTypeFromWeightsType(dataType));
1043             }
1044
1045             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
1046             const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1047
1048             result = layerSupportObject->IsTransposeConvolution2dSupported(input,
1049                                                                            output,
1050                                                                            descriptor,
1051                                                                            weights,
1052                                                                            biases,
1053                                                                            reason);
1054
1055             break;
1056         }
1057         default:
1058         {
1059             BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
1060             reason.value() = "Unrecognised layer type";
1061             result = false;
1062             break;
1063         }
1064     }
1065     return result;
1066 }
1067
1068 bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1069                                         Optional<DataType> dataType,
1070                                         std::string& outReasonIfUnsupported)
1071 {
1072     auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
1073     return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1074 }
1075
1076 // Default Implementations
1077 std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& /*descriptor*/,
1078                                                        const WorkloadInfo& /*info*/) const
1079 {
1080     return std::unique_ptr<IWorkload>();
1081 }
1082
1083 std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1084                                                               const WorkloadInfo& /*info*/) const
1085 {
1086     return std::unique_ptr<IWorkload>();
1087 }
1088
1089 std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1090                                                             const WorkloadInfo& /*info*/) const
1091 {
1092     return std::unique_ptr<IWorkload>();
1093 }
1094
1095 std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1096                                                              const WorkloadInfo& /*info*/) const
1097 {
1098     return std::unique_ptr<IWorkload>();
1099 }
1100
1101 std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
1102     const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
1103 {
1104     return std::unique_ptr<IWorkload>();
1105 }
1106
1107 std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1108                                                                   const WorkloadInfo& /*Info*/) const
1109 {
1110     return std::unique_ptr<IWorkload>();
1111 }
1112
1113 std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1114                                                               const WorkloadInfo& /*info*/) const
1115 {
1116     return std::unique_ptr<IWorkload>();
1117 }
1118
1119 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1120                                                           const WorkloadInfo& /*info*/) const
1121 {
1122     return std::unique_ptr<IWorkload>();
1123 }
1124
1125 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1126                                                             const WorkloadInfo& /*info*/) const
1127 {
1128     return std::unique_ptr<IWorkload>();
1129 }
1130
1131 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1132                                                                      const WorkloadInfo& /*info*/) const
1133 {
1134     return std::unique_ptr<IWorkload>();
1135 }
1136
1137 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1138                                                                      const WorkloadInfo& /*info*/) const
1139 {
1140     return std::unique_ptr<IWorkload>();
1141 }
1142
1143 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1144                                                                  const WorkloadInfo& /*info*/) const
1145 {
1146     return std::unique_ptr<IWorkload>();
1147 }
1148
1149 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1150                                                          const WorkloadInfo& /*info*/) const
1151 {
1152     return std::unique_ptr<IWorkload>();
1153 }
1154
1155 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1156                                                                 const WorkloadInfo& /*info*/) const
1157 {
1158     return std::unique_ptr<IWorkload>();
1159 }
1160
1161 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
1162     const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
1163 {
1164     return std::unique_ptr<IWorkload>();
1165 }
1166
1167 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
1168     const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
1169 {
1170     return std::unique_ptr<IWorkload>();
1171 }
1172
1173 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
1174     const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
1175 {
1176     return std::unique_ptr<IWorkload>();
1177 }
1178
1179 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1180                                                             const WorkloadInfo& /*info*/) const
1181 {
1182     return std::unique_ptr<IWorkload>();
1183 }
1184
1185 std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& /*descriptor*/,
1186                                                          const WorkloadInfo& /*Info*/) const
1187 {
1188     return std::unique_ptr<IWorkload>();
1189 }
1190
1191 std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1192                                                                     const WorkloadInfo& /*info*/) const
1193 {
1194     return std::unique_ptr<IWorkload>();
1195 }
1196
1197 std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1198                                                          const WorkloadInfo& /*info*/) const
1199 {
1200     return std::unique_ptr<IWorkload>();
1201 }
1202
1203 std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1204                                                                   const WorkloadInfo& /*info*/) const
1205 {
1206     return std::unique_ptr<IWorkload>();
1207 }
1208
1209 std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1210                                                           const WorkloadInfo& /*info*/) const
1211 {
1212     return std::unique_ptr<IWorkload>();
1213 }
1214
1215 std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& /*descriptor*/,
1216                                                            const WorkloadInfo& /*info*/) const
1217 {
1218     return std::unique_ptr<IWorkload>();
1219 }
1220
1221 std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
1222     const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1223     const WorkloadInfo& /*info*/) const
1224 {
1225     return std::unique_ptr<IWorkload>();
1226 }
1227
1228 std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1229                                                                    const WorkloadInfo& /*info*/) const
1230 {
1231     return std::unique_ptr<IWorkload>();
1232 }
1233
1234 std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1235                                                               const WorkloadInfo& /*info*/) const
1236 {
1237     return std::unique_ptr<IWorkload>();
1238 }
1239
1240 std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1241                                                         const WorkloadInfo& /*info*/) const
1242 {
1243     return std::unique_ptr<IWorkload>();
1244 }
1245
1246 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1247                                                            const WorkloadInfo& /*info*/) const
1248 {
1249     return std::unique_ptr<IWorkload>();
1250 }
1251
1252 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1253                                                         const WorkloadInfo& /*Info*/) const
1254 {
1255     return std::unique_ptr<IWorkload>();
1256 }
1257
1258 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1259                                                            const WorkloadInfo& /*info*/) const
1260 {
1261     return std::unique_ptr<IWorkload>();
1262 }
1263
1264 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1265                                                              const WorkloadInfo& /*info*/) const
1266 {
1267     return std::unique_ptr<IWorkload>();
1268 }
1269
1270 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1271                                                          const WorkloadInfo& /*info*/) const
1272 {
1273     return std::unique_ptr<IWorkload>();
1274 }
1275
1276 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& /*descriptor*/,
1277                                                           const WorkloadInfo& /*info*/) const
1278 {
1279     return std::unique_ptr<IWorkload>();
1280 }
1281
1282 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1283                                                            const WorkloadInfo& /*info*/) const
1284 {
1285     return std::unique_ptr<IWorkload>();
1286 }
1287
1288 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1289                                                                   const WorkloadInfo& /*info*/) const
1290 {
1291     return std::unique_ptr<IWorkload>();
1292 }
1293
1294 std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1295                                                                  const WorkloadInfo& /*info*/) const
1296 {
1297     return std::unique_ptr<IWorkload>();
1298 }
1299
1300 std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1301                                                           const WorkloadInfo& /*info*/) const
1302 {
1303     return std::unique_ptr<IWorkload>();
1304 }
1305
1306 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1307                                                        const WorkloadInfo& /*Info*/) const
1308 {
1309     return std::unique_ptr<IWorkload>();
1310 }
1311
1312 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
1313                                                            const WorkloadInfo&/**/ /*info*/) const
1314 {
1315     return std::unique_ptr<IWorkload>();
1316 }
1317
1318 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1319                                                              const WorkloadInfo& /*info*/) const
1320 {
1321     return std::unique_ptr<IWorkload>();
1322 }
1323
1324 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1325                                                                const WorkloadInfo& /*info*/) const
1326 {
1327     return std::unique_ptr<IWorkload>();
1328 }
1329
1330 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1331                                                          const WorkloadInfo &/*info*/) const
1332 {
1333     return std::unique_ptr<IWorkload>();
1334 }
1335
1336 std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1337                                                             const WorkloadInfo& /*Info*/) const
1338 {
1339     return std::unique_ptr<IWorkload>();
1340 }
1341
1342 std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1343                                                                  const WorkloadInfo& /*info*/) const
1344 {
1345     return std::unique_ptr<IWorkload>();
1346 }
1347
1348 std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1349                                                            const WorkloadInfo& /*info*/) const
1350 {
1351     return std::unique_ptr<IWorkload>();
1352 }
1353
1354 std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& /*descriptor*/,
1355                                                                   const WorkloadInfo& /*info*/) const
1356 {
1357     return std::unique_ptr<IWorkload>();
1358 }
1359
1360 std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1361                                                             const WorkloadInfo& /*info*/) const
1362 {
1363     return std::unique_ptr<IWorkload>();
1364 }
1365
1366 std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& /*descriptor*/,
1367                                                          const WorkloadInfo& /*info*/) const
1368 {
1369     return std::unique_ptr<IWorkload>();
1370 }
1371
1372 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1373                                                          const WorkloadInfo& /*info*/) const
1374 {
1375     return std::unique_ptr<IWorkload>();
1376 }
1377 /**/
1378 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1379                                                            const WorkloadInfo& /*info*/) const
1380 {
1381     return std::unique_ptr<IWorkload>();
1382 }
1383
1384 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1385                                                             const WorkloadInfo& /*info*/) const
1386 {
1387     return std::unique_ptr<IWorkload>();
1388 }
1389
1390 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1391                                                                   const WorkloadInfo& /*info*/) const
1392 {
1393     return std::unique_ptr<IWorkload>();
1394 }
1395
1396 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1397                                                                 const WorkloadInfo& /*info*/) const
1398 {
1399     return std::unique_ptr<IWorkload>();
1400 }
1401
1402 std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1403                                                          const WorkloadInfo& /*info*/) const
1404 {
1405     return std::unique_ptr<IWorkload>();
1406 }
1407
1408 std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1409                                                                 const WorkloadInfo& /*info*/) const
1410 {
1411     return std::unique_ptr<IWorkload>();
1412 }
1413
1414 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1415                                                                const WorkloadInfo& /*info*/) const
1416 {
1417     return std::unique_ptr<IWorkload>();
1418 }
1419
1420 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1421                                                           const WorkloadInfo& /*info*/) const
1422 {
1423     return std::unique_ptr<IWorkload>();
1424 }
1425
1426 std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
1427     const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1428     const WorkloadInfo& /*info*/) const
1429 {
1430     return std::unique_ptr<IWorkload>();
1431 }
1432
1433 } // namepsace armnn