IVGCVSW-3173 Extend reference softmax workload to support qsymm16
[platform/upstream/armnn.git] / src / backends / reference / RefLayerSupport.cpp
index a9cddfd..e4bc9bf 100644 (file)
@@ -1060,11 +1060,24 @@ bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
                                          Optional<std::string&> reasonIfUnsupported) const
 {
     ignore_unused(output);
-    ignore_unused(descriptor);
-    return IsSupportedForDataTypeRef(reasonIfUnsupported,
-                                     input.GetDataType(),
-                                     &TrueFunc<>,
-                                     &TrueFunc<>);
+    bool supported = true;
+    std::array<DataType,3> supportedTypes =
+    {
+            DataType::Float32,
+            DataType::QuantisedAsymm8,
+            DataType::QuantisedSymm16
+    };
+
+    supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+                                  "Reference concatenation: output type not supported");
+
+    supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+                                  "Reference concatenation: input type not supported");
+
+    supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+                                  "Reference concatenation: input type not supported");
+
+    return supported;
 }
 
 bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,