2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
8 #include <armnn/Utils.hpp>
9 #include <reference/RefWorkloadFactory.hpp>
10 #include <backendsCommon/test/LayerTests.hpp>
11 #include "TensorHelpers.hpp"
12 #include <boost/test/unit_test.hpp>
14 inline void ConfigureLoggingTest()
16 // Configures logging for both the ARMNN library and this test program.
17 armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
18 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, armnn::LogSeverity::Fatal);
21 // The following macros require the caller to have defined FactoryType, with one of the following using statements:
23 // using FactoryType = armnn::RefWorkloadFactory;
24 // using FactoryType = armnn::ClWorkloadFactory;
25 // using FactoryType = armnn::NeonWorkloadFactory;
27 /// Executes BOOST_TEST on CompareTensors() return value so that the predicate_result message is reported.
28 /// If the test reports itself as not supported then the tensors are not compared.
29 /// Additionally this checks that the supportedness reported by the test matches the name of the test.
30 /// Unsupported tests must be 'tagged' by including "UNSUPPORTED" in their name.
31 /// This is useful because it clarifies that the feature being tested is not actually supported
32 /// (a passed test with the name of a feature would imply that feature was supported).
33 /// If support is added for a feature, the test case will fail because the name incorrectly contains UNSUPPORTED.
34 /// If support is removed for a feature, the test case will fail because the name doesn't contain UNSUPPORTED.
35 template <typename T, std::size_t n>
36 void CompareTestResultIfSupported(const std::string& testName, const LayerTestResult<T, n>& testResult)
38 bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
39 BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.supported,
40 "The test name does not match the supportedness it is reporting");
41 if (testResult.supported)
43 BOOST_TEST(CompareTensors(testResult.output, testResult.outputExpected));
47 template <typename T, std::size_t n>
48 void CompareTestResultIfSupported(const std::string& testName, const std::vector<LayerTestResult<T, n>>& testResult)
50 bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
51 for (unsigned int i = 0; i < testResult.size(); ++i)
53 BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].supported,
54 "The test name does not match the supportedness it is reporting");
55 if (testResult[i].supported)
57 BOOST_TEST(CompareTensors(testResult[i].output, testResult[i].outputExpected));
62 template<typename FactoryType, typename TFuncPtr, typename... Args>
63 void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
65 std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>();
66 armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
68 FactoryType workloadFactory;
69 auto testResult = (*testFunction)(workloadFactory, args...);
70 CompareTestResultIfSupported(testName, testResult);
73 #define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \
74 BOOST_AUTO_TEST_CASE(TestName) \
76 RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
79 template<typename FactoryType, typename TFuncPtr, typename... Args>
80 void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
82 FactoryType workloadFactory;
83 armnn::RefWorkloadFactory refWorkloadFactory;
84 auto testResult = (*testFunction)(workloadFactory, refWorkloadFactory, args...);
85 CompareTestResultIfSupported(testName, testResult);
88 #define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \
89 BOOST_AUTO_TEST_CASE(TestName) \
91 CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
94 #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \
95 BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \
97 CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \