IVGCVSW-2043 - Merger using ACL for innermost concat axis
[platform/upstream/armnn.git] / src / backends / backendsCommon / WorkloadFactory.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "CpuTensorHandle.hpp"
7
8 #include <Layer.hpp>
9 #include <LayersFwd.hpp>
10
11 #include <armnn/Types.hpp>
12 #include <armnn/LayerSupport.hpp>
13 #include <armnn/ILayerSupport.hpp>
14
15 #include <backendsCommon/BackendRegistry.hpp>
16 #include <backendsCommon/WorkloadFactory.hpp>
17 #include <backendsCommon/IBackendInternal.hpp>
18
19 #include <boost/cast.hpp>
20 #include <boost/iterator/transform_iterator.hpp>
21
22 #include <cstring>
23 #include <sstream>
24
25 namespace armnn
26 {
27
28 namespace
29 {
30
31 const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
32 {
33     if (!type)
34     {
35         return info;
36     }
37
38     return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
39 }
40
41 Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
42 {
43     if (!weightsType)
44     {
45         return weightsType;
46     }
47
48     switch(weightsType.value())
49     {
50         case DataType::Float16:
51         case DataType::Float32:
52             return weightsType;
53         case DataType::QuantisedAsymm8:
54             return DataType::Signed32;
55         default:
56             BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
57     }
58     return EmptyOptional();
59 }
60
61 } // anonymous namespace
62
63 bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
64                                         const IConnectableLayer& connectableLayer,
65                                         Optional<DataType> dataType,
66                                         std::string& outReasonIfUnsupported)
67 {
68     Optional<std::string&> reason = outReasonIfUnsupported;
69     bool result;
70     const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
71
72     auto const& backendRegistry = BackendRegistryInstance();
73     if (!backendRegistry.IsBackendRegistered(backendId))
74     {
75         std::stringstream ss;
76         ss << connectableLayer.GetName() << " is not supported on " << backendId
77            << " because this backend is not registered.";
78
79         outReasonIfUnsupported = ss.str();
80         return false;
81     }
82
83     auto backendFactory = backendRegistry.GetFactory(backendId);
84     auto backendObject = backendFactory();
85     auto layerSupportObject = backendObject->GetLayerSupport();
86
87     switch(layer.GetType())
88     {
89         case LayerType::Activation:
90         {
91             auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
92             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
93             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
94             result = layerSupportObject->IsActivationSupported(
95                                            OverrideDataType(input, dataType),
96                                            OverrideDataType(output, dataType),
97                                            cLayer->GetParameters(),
98                                            reason);
99             break;
100         }
101         case LayerType::Addition:
102         {
103             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
104             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
105             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
106             result = layerSupportObject->IsAdditionSupported(
107                                         OverrideDataType(input0, dataType),
108                                         OverrideDataType(input1, dataType),
109                                         OverrideDataType(output, dataType),
110                                         reason);
111             break;
112         }
113         case LayerType::BatchNormalization:
114         {
115             auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
116             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
117             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118             const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119             const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120             const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121             const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
122             result = layerSupportObject->IsBatchNormalizationSupported(
123                                                    OverrideDataType(input, dataType),
124                                                    OverrideDataType(output, dataType),
125                                                    OverrideDataType(mean, dataType),
126                                                    OverrideDataType(var, dataType),
127                                                    OverrideDataType(beta, dataType),
128                                                    OverrideDataType(gamma, dataType),
129                                                    cLayer->GetParameters(),
130                                                    reason);
131             break;
132         }
133         case LayerType::BatchToSpaceNd:
134         {
135             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
137             auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
138
139             result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140                                                                    OverrideDataType(output, dataType),
141                                                                    cLayer->GetParameters(),
142                                                                    reason);
143             break;
144         }
145         case LayerType::Constant:
146         {
147             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
148             result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
149             break;
150         }
151         case LayerType::ConvertFp16ToFp32:
152         {
153             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
154             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
155             result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
156             break;
157         }
158         case LayerType::ConvertFp32ToFp16:
159         {
160             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
161             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
162             result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
163             break;
164         }
165         case LayerType::Convolution2d:
166         {
167             auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
168
169             const TensorInfo input  = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
170                                                        dataType);
171             const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
172             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
173
174             const Convolution2dDescriptor& descriptor  = cLayer->GetParameters();
175
176             // Construct optional biases object based on the value of m_BiasEnabled
177             Optional<TensorInfo> biases;
178             if (descriptor.m_BiasEnabled)
179             {
180                 biases =
181                     OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
182             }
183
184             result = layerSupportObject->IsConvolution2dSupported(
185                                               input,
186                                               output,
187                                               descriptor,
188                                               OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
189                                               biases,
190                                               reason);
191             break;
192         }
193         case LayerType::MemCopy:
194         {
195             // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
196             // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
197             result = backendId == Compute::CpuRef || backendId == Compute::Undefined
198                 || backendId == Compute::CpuAcc || backendId == Compute::GpuAcc;
199             reason.value() = "Unsupported backend type";
200             break;
201         }
202         case LayerType::DepthwiseConvolution2d:
203         {
204             auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
205             const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
206                                                        dataType);
207             const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
208             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
209
210             const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
211
212             // Construct optional biases object based on the value of m_BiasEnabled
213             Optional<TensorInfo> biases;
214             if (descriptor.m_BiasEnabled)
215             {
216                 biases =
217                     OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
218             }
219
220             result = layerSupportObject->IsDepthwiseConvolutionSupported(
221                                                      input,
222                                                      output,
223                                                      descriptor,
224                                                      OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
225                                                      biases,
226                                                      reason);
227             break;
228         }
229         case LayerType::FakeQuantization:
230         {
231             auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
232             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
233             result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
234                                                                      cLayer->GetParameters(),
235                                                                      reason);
236             break;
237         }
238         case LayerType::Floor:
239         {
240             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
241             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
242             result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
243                                                           OverrideDataType(output, dataType),
244                                                           reason);
245             break;
246         }
247         case LayerType::FullyConnected:
248         {
249             auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
250             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
251             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
252             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
253
254             TensorInfo biasInfo;
255             const TensorInfo * biasInfoPtr = nullptr;
256             static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
257             static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
258             static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
259
260             const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
261             if (descriptor.m_BiasEnabled)
262             {
263                 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
264                 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
265                 biasInfoPtr = &biasInfo;
266             }
267             else
268             {
269                 // If biases are not enabled pass a dummy tensorinfo for the validation
270                 switch(input.GetDataType())
271                 {
272                     case DataType::Float16:
273                     {
274                         biasInfoPtr = &dummyFloat16Bias;
275                         break;
276                     }
277                     case DataType::Float32:
278                     {
279                         biasInfoPtr = &dummyFloat32Bias;
280                         break;
281                     }
282                     case DataType::QuantisedAsymm8:
283                     {
284                         biasInfoPtr = &dummyQA8Bias;
285                         break;
286                     }
287                     default:
288                     {
289                         BOOST_ASSERT_MSG(false, "Unexpected bias type");
290                     }
291                 }
292             }
293
294             result = layerSupportObject->IsFullyConnectedSupported(
295                                                OverrideDataType(input, dataType),
296                                                OverrideDataType(output, dataType),
297                                                OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
298                                                *biasInfoPtr,
299                                                descriptor,
300                                                reason);
301             break;
302         }
303         case LayerType::Input:
304         {
305             const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
306             result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
307             break;
308         }
309         case LayerType::L2Normalization:
310         {
311             auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
312             const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
313
314             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
315             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
316
317             result = layerSupportObject->IsL2NormalizationSupported(
318                                                 OverrideDataType(input, dataType),
319                                                 OverrideDataType(output, dataType),
320                                                 descriptor,
321                                                 reason);
322             break;
323         }
324         case LayerType::Lstm:
325         {
326             auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
327             const LstmDescriptor& descriptor = cLayer->GetParameters();
328
329             // All inputs.
330             const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
331                                                        dataType);
332             const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
333                                                                dataType);
334             const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
335                                                              dataType);
336             // All outputs
337             const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
338             const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
339             const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
340             const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
341
342             // Basic parameters
343             const TensorInfo& inputToForgetWeights
344                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
345             const TensorInfo& inputToCellWeights
346                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
347             const TensorInfo& inputToOutputWeights
348                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
349             const TensorInfo& recurrentToForgetWeights
350                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
351             const TensorInfo& recurrentToCellWeights
352                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
353             const TensorInfo& recurrentToOutputWeights
354                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
355             const TensorInfo& forgetGateBias
356                     = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
357             const TensorInfo& cellBias
358                     = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
359             const TensorInfo& outputGateBias
360                     = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
361
362             // Optional parameters
363             const TensorInfo* inputToInputWeights = nullptr;
364             const TensorInfo* recurrentToInputWeights = nullptr;
365             const TensorInfo* cellToInputWeights = nullptr;
366             const TensorInfo* inputGateBias = nullptr;
367             const TensorInfo* projectionWeights = nullptr;
368             const TensorInfo* projectionBias = nullptr;
369             const TensorInfo* cellToForgetWeights = nullptr;
370             const TensorInfo* cellToOutputWeights = nullptr;
371
372             TensorInfo optInputToInputWeights;
373             TensorInfo optRecurrentToInputWeights;
374             TensorInfo optCellToInputWeights;
375             TensorInfo optInputGateBias;
376             TensorInfo optProjectionWeights;
377             TensorInfo optProjectionBias;
378             TensorInfo optCellToForgetWeights;
379             TensorInfo optCellToOutputWeights;
380
381             if(!descriptor.m_CifgEnabled)
382             {
383                 optInputToInputWeights =
384                     OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
385                 inputToInputWeights = &optInputToInputWeights;
386
387                 optRecurrentToInputWeights =
388                     OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
389                 recurrentToInputWeights = &optRecurrentToInputWeights;
390                 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
391                 {
392                     optCellToInputWeights =
393                         OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
394                     cellToInputWeights = &optCellToInputWeights;
395                 }
396                 optInputGateBias =
397                        OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
398                 inputGateBias = &optInputGateBias;
399             }
400
401             if(descriptor.m_ProjectionEnabled)
402             {
403                 optProjectionWeights =
404                     OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
405                 projectionWeights = &optProjectionWeights;
406                 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
407                 {
408                     optProjectionBias =
409                         OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
410                     projectionBias = &optProjectionBias;
411                 }
412             }
413
414             if(descriptor.m_PeepholeEnabled)
415             {
416                 optCellToForgetWeights =
417                     OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
418                 cellToForgetWeights = &optCellToForgetWeights;
419                 optCellToOutputWeights =
420                     OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
421                 cellToOutputWeights = &optCellToOutputWeights;
422             }
423
424             result = layerSupportObject->IsLstmSupported(
425                                      input,
426                                      outputStateIn,
427                                      cellStateIn,
428                                      scratchBuffer,
429                                      outputStateOut,
430                                      cellStateOut,
431                                      output,
432                                      descriptor,
433                                      inputToForgetWeights,
434                                      inputToCellWeights,
435                                      inputToOutputWeights,
436                                      recurrentToForgetWeights,
437                                      recurrentToCellWeights,
438                                      recurrentToOutputWeights,
439                                      forgetGateBias,
440                                      cellBias,
441                                      outputGateBias,
442                                      inputToInputWeights,
443                                      recurrentToInputWeights,
444                                      cellToInputWeights,
445                                      inputGateBias,
446                                      projectionWeights,
447                                      projectionBias,
448                                      cellToForgetWeights,
449                                      cellToOutputWeights,
450                                      reason);
451             break;
452         }
453         case LayerType::Merger:
454         {
455             auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
456
457             // Get vector of all inputs.
458             auto getTensorInfo = [&dataType](const InputSlot& slot)
459                 {
460                     return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
461                 };
462             auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
463             auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
464             std::vector<TensorInfo> inputs(beginI, endI);
465
466             auto getTensorInfoPtr = [](const TensorInfo& info)
467                 {
468                     return &info;
469                 };
470             auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
471             auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
472             std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
473
474             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
475
476             result = layerSupportObject->IsMergerSupported(inputPtrs, output, cLayer->GetParameters(), reason);
477             break;
478         }
479         case LayerType::Multiplication:
480         {
481             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
482             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
483             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
484             result = layerSupportObject->IsMultiplicationSupported(
485                                                OverrideDataType(input0, dataType),
486                                                OverrideDataType(input1, dataType),
487                                                OverrideDataType(output, dataType),
488                                                reason);
489             break;
490         }
491         case LayerType::Normalization:
492         {
493             auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
494             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
495             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
496             result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
497                                                                   OverrideDataType(output, dataType),
498                                                                   cLayer->GetParameters(),
499                                                                   reason);
500             break;
501         }
502         case LayerType::Output:
503         {
504             const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
505             result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
506             break;
507         }
508         case LayerType::Permute:
509         {
510             auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
511             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
512             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
513             result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
514                                                             OverrideDataType(output, dataType),
515                                                             cLayer->GetParameters(),
516                                                             reason);
517             break;
518         }
519         case LayerType::Pad:
520         {
521             auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
522             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
523             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
524             result = layerSupportObject->IsPadSupported(
525                                     OverrideDataType(input, dataType),
526                                     OverrideDataType(output, dataType),
527                                     cLayer->GetParameters(),
528                                     reason);
529             break;
530         }
531         case LayerType::Pooling2d:
532         {
533             auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
534             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
535             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
536             result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
537                                                               OverrideDataType(output, dataType),
538                                                               cLayer->GetParameters(),
539                                                               reason);
540             break;
541         }
542         case LayerType::Division:
543         {
544             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
545             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
546             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
547             result = layerSupportObject->IsDivisionSupported(
548                                          OverrideDataType(input0, dataType),
549                                          OverrideDataType(input1, dataType),
550                                          OverrideDataType(output, dataType),
551                                          reason);
552             break;
553         }
554         case LayerType::Reshape:
555         {
556             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
557             result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType), reason);
558             break;
559         }
560         case LayerType::ResizeBilinear:
561         {
562             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
563             result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType), reason);
564             break;
565         }
566         case LayerType::Softmax:
567         {
568             auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
569             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
570             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
571             result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
572                                                             OverrideDataType(output, dataType),
573                                                             cLayer->GetParameters(),
574                                                             reason);
575             break;
576         }
577         case LayerType::SpaceToBatchNd:
578         {
579             auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
580             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
581             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
582             result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
583                                                                    OverrideDataType(output, dataType),
584                                                                    cLayer->GetParameters(),
585                                                                    reason);
586             break;
587         }
588         case LayerType::Splitter:
589         {
590             auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
591             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
592             result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
593                                                              cLayer->GetParameters(),
594                                                              reason);
595             break;
596         }
597         case LayerType::StridedSlice:
598         {
599             auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
600             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
601             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
602             result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
603                                                                  OverrideDataType(output, dataType),
604                                                                  cLayer->GetParameters(),
605                                                                  reason);
606             break;
607         }
608         case LayerType::Subtraction:
609         {
610             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
611             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
612             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
613             result = layerSupportObject->IsSubtractionSupported(
614                                             OverrideDataType(input0, dataType),
615                                             OverrideDataType(input1, dataType),
616                                             OverrideDataType(output, dataType),
617                                             reason);
618             break;
619         }
620         case LayerType::Mean:
621         {
622             auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
623             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
624             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
625             result = layerSupportObject->IsMeanSupported(
626                                      OverrideDataType(input, dataType),
627                                      OverrideDataType(output, dataType),
628                                      cLayer->GetParameters(),
629                                      reason);
630             break;
631         }
632         default:
633         {
634             BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
635             reason.value() = "Unrecognised layer type";
636             result = false;
637             break;
638         }
639     }
640     return result;
641 }
642
643 bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
644                                         Optional<DataType> dataType,
645                                         std::string& outReasonIfUnsupported)
646 {
647     auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
648     return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
649 }
650
651 }