MLCE-82 Add IsLayerSupported tests for MEAN
authorMatthew Bentham <matthew.bentham@arm.com>
Wed, 2 Jan 2019 13:26:31 +0000 (13:26 +0000)
committerMatthew Bentham <matthew.bentham@arm.com>
Thu, 10 Jan 2019 16:12:45 +0000 (16:12 +0000)
Change-Id: I43be451f490db0154021f47a2fd49d1269cf5b95

src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
src/backends/cl/test/ClLayerSupportTests.cpp
src/backends/neon/test/NeonLayerSupportTests.cpp

index d6528bb..78716ef 100644 (file)
@@ -592,4 +592,30 @@ bool IsConvertLayerSupportedTests(std::string& reasonIfUnsupported)
     return result;
 }
 
+template<typename FactoryType, armnn::DataType InputDataType , armnn::DataType OutputDataType>
+bool IsMeanLayerSupportedTests(std::string& reasonIfUnsupported)
+{
+    armnn::Graph graph;
+    static const std::vector<unsigned> axes = {1, 0};
+    armnn::MeanDescriptor desc(axes, false);
+
+    armnn::Layer* const layer = graph.AddLayer<armnn::MeanLayer>(desc, "LayerName");
+
+    armnn::Layer* const input = graph.AddLayer<armnn::InputLayer>(0, "input");
+    armnn::Layer* const output = graph.AddLayer<armnn::OutputLayer>(0, "output");
+
+    armnn::TensorInfo inputTensorInfo({4, 3, 2}, InputDataType);
+    armnn::TensorInfo outputTensorInfo({2}, OutputDataType);
+
+    input->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+    input->GetOutputHandler(0).SetTensorInfo(inputTensorInfo);
+    layer->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+    layer->GetOutputHandler(0).SetTensorInfo(outputTensorInfo);
+
+    bool result = FactoryType::IsLayerSupported(*layer, InputDataType, reasonIfUnsupported);
+
+    return result;
+}
+
+
 } //namespace
index acfd8c3..bcf057b 100644 (file)
@@ -7,6 +7,7 @@
 
 #include <layers/ConvertFp16ToFp32Layer.hpp>
 #include <layers/ConvertFp32ToFp16Layer.hpp>
+#include <layers/MeanLayer.hpp>
 #include <test/TensorHelpers.hpp>
 
 #include <backendsCommon/CpuTensorHandle.hpp>
@@ -106,4 +107,14 @@ BOOST_FIXTURE_TEST_CASE(IsConvertFp32ToFp16SupportedFp32OutputCl, ClContextContr
     BOOST_CHECK_EQUAL(reasonIfUnsupported, "Output should be Float16");
 }
 
+BOOST_FIXTURE_TEST_CASE(IsMeanSupportedCl, ClContextControlFixture)
+{
+    std::string reasonIfUnsupported;
+
+    bool result = IsMeanLayerSupportedTests<armnn::ClWorkloadFactory,
+      armnn::DataType::Float32, armnn::DataType::Float32>(reasonIfUnsupported);
+
+    BOOST_CHECK(result);
+}
+
 BOOST_AUTO_TEST_SUITE_END()
index c6d2731..435afd2 100644 (file)
@@ -61,4 +61,14 @@ BOOST_AUTO_TEST_CASE(IsConvertFp32ToFp16SupportedNeon)
     BOOST_CHECK(result);
 }
 
+BOOST_AUTO_TEST_CASE(IsMeanSupportedNeon)
+{
+    std::string reasonIfUnsupported;
+
+    bool result = IsMeanLayerSupportedTests<armnn::NeonWorkloadFactory,
+      armnn::DataType::Float32, armnn::DataType::Float32>(reasonIfUnsupported);
+
+    BOOST_CHECK(result);
+}
+
 BOOST_AUTO_TEST_SUITE_END()