IVGCVSW-2093 Add SpaceToBatchNd layer and corresponding no-op factory implementations
[platform/upstream/armnn.git] / src / backends / backendsCommon / WorkloadFactory.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "CpuTensorHandle.hpp"
7
8 #include <Layer.hpp>
9 #include <LayersFwd.hpp>
10
11 #include <armnn/Types.hpp>
12 #include <armnn/LayerSupport.hpp>
13
14 #include <backendsCommon/LayerSupportRegistry.hpp>
15 #include <backendsCommon/WorkloadFactory.hpp>
16
17 #include <boost/cast.hpp>
18 #include <boost/iterator/transform_iterator.hpp>
19
20 #include <cstring>
21
22 namespace armnn
23 {
24
25 namespace
26 {
27
28 const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
29 {
30     if (!type)
31     {
32         return info;
33     }
34
35     return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
36 }
37
38 Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
39 {
40     if (!weightsType)
41     {
42         return weightsType;
43     }
44
45     switch(weightsType.value())
46     {
47         case DataType::Float16:
48         case DataType::Float32:
49             return weightsType;
50         case DataType::QuantisedAsymm8:
51             return DataType::Signed32;
52         default:
53             BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
54     }
55     return EmptyOptional();
56 }
57
58 } // anonymous namespace
59
60 bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
61                                         const IConnectableLayer& connectableLayer,
62                                         Optional<DataType> dataType,
63                                         std::string& outReasonIfUnsupported)
64 {
65     Optional<std::string&> reason = outReasonIfUnsupported;
66     bool result;
67     const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
68
69     auto const& layerSupportRegistry = LayerSupportRegistryInstance();
70     auto layerSupportFactory = layerSupportRegistry.GetFactory(backendId);
71     auto layerSupportObject = layerSupportFactory(EmptyInitializer());
72
73     switch(layer.GetType())
74     {
75         case LayerType::Activation:
76         {
77             auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
78             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
79             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
80             result = layerSupportObject->IsActivationSupported(
81                                            OverrideDataType(input, dataType),
82                                            OverrideDataType(output, dataType),
83                                            cLayer->GetParameters(),
84                                            reason);
85             break;
86         }
87         case LayerType::Addition:
88         {
89             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
90             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
91             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
92             result = layerSupportObject->IsAdditionSupported(
93                                         OverrideDataType(input0, dataType),
94                                         OverrideDataType(input1, dataType),
95                                         OverrideDataType(output, dataType),
96                                         reason);
97             break;
98         }
99         case LayerType::BatchNormalization:
100         {
101             auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
102             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
103             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
104             const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
105             const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
106             const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
107             const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
108             result = layerSupportObject->IsBatchNormalizationSupported(
109                                                    OverrideDataType(input, dataType),
110                                                    OverrideDataType(output, dataType),
111                                                    OverrideDataType(mean, dataType),
112                                                    OverrideDataType(var, dataType),
113                                                    OverrideDataType(beta, dataType),
114                                                    OverrideDataType(gamma, dataType),
115                                                    cLayer->GetParameters(),
116                                                    reason);
117             break;
118         }
119         case LayerType::Constant:
120         {
121             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
122             result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
123             break;
124         }
125         case LayerType::ConvertFp16ToFp32:
126         {
127             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
128             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
129             result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
130             break;
131         }
132         case LayerType::ConvertFp32ToFp16:
133         {
134             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
135             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
136             result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
137             break;
138         }
139         case LayerType::Convolution2d:
140         {
141             auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
142
143             const TensorInfo input  = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
144                                                        dataType);
145             const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
146             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
147
148             const Convolution2dDescriptor& descriptor  = cLayer->GetParameters();
149
150             // Construct optional biases object based on the value of m_BiasEnabled
151             Optional<TensorInfo> biases;
152             if (descriptor.m_BiasEnabled)
153             {
154                 biases =
155                     OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
156             }
157
158             result = layerSupportObject->IsConvolution2dSupported(
159                                               input,
160                                               output,
161                                               descriptor,
162                                               OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
163                                               biases,
164                                               reason);
165             break;
166         }
167         case LayerType::MemCopy:
168         {
169             // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
170             // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
171             result = backendId == Compute::CpuRef || backendId == Compute::Undefined
172                 || backendId == Compute::CpuAcc || backendId == Compute::GpuAcc;
173             reason.value() = "Unsupported backend type";
174             break;
175         }
176         case LayerType::DepthwiseConvolution2d:
177         {
178             auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
179             const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
180                                                        dataType);
181             const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
182             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
183
184             const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
185
186             // Construct optional biases object based on the value of m_BiasEnabled
187             Optional<TensorInfo> biases;
188             if (descriptor.m_BiasEnabled)
189             {
190                 biases =
191                     OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
192             }
193
194             result = layerSupportObject->IsDepthwiseConvolutionSupported(
195                                                      input,
196                                                      output,
197                                                      descriptor,
198                                                      OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
199                                                      biases,
200                                                      reason);
201             break;
202         }
203         case LayerType::FakeQuantization:
204         {
205             auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
206             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
207             result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
208                                                                      cLayer->GetParameters(),
209                                                                      reason);
210             break;
211         }
212         case LayerType::Floor:
213         {
214             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
215             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
216             result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
217                                                           OverrideDataType(output, dataType),
218                                                           reason);
219             break;
220         }
221         case LayerType::FullyConnected:
222         {
223             auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
224             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
225             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
226             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
227
228             TensorInfo biasInfo;
229             const TensorInfo * biasInfoPtr = nullptr;
230             static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
231             static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
232             static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
233
234             const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
235             if (descriptor.m_BiasEnabled)
236             {
237                 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
238                 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
239                 biasInfoPtr = &biasInfo;
240             }
241             else
242             {
243                 // If biases are not enabled pass a dummy tensorinfo for the validation
244                 switch(input.GetDataType())
245                 {
246                     case DataType::Float16:
247                     {
248                         biasInfoPtr = &dummyFloat16Bias;
249                         break;
250                     }
251                     case DataType::Float32:
252                     {
253                         biasInfoPtr = &dummyFloat32Bias;
254                         break;
255                     }
256                     case DataType::QuantisedAsymm8:
257                     {
258                         biasInfoPtr = &dummyQA8Bias;
259                         break;
260                     }
261                     default:
262                     {
263                         BOOST_ASSERT_MSG(false, "Unexpected bias type");
264                     }
265                 }
266             }
267
268             result = layerSupportObject->IsFullyConnectedSupported(
269                                                OverrideDataType(input, dataType),
270                                                OverrideDataType(output, dataType),
271                                                OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
272                                                *biasInfoPtr,
273                                                descriptor,
274                                                reason);
275             break;
276         }
277         case LayerType::Input:
278         {
279             const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
280             result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
281             break;
282         }
283         case LayerType::L2Normalization:
284         {
285             auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
286             const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
287
288             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
289             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
290
291             result = layerSupportObject->IsL2NormalizationSupported(
292                                                 OverrideDataType(input, dataType),
293                                                 OverrideDataType(output, dataType),
294                                                 descriptor,
295                                                 reason);
296             break;
297         }
298         case LayerType::Lstm:
299         {
300             auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
301             const LstmDescriptor& descriptor = cLayer->GetParameters();
302
303             // All inputs.
304             const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
305                                                        dataType);
306             const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
307                                                                dataType);
308             const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
309                                                              dataType);
310             // All outputs
311             const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
312             const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
313             const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
314             const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
315
316             // Basic parameters
317             const TensorInfo& inputToForgetWeights
318                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
319             const TensorInfo& inputToCellWeights
320                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
321             const TensorInfo& inputToOutputWeights
322                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
323             const TensorInfo& recurrentToForgetWeights
324                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
325             const TensorInfo& recurrentToCellWeights
326                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
327             const TensorInfo& recurrentToOutputWeights
328                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
329             const TensorInfo& forgetGateBias
330                     = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
331             const TensorInfo& cellBias
332                     = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
333             const TensorInfo& outputGateBias
334                     = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
335
336             // Optional parameters
337             const TensorInfo* inputToInputWeights = nullptr;
338             const TensorInfo* recurrentToInputWeights = nullptr;
339             const TensorInfo* cellToInputWeights = nullptr;
340             const TensorInfo* inputGateBias = nullptr;
341             const TensorInfo* projectionWeights = nullptr;
342             const TensorInfo* projectionBias = nullptr;
343             const TensorInfo* cellToForgetWeights = nullptr;
344             const TensorInfo* cellToOutputWeights = nullptr;
345
346             TensorInfo optInputToInputWeights;
347             TensorInfo optRecurrentToInputWeights;
348             TensorInfo optCellToInputWeights;
349             TensorInfo optInputGateBias;
350             TensorInfo optProjectionWeights;
351             TensorInfo optProjectionBias;
352             TensorInfo optCellToForgetWeights;
353             TensorInfo optCellToOutputWeights;
354
355             if(!descriptor.m_CifgEnabled)
356             {
357                 optInputToInputWeights =
358                     OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
359                 inputToInputWeights = &optInputToInputWeights;
360
361                 optRecurrentToInputWeights =
362                     OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
363                 recurrentToInputWeights = &optRecurrentToInputWeights;
364                 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
365                 {
366                     optCellToInputWeights =
367                         OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
368                     cellToInputWeights = &optCellToInputWeights;
369                 }
370                 optInputGateBias =
371                        OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
372                 inputGateBias = &optInputGateBias;
373             }
374
375             if(descriptor.m_ProjectionEnabled)
376             {
377                 optProjectionWeights =
378                     OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
379                 projectionWeights = &optProjectionWeights;
380                 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
381                 {
382                     optProjectionBias =
383                         OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
384                     projectionBias = &optProjectionBias;
385                 }
386             }
387
388             if(descriptor.m_PeepholeEnabled)
389             {
390                 optCellToForgetWeights =
391                     OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
392                 cellToForgetWeights = &optCellToForgetWeights;
393                 optCellToOutputWeights =
394                     OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
395                 cellToOutputWeights = &optCellToOutputWeights;
396             }
397
398             result = layerSupportObject->IsLstmSupported(
399                                      input,
400                                      outputStateIn,
401                                      cellStateIn,
402                                      scratchBuffer,
403                                      outputStateOut,
404                                      cellStateOut,
405                                      output,
406                                      descriptor,
407                                      inputToForgetWeights,
408                                      inputToCellWeights,
409                                      inputToOutputWeights,
410                                      recurrentToForgetWeights,
411                                      recurrentToCellWeights,
412                                      recurrentToOutputWeights,
413                                      forgetGateBias,
414                                      cellBias,
415                                      outputGateBias,
416                                      inputToInputWeights,
417                                      recurrentToInputWeights,
418                                      cellToInputWeights,
419                                      inputGateBias,
420                                      projectionWeights,
421                                      projectionBias,
422                                      cellToForgetWeights,
423                                      cellToOutputWeights,
424                                      reason);
425             break;
426         }
427         case LayerType::Merger:
428         {
429             auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
430
431             // Get vector of all inputs.
432             auto getTensorInfo = [&dataType](const InputSlot& slot)
433                 {
434                     return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
435                 };
436             auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
437             auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
438             std::vector<TensorInfo> inputs(beginI, endI);
439
440             auto getTensorInfoPtr = [](const TensorInfo& info)
441                 {
442                     return &info;
443                 };
444             auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
445             auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
446             std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
447
448             result = layerSupportObject->IsMergerSupported(inputPtrs, cLayer->GetParameters(), reason);
449             break;
450         }
451         case LayerType::Multiplication:
452         {
453             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
454             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
455             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
456             result = layerSupportObject->IsMultiplicationSupported(
457                                                OverrideDataType(input0, dataType),
458                                                OverrideDataType(input1, dataType),
459                                                OverrideDataType(output, dataType),
460                                                reason);
461             break;
462         }
463         case LayerType::Normalization:
464         {
465             auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
466             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
467             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
468             result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
469                                                                   OverrideDataType(output, dataType),
470                                                                   cLayer->GetParameters(),
471                                                                   reason);
472             break;
473         }
474         case LayerType::Output:
475         {
476             const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
477             result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
478             break;
479         }
480         case LayerType::Permute:
481         {
482             auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
483             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
484             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
485             result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
486                                                             OverrideDataType(output, dataType),
487                                                             cLayer->GetParameters(),
488                                                             reason);
489             break;
490         }
491         case LayerType::Pad:
492         {
493             auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
494             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
495             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
496             result = layerSupportObject->IsPadSupported(
497                                     OverrideDataType(input, dataType),
498                                     OverrideDataType(output, dataType),
499                                     cLayer->GetParameters(),
500                                     reason);
501             break;
502         }
503         case LayerType::Pooling2d:
504         {
505             auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
506             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
507             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
508             result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
509                                                               OverrideDataType(output, dataType),
510                                                               cLayer->GetParameters(),
511                                                               reason);
512             break;
513         }
514         case LayerType::Division:
515         {
516             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
517             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
518             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
519             result = layerSupportObject->IsDivisionSupported(
520                                          OverrideDataType(input0, dataType),
521                                          OverrideDataType(input1, dataType),
522                                          OverrideDataType(output, dataType),
523                                          reason);
524             break;
525         }
526         case LayerType::Reshape:
527         {
528             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
529             result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType), reason);
530             break;
531         }
532         case LayerType::ResizeBilinear:
533         {
534             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
535             result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType), reason);
536             break;
537         }
538         case LayerType::Softmax:
539         {
540             auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
541             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
542             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
543             result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
544                                                             OverrideDataType(output, dataType),
545                                                             cLayer->GetParameters(),
546                                                             reason);
547             break;
548         }
549         case LayerType::SpaceToBatchNd:
550         {
551             auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
552             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
553             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
554             result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
555                                                                    OverrideDataType(output, dataType),
556                                                                    cLayer->GetParameters(),
557                                                                    reason);
558             break;
559         }
560         case LayerType::Splitter:
561         {
562             auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
563             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
564             result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
565                                                              cLayer->GetParameters(),
566                                                              reason);
567             break;
568         }
569         case LayerType::Subtraction:
570         {
571             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
572             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
573             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
574             result = layerSupportObject->IsSubtractionSupported(
575                                             OverrideDataType(input0, dataType),
576                                             OverrideDataType(input1, dataType),
577                                             OverrideDataType(output, dataType),
578                                             reason);
579             break;
580         }
581         case LayerType::Mean:
582         {
583             auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
584             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
585             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
586             result = layerSupportObject->IsMeanSupported(
587                                      OverrideDataType(input, dataType),
588                                      OverrideDataType(output, dataType),
589                                      cLayer->GetParameters(),
590                                      reason);
591             break;
592         }
593         default:
594         {
595             BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
596             reason.value() = "Unrecognised layer type";
597             result = false;
598             break;
599         }
600     }
601     return result;
602 }
603
604 bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
605                                         Optional<DataType> dataType,
606                                         std::string& outReasonIfUnsupported)
607 {
608     auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
609     return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
610 }
611
612 }