Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / WorkloadFactory.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include "WorkloadFactory.hpp"
6 #include "RefWorkloadFactory.hpp"
7 #include "NeonWorkloadFactory.hpp"
8 #include "ClWorkloadFactory.hpp"
9
10 #include "armnn/Types.hpp"
11 #include "armnn/LayerSupport.hpp"
12 #include "Layer.hpp"
13 #include "LayersFwd.hpp"
14 #include "CpuTensorHandle.hpp"
15
16 #include <boost/cast.hpp>
17 #include <cstring>
18 #include <boost/iterator/transform_iterator.hpp>
19
20 namespace armnn
21 {
22
23 namespace
24 {
25     const TensorInfo OverrideDataType(const TensorInfo& info, boost::optional<DataType> type)
26     {
27         if (type == boost::none)
28         {
29             return info;
30         }
31
32         return TensorInfo(info.GetShape(), type.get(), info.GetQuantizationScale(), info.GetQuantizationOffset());
33     }
34
35     boost::optional<DataType> GetBiasTypeFromWeightsType(boost::optional<DataType> weightsType)
36     {
37         if (weightsType == boost::none)
38         {
39             return weightsType;
40         }
41
42         switch(weightsType.get())
43         {
44             case DataType::Float16:
45             case DataType::Float32:
46                 return weightsType;
47             case DataType::QuantisedAsymm8:
48                 return DataType::Signed32;
49             default:
50                 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
51         }
52         return boost::none;
53     }
54 }
55
56 bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, boost::optional<DataType> dataType,
57     std::string& outReasonIfUnsupported)
58 {
59     constexpr size_t reasonCapacity = 1024;
60     char reason[reasonCapacity];
61     bool result;
62     switch(layer.GetType())
63     {
64         case LayerType::Activation:
65         {
66             auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
67             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
68             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
69             result = IsActivationSupported(compute,
70                                            OverrideDataType(input, dataType),
71                                            OverrideDataType(output, dataType),
72                                            cLayer->GetParameters(),
73                                            reason,
74                                            reasonCapacity);
75             break;
76         }
77         case LayerType::Addition:
78         {
79             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
80             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
81             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
82             result = IsAdditionSupported(compute,
83                                         OverrideDataType(input0, dataType),
84                                         OverrideDataType(input1, dataType),
85                                         OverrideDataType(output, dataType),
86                                         reason,
87                                         reasonCapacity);
88             break;
89         }
90         case LayerType::BatchNormalization:
91         {
92             auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
93             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
94             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
95             const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
96             const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
97             const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
98             const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
99             result = IsBatchNormalizationSupported(compute,
100                                                    OverrideDataType(input, dataType),
101                                                    OverrideDataType(output, dataType),
102                                                    OverrideDataType(mean, dataType),
103                                                    OverrideDataType(var, dataType),
104                                                    OverrideDataType(beta, dataType),
105                                                    OverrideDataType(gamma, dataType),
106                                                    cLayer->GetParameters(),
107                                                    reason, reasonCapacity);
108             break;
109         }
110         case LayerType::Constant:
111         {
112             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
113             result = IsConstantSupported(compute, OverrideDataType(output, dataType), reason, reasonCapacity);
114             break;
115         }
116         case LayerType::ConvertFp16ToFp32:
117         {
118             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
119             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
120             result = IsConvertFp16ToFp32Supported(compute, input, output, reason, reasonCapacity);
121             break;
122         }
123         case LayerType::ConvertFp32ToFp16:
124         {
125             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
126             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
127             result = IsConvertFp32ToFp16Supported(compute, input, output, reason, reasonCapacity);
128             break;
129         }
130         case LayerType::Convolution2d:
131         {
132             auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
133             const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), dataType);
134             const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
135             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
136
137             TensorInfo biasInfo;
138             const TensorInfo * biasInfoPtr = nullptr;
139             static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
140             static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
141             static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
142
143             const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
144
145             if (descriptor.m_BiasEnabled)
146             {
147                 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
148                 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
149                 biasInfoPtr = &biasInfo;
150             }
151             else
152             {
153                 // If biases are not enabled pass a dummy tensorinfo for the validation.
154                 switch(input.GetDataType())
155                 {
156                     case DataType::Float16:
157                     {
158                         biasInfoPtr = &dummyFloat16Bias;
159                         break;
160                     }
161                     case DataType::Float32:
162                     {
163                         biasInfoPtr = &dummyFloat32Bias;
164                         break;
165                     }
166                     case DataType::QuantisedAsymm8:
167                     {
168                         biasInfoPtr = &dummyQA8Bias;
169                         break;
170                     }
171                     default:
172                     {
173                         BOOST_ASSERT_MSG(false, "Unexpected input type");
174                     }
175                 }
176             }
177
178             result = IsConvolution2dSupported(compute,
179                                               input,
180                                               output,
181                                               descriptor,
182                                               OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
183                                               *biasInfoPtr,
184                                               reason,
185                                               reasonCapacity);
186             break;
187         }
188         case LayerType::MemCopy:
189         {
190             // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
191             // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
192             result = compute == Compute::CpuRef || compute == Compute::Undefined
193                 || compute == Compute::CpuAcc || compute == Compute::GpuAcc;
194             strcpy(reason, "Unsupported backend type");
195             break;
196         }
197         case LayerType::DepthwiseConvolution2d:
198         {
199             auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
200             const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
201                                                        dataType);
202             const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
203             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
204
205             TensorInfo biasInfo;
206             const TensorInfo * biasInfoPtr = nullptr;
207             static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
208             static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
209             static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
210
211             const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
212             if (descriptor.m_BiasEnabled)
213             {
214                 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
215                 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
216                 biasInfoPtr = &biasInfo;
217             }
218             else
219             {
220                 // If biases are not enabled pass a dummy tensorinfo for the validation
221                 switch(input.GetDataType())
222                 {
223                     case DataType::Float16:
224                     {
225                         biasInfoPtr = &dummyFloat16Bias;
226                         break;
227                     }
228                     case DataType::Float32:
229                     {
230                         biasInfoPtr = &dummyFloat32Bias;
231                         break;
232                     }
233                     case DataType::QuantisedAsymm8:
234                     {
235                         biasInfoPtr = &dummyQA8Bias;
236                         break;
237                     }
238                     default:
239                     {
240                         BOOST_ASSERT_MSG(false, "Unexpected bias type");
241                     }
242                 }
243             }
244
245
246             result = IsDepthwiseConvolutionSupported(compute,
247                                                      input,
248                                                      output,
249                                                      descriptor,
250                                                      OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
251                                                      *biasInfoPtr,
252                                                      reason,
253                                                      reasonCapacity);
254             break;
255         }
256         case LayerType::FakeQuantization:
257         {
258             auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
259             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
260             result = IsFakeQuantizationSupported(compute, OverrideDataType(input, dataType), cLayer->GetParameters(),
261                                                  reason, reasonCapacity);
262             break;
263         }
264         case LayerType::Floor:
265         {
266             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
267             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
268             result = IsFloorSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
269                                       reason, reasonCapacity);
270             break;
271         }
272         case LayerType::FullyConnected:
273         {
274             auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
275             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
276             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
277             BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
278
279             TensorInfo biasInfo;
280             const TensorInfo * biasInfoPtr = nullptr;
281             static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
282             static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
283             static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
284
285             const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
286             if (descriptor.m_BiasEnabled)
287             {
288                 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
289                 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
290                 biasInfoPtr = &biasInfo;
291             }
292             else
293             {
294                 // If biases are not enabled pass a dummy tensorinfo for the validation
295                 switch(input.GetDataType())
296                 {
297                     case DataType::Float16:
298                     {
299                         biasInfoPtr = &dummyFloat16Bias;
300                         break;
301                     }
302                     case DataType::Float32:
303                     {
304                         biasInfoPtr = &dummyFloat32Bias;
305                         break;
306                     }
307                     case DataType::QuantisedAsymm8:
308                     {
309                         biasInfoPtr = &dummyQA8Bias;
310                         break;
311                     }
312                     default:
313                     {
314                         BOOST_ASSERT_MSG(false, "Unexpected bias type");
315                     }
316                 }
317             }
318
319             result = IsFullyConnectedSupported(compute,
320                                                OverrideDataType(input, dataType),
321                                                OverrideDataType(output, dataType),
322                                                OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
323                                                *biasInfoPtr,
324                                                descriptor,
325                                                reason,
326                                                reasonCapacity);
327             break;
328         }
329         case LayerType::Input:
330         {
331             const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
332             result = IsInputSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
333             break;
334         }
335         case LayerType::L2Normalization:
336         {
337             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
338             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
339             result = IsL2NormalizationSupported(compute, OverrideDataType(input, dataType),
340                     OverrideDataType(output, dataType), reason, reasonCapacity);
341             break;
342         }
343         case LayerType::Lstm:
344         {
345             auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
346             const LstmDescriptor& descriptor = cLayer->GetParameters();
347
348             // All inputs.
349             const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
350                                                        dataType);
351             const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
352                                                                dataType);
353             const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
354                                                              dataType);
355             // All outputs
356             const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
357             const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
358             const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
359             const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
360
361             // Basic parameters
362             const TensorInfo& inputToForgetWeights
363                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
364             const TensorInfo& inputToCellWeights
365                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
366             const TensorInfo& inputToOutputWeights
367                     = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
368             const TensorInfo& recurrentToForgetWeights
369                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
370             const TensorInfo& recurrentToCellWeights
371                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
372             const TensorInfo& recurrentToOutputWeights
373                     = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
374             const TensorInfo& forgetGateBias
375                     = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
376             const TensorInfo& cellBias
377                     = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
378             const TensorInfo& outputGateBias
379                     = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
380
381             // Optional parameters
382             const TensorInfo* inputToInputWeights = nullptr;
383             const TensorInfo* recurrentToInputWeights = nullptr;
384             const TensorInfo* cellToInputWeights = nullptr;
385             const TensorInfo* inputGateBias = nullptr;
386             const TensorInfo* projectionWeights = nullptr;
387             const TensorInfo* projectionBias = nullptr;
388             const TensorInfo* cellToForgetWeights = nullptr;
389             const TensorInfo* cellToOutputWeights = nullptr;
390
391             TensorInfo optInputToInputWeights;
392             TensorInfo optRecurrentToInputWeights;
393             TensorInfo optCellToInputWeights;
394             TensorInfo optInputGateBias;
395             TensorInfo optProjectionWeights;
396             TensorInfo optProjectionBias;
397             TensorInfo optCellToForgetWeights;
398             TensorInfo optCellToOutputWeights;
399
400             if(!descriptor.m_CifgEnabled)
401             {
402                 optInputToInputWeights =
403                     OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
404                 inputToInputWeights = &optInputToInputWeights;
405
406                 optRecurrentToInputWeights =
407                     OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
408                 recurrentToInputWeights = &optRecurrentToInputWeights;
409                 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
410                 {
411                     optCellToInputWeights =
412                         OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
413                     cellToInputWeights = &optCellToInputWeights;
414                 }
415                 optInputGateBias =
416                        OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
417                 inputGateBias = &optInputGateBias;
418             }
419
420             if(descriptor.m_ProjectionEnabled)
421             {
422                 optProjectionWeights =
423                     OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
424                 projectionWeights = &optProjectionWeights;
425                 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
426                 {
427                     optProjectionBias =
428                         OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
429                     projectionBias = &optProjectionBias;
430                 }
431             }
432
433             if(descriptor.m_PeepholeEnabled)
434             {
435                 optCellToForgetWeights =
436                     OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
437                 cellToForgetWeights = &optCellToForgetWeights;
438                 optCellToOutputWeights =
439                     OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
440                 cellToOutputWeights = &optCellToOutputWeights;
441             }
442
443             result = IsLstmSupported(compute,
444                                      input,
445                                      outputStateIn,
446                                      cellStateIn,
447                                      scratchBuffer,
448                                      outputStateOut,
449                                      cellStateOut,
450                                      output,
451                                      descriptor,
452                                      inputToForgetWeights,
453                                      inputToCellWeights,
454                                      inputToOutputWeights,
455                                      recurrentToForgetWeights,
456                                      recurrentToCellWeights,
457                                      recurrentToOutputWeights,
458                                      forgetGateBias,
459                                      cellBias,
460                                      outputGateBias,
461                                      inputToInputWeights,
462                                      recurrentToInputWeights,
463                                      cellToInputWeights,
464                                      inputGateBias,
465                                      projectionWeights,
466                                      projectionBias,
467                                      cellToForgetWeights,
468                                      cellToOutputWeights,
469                                      reason,
470                                      reasonCapacity);
471             break;
472         }
473         case LayerType::Merger:
474         {
475             auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
476
477             // Get vector of all inputs.
478             auto getTensorInfo = [&dataType](const InputSlot& slot)
479                 {
480                     return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
481                 };
482             auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
483             auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
484             std::vector<TensorInfo> inputs(beginI, endI);
485
486             auto getTensorInfoPtr = [](const TensorInfo& info)
487                 {
488                     return &info;
489                 };
490             auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
491             auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
492             std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
493
494             result = IsMergerSupported(compute, inputPtrs, cLayer->GetParameters(), reason, reasonCapacity);
495             break;
496         }
497         case LayerType::Multiplication:
498         {
499             const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
500             const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
501             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
502             result = IsMultiplicationSupported(compute,
503                                                OverrideDataType(input0, dataType),
504                                                OverrideDataType(input1, dataType),
505                                                OverrideDataType(output, dataType),
506                                                reason,
507                                                reasonCapacity);
508             break;
509         }
510         case LayerType::Normalization:
511         {
512             auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
513             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
514             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
515             result = IsNormalizationSupported(compute, OverrideDataType(input, dataType),
516                                               OverrideDataType(output, dataType), cLayer->GetParameters(), reason,
517                                               reasonCapacity);
518             break;
519         }
520         case LayerType::Output:
521         {
522             const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
523             result = IsOutputSupported(compute, OverrideDataType(output, dataType), reason, reasonCapacity);
524             break;
525         }
526         case LayerType::Permute:
527         {
528             auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
529             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
530             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
531             result = IsPermuteSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
532                                         cLayer->GetParameters(), reason, reasonCapacity);
533             break;
534         }
535         case LayerType::Pooling2d:
536         {
537             auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
538             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
539             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
540             result = IsPooling2dSupported(compute, OverrideDataType(input, dataType),
541                                           OverrideDataType(output, dataType), cLayer->GetParameters(), reason,
542                                           reasonCapacity);
543             break;
544         }
545         case LayerType::Reshape:
546         {
547             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
548             result = IsReshapeSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
549             break;
550         }
551         case LayerType::ResizeBilinear:
552         {
553             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
554             result = IsResizeBilinearSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
555             break;
556         }
557         case LayerType::Softmax:
558         {
559             auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
560             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
561             const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
562             result = IsSoftmaxSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
563                                         cLayer->GetParameters(), reason, reasonCapacity);
564             break;
565         }
566         case LayerType::Splitter:
567         {
568             auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
569             const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
570             result = IsSplitterSupported(compute, OverrideDataType(input, dataType), cLayer->GetParameters(), reason,
571                                          reasonCapacity);
572             break;
573         }
574         default:
575         {
576             BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
577             strcpy(reason, "Unrecognised layer type");
578             result = false;
579             break;
580         }
581     }
582     outReasonIfUnsupported = reason;
583     return result;
584 }
585
586 bool IWorkloadFactory::IsLayerSupported(const Layer& layer, boost::optional<DataType> dataType,
587                                         std::string& outReasonIfUnsupported)
588 {
589     return IsLayerSupported(layer.GetComputeDevice(), layer, dataType, outReasonIfUnsupported);
590 }
591
592 }