IVGCVSW-2093 Add SpaceToBatchNd layer and corresponding no-op factory implementations
[platform/upstream/armnn.git] / src / backends / WorkloadFactory.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <backends/WorkloadFactory.hpp>
6 #include <backends/LayerSupportRegistry.hpp>
7
8 #include <armnn/Types.hpp>
9 #include <armnn/LayerSupport.hpp>
10 #include <Layer.hpp>
11 #include <LayersFwd.hpp>
12 #include "CpuTensorHandle.hpp"
13
14 #include <boost/cast.hpp>
15 #include <cstring>
16 #include <boost/iterator/transform_iterator.hpp>
17
18 namespace armnn
19 {
20
21 namespace
22 {
23
24 const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
25 {
26     if (!type)
27     {
28         return info;
29     }
30
31     return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
32 }
33
34 Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
35 {
36     if (!weightsType)
37     {
38         return weightsType;
39     }
40
41     switch(weightsType.value())
42     {
43         case DataType::Float16:
44         case DataType::Float32:
45             return weightsType;
46         case DataType::QuantisedAsymm8:
47             return DataType::Signed32;
48         default:
49             BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
50     }
51     return EmptyOptional();
52 }
53
54 } // anonymous namespace
55
56 bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
57                                         const IConnectableLayer& connectableLayer,
58                                         Optional<DataType> dataType,
59                                         std::string& outReasonIfUnsupported)
60 {
61     Optional<std::string&> reason = outReasonIfUnsupported;
62     bool result;
63     const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
64
65     auto const& layerSupportRegistry = LayerSupportRegistryInstance();
66     auto layerSupportFactory = layerSupportRegistry.GetFactory(backendId);
67     auto layerSupportObject = layerSupportFactory(EmptyInitializer());
68
69     switch(layer.GetType())
70     {
71         case LayerType::Activation:
72         {
73             auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
74             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
75             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
76             result = layerSupportObject->IsActivationSupported(
77                                            OverrideDataType(input, dataType),
78                                            OverrideDataType(output, dataType),
79                                            cLayer->GetParameters(),
80                                            reason);
81             break;
82         }
83         case LayerType::Addition:
84         {
85             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
86             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
87             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
88             result = layerSupportObject->IsAdditionSupported(
89                                         OverrideDataType(input0, dataType),
90                                         OverrideDataType(input1, dataType),
91                                         OverrideDataType(output, dataType),
92                                         reason);
93             break;
94         }
95         case LayerType::BatchNormalization:
96         {
97             auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
98             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
99             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
100             const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
101             const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
102             const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
103             const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
104             result = layerSupportObject->IsBatchNormalizationSupported(
105                                                    OverrideDataType(input, dataType),
106                                                    OverrideDataType(output, dataType),
107                                                    OverrideDataType(mean, dataType),
108                                                    OverrideDataType(var, dataType),
109                                                    OverrideDataType(beta, dataType),
110                                                    OverrideDataType(gamma, dataType),
111                                                    cLayer->GetParameters(),
112                                                    reason);
113             break;
114         }
115         case LayerType::Constant:
116         {
117             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118             result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
119             break;
120         }
121         case LayerType::ConvertFp16ToFp32:
122         {
123             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
124             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
125             result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
126             break;
127         }
128         case LayerType::ConvertFp32ToFp16:
129         {
130             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
131             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
132             result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
133             break;
134         }
135         case LayerType::Convolution2d:
136         {
137             auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
138
139             const TensorInfo input  = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
140                                                        dataType);
141             const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
142             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
143
144             const Convolution2dDescriptor& descriptor  = cLayer->GetParameters();
145
146             // Construct optional biases object based on the value of m_BiasEnabled
147             Optional<TensorInfo> biases;
148             if (descriptor.m_BiasEnabled)
149             {
150                 biases =
151                     OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
152             }
153
154             result = layerSupportObject->IsConvolution2dSupported(
155                                               input,
156                                               output,
157                                               descriptor,
158                                               OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
159                                               biases,
160                                               reason);
161             break;
162         }
163         case LayerType::MemCopy:
164         {
165             // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
166             // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
167             result = backendId == Compute::CpuRef || backendId == Compute::Undefined
168                 || backendId == Compute::CpuAcc || backendId == Compute::GpuAcc;
169             reason.value() = "Unsupported backend type";
170             break;
171         }
172         case LayerType::DepthwiseConvolution2d:
173         {
174             auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
175             const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
176                                                        dataType);
177             const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
178             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
179
180             const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
181
182             // Construct optional biases object based on the value of m_BiasEnabled
183             Optional<TensorInfo> biases;
184             if (descriptor.m_BiasEnabled)
185             {
186                 biases =
187                     OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
188             }
189
190             result = layerSupportObject->IsDepthwiseConvolutionSupported(
191                                                      input,
192                                                      output,
193                                                      descriptor,
194                                                      OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
195                                                      biases,
196                                                      reason);
197             break;
198         }
199         case LayerType::FakeQuantization:
200         {
201             auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
202             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
203             result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
204                                                                      cLayer->GetParameters(),
205                                                                      reason);
206             break;
207         }
208         case LayerType::Floor:
209         {
210             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
211             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
212             result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
213                                                           OverrideDataType(output, dataType),
214                                                           reason);
215             break;
216         }
217         case LayerType::FullyConnected:
218         {
219             auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
220             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
221             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
222             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
223
224             TensorInfo biasInfo;
225             const TensorInfo * biasInfoPtr = nullptr;
226             static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
227             static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
228             static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
229
230             const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
231             if (descriptor.m_BiasEnabled)
232             {
233                 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
234                 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
235                 biasInfoPtr = &biasInfo;
236             }
237             else
238             {
239                 // If biases are not enabled pass a dummy tensorinfo for the validation
240                 switch(input.GetDataType())
241                 {
242                     case DataType::Float16:
243                     {
244                         biasInfoPtr = &dummyFloat16Bias;
245                         break;
246                     }
247                     case DataType::Float32:
248                     {
249                         biasInfoPtr = &dummyFloat32Bias;
250                         break;
251                     }
252                     case DataType::QuantisedAsymm8:
253                     {
254                         biasInfoPtr = &dummyQA8Bias;
255                         break;
256                     }
257                     default:
258                     {
259                         BOOST_ASSERT_MSG(false, "Unexpected bias type");
260                     }
261                 }
262             }
263
264             result = layerSupportObject->IsFullyConnectedSupported(
265                                                OverrideDataType(input, dataType),
266                                                OverrideDataType(output, dataType),
267                                                OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
268                                                *biasInfoPtr,
269                                                descriptor,
270                                                reason);
271             break;
272         }
273         case LayerType::Input:
274         {
275             const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
276             result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
277             break;
278         }
279         case LayerType::L2Normalization:
280         {
281             auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
282             const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
283
284             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
285             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
286
287             result = layerSupportObject->IsL2NormalizationSupported(
288                                                 OverrideDataType(input, dataType),
289                                                 OverrideDataType(output, dataType),
290                                                 descriptor,
291                                                 reason);
292             break;
293         }
294         case LayerType::Lstm:
295         {
296             auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
297             const LstmDescriptor& descriptor = cLayer->GetParameters();
298
299             // All inputs.
300             const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
301                                                        dataType);
302             const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
303                                                                dataType);
304             const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
305                                                              dataType);
306             // All outputs
307             const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
308             const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
309             const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
310             const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
311
312             // Basic parameters
313             const TensorInfo& inputToForgetWeights
314                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
315             const TensorInfo& inputToCellWeights
316                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
317             const TensorInfo& inputToOutputWeights
318                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
319             const TensorInfo& recurrentToForgetWeights
320                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
321             const TensorInfo& recurrentToCellWeights
322                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
323             const TensorInfo& recurrentToOutputWeights
324                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
325             const TensorInfo& forgetGateBias
326                     = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
327             const TensorInfo& cellBias
328                     = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
329             const TensorInfo& outputGateBias
330                     = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
331
332             // Optional parameters
333             const TensorInfo* inputToInputWeights = nullptr;
334             const TensorInfo* recurrentToInputWeights = nullptr;
335             const TensorInfo* cellToInputWeights = nullptr;
336             const TensorInfo* inputGateBias = nullptr;
337             const TensorInfo* projectionWeights = nullptr;
338             const TensorInfo* projectionBias = nullptr;
339             const TensorInfo* cellToForgetWeights = nullptr;
340             const TensorInfo* cellToOutputWeights = nullptr;
341
342             TensorInfo optInputToInputWeights;
343             TensorInfo optRecurrentToInputWeights;
344             TensorInfo optCellToInputWeights;
345             TensorInfo optInputGateBias;
346             TensorInfo optProjectionWeights;
347             TensorInfo optProjectionBias;
348             TensorInfo optCellToForgetWeights;
349             TensorInfo optCellToOutputWeights;
350
351             if(!descriptor.m_CifgEnabled)
352             {
353                 optInputToInputWeights =
354                     OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
355                 inputToInputWeights = &optInputToInputWeights;
356
357                 optRecurrentToInputWeights =
358                     OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
359                 recurrentToInputWeights = &optRecurrentToInputWeights;
360                 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
361                 {
362                     optCellToInputWeights =
363                         OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
364                     cellToInputWeights = &optCellToInputWeights;
365                 }
366                 optInputGateBias =
367                        OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
368                 inputGateBias = &optInputGateBias;
369             }
370
371             if(descriptor.m_ProjectionEnabled)
372             {
373                 optProjectionWeights =
374                     OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
375                 projectionWeights = &optProjectionWeights;
376                 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
377                 {
378                     optProjectionBias =
379                         OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
380                     projectionBias = &optProjectionBias;
381                 }
382             }
383
384             if(descriptor.m_PeepholeEnabled)
385             {
386                 optCellToForgetWeights =
387                     OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
388                 cellToForgetWeights = &optCellToForgetWeights;
389                 optCellToOutputWeights =
390                     OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
391                 cellToOutputWeights = &optCellToOutputWeights;
392             }
393
394             result = layerSupportObject->IsLstmSupported(
395                                      input,
396                                      outputStateIn,
397                                      cellStateIn,
398                                      scratchBuffer,
399                                      outputStateOut,
400                                      cellStateOut,
401                                      output,
402                                      descriptor,
403                                      inputToForgetWeights,
404                                      inputToCellWeights,
405                                      inputToOutputWeights,
406                                      recurrentToForgetWeights,
407                                      recurrentToCellWeights,
408                                      recurrentToOutputWeights,
409                                      forgetGateBias,
410                                      cellBias,
411                                      outputGateBias,
412                                      inputToInputWeights,
413                                      recurrentToInputWeights,
414                                      cellToInputWeights,
415                                      inputGateBias,
416                                      projectionWeights,
417                                      projectionBias,
418                                      cellToForgetWeights,
419                                      cellToOutputWeights,
420                                      reason);
421             break;
422         }
423         case LayerType::Merger:
424         {
425             auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
426
427             // Get vector of all inputs.
428             auto getTensorInfo = [&dataType](const InputSlot& slot)
429                 {
430                     return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
431                 };
432             auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
433             auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
434             std::vector<TensorInfo> inputs(beginI, endI);
435
436             auto getTensorInfoPtr = [](const TensorInfo& info)
437                 {
438                     return &info;
439                 };
440             auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
441             auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
442             std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
443
444             result = layerSupportObject->IsMergerSupported(inputPtrs, cLayer->GetParameters(), reason);
445             break;
446         }
447         case LayerType::Multiplication:
448         {
449             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
450             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
451             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
452             result = layerSupportObject->IsMultiplicationSupported(
453                                                OverrideDataType(input0, dataType),
454                                                OverrideDataType(input1, dataType),
455                                                OverrideDataType(output, dataType),
456                                                reason);
457             break;
458         }
459         case LayerType::Normalization:
460         {
461             auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
462             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
463             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
464             result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
465                                                                   OverrideDataType(output, dataType),
466                                                                   cLayer->GetParameters(),
467                                                                   reason);
468             break;
469         }
470         case LayerType::Output:
471         {
472             const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
473             result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
474             break;
475         }
476         case LayerType::Permute:
477         {
478             auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
479             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
480             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
481             result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
482                                                             OverrideDataType(output, dataType),
483                                                             cLayer->GetParameters(),
484                                                             reason);
485             break;
486         }
487         case LayerType::Pad:
488         {
489             auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
490             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
491             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
492             result = layerSupportObject->IsPadSupported(
493                                     OverrideDataType(input, dataType),
494                                     OverrideDataType(output, dataType),
495                                     cLayer->GetParameters(),
496                                     reason);
497             break;
498         }
499         case LayerType::Pooling2d:
500         {
501             auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
502             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
503             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
504             result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
505                                                               OverrideDataType(output, dataType),
506                                                               cLayer->GetParameters(),
507                                                               reason);
508             break;
509         }
510         case LayerType::Division:
511         {
512             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
513             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
514             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
515             result = layerSupportObject->IsDivisionSupported(
516                                          OverrideDataType(input0, dataType),
517                                          OverrideDataType(input1, dataType),
518                                          OverrideDataType(output, dataType),
519                                          reason);
520             break;
521         }
522         case LayerType::Reshape:
523         {
524             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
525             result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType), reason);
526             break;
527         }
528         case LayerType::ResizeBilinear:
529         {
530             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
531             result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType), reason);
532             break;
533         }
534         case LayerType::Softmax:
535         {
536             auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
537             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
538             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
539             result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
540                                                             OverrideDataType(output, dataType),
541                                                             cLayer->GetParameters(),
542                                                             reason);
543             break;
544         }
545         case LayerType::SpaceToBatchNd:
546         {
547             auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
548             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
549             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
550             result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
551                                                                    OverrideDataType(output, dataType),
552                                                                    cLayer->GetParameters(),
553                                                                    reason);
554             break;
555         }
556         case LayerType::Splitter:
557         {
558             auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
559             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
560             result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
561                                                              cLayer->GetParameters(),
562                                                              reason);
563             break;
564         }
565         case LayerType::Subtraction:
566         {
567             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
568             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
569             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
570             result = layerSupportObject->IsSubtractionSupported(
571                                             OverrideDataType(input0, dataType),
572                                             OverrideDataType(input1, dataType),
573                                             OverrideDataType(output, dataType),
574                                             reason);
575             break;
576         }
577         case LayerType::Mean:
578         {
579             auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
580             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
581             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
582             result = layerSupportObject->IsMeanSupported(
583                                      OverrideDataType(input, dataType),
584                                      OverrideDataType(output, dataType),
585                                      cLayer->GetParameters(),
586                                      reason);
587             break;
588         }
589         default:
590         {
591             BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
592             reason.value() = "Unrecognised layer type";
593             result = false;
594             break;
595         }
596     }
597     return result;
598 }
599
600 bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
601                                         Optional<DataType> dataType,
602                                         std::string& outReasonIfUnsupported)
603 {
604     auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
605     return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
606 }
607
608 }