IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / armnn / test / UnitTests.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <Logging.hpp>
8 #include <armnn/Utils.hpp>
9 #include <reference/RefWorkloadFactory.hpp>
10 #include <backendsCommon/test/LayerTests.hpp>
11 #include "TensorHelpers.hpp"
12 #include <boost/test/unit_test.hpp>
13
14 inline void ConfigureLoggingTest()
15 {
16     // Configures logging for both the ARMNN library and this test program.
17     armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
18     armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, armnn::LogSeverity::Fatal);
19 }
20
21 // The following macros require the caller to have defined FactoryType, with one of the following using statements:
22 //
23 //      using FactoryType = armnn::RefWorkloadFactory;
24 //      using FactoryType = armnn::ClWorkloadFactory;
25 //      using FactoryType = armnn::NeonWorkloadFactory;
26
27 /// Executes BOOST_TEST on CompareTensors() return value so that the predicate_result message is reported.
28 /// If the test reports itself as not supported then the tensors are not compared.
29 /// Additionally this checks that the supportedness reported by the test matches the name of the test.
30 /// Unsupported tests must be 'tagged' by including "UNSUPPORTED" in their name.
31 /// This is useful because it clarifies that the feature being tested is not actually supported
32 /// (a passed test with the name of a feature would imply that feature was supported).
33 /// If support is added for a feature, the test case will fail because the name incorrectly contains UNSUPPORTED.
34 /// If support is removed for a feature, the test case will fail because the name doesn't contain UNSUPPORTED.
35 template <typename T, std::size_t n>
36 void CompareTestResultIfSupported(const std::string& testName, const LayerTestResult<T, n>& testResult)
37 {
38     bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
39     BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.supported,
40         "The test name does not match the supportedness it is reporting");
41     if (testResult.supported)
42     {
43         BOOST_TEST(CompareTensors(testResult.output, testResult.outputExpected));
44     }
45 }
46
47 template <typename T, std::size_t n>
48 void CompareTestResultIfSupported(const std::string& testName, const std::vector<LayerTestResult<T, n>>& testResult)
49 {
50     bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
51     for (unsigned int i = 0; i < testResult.size(); ++i)
52     {
53         BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].supported,
54             "The test name does not match the supportedness it is reporting");
55         if (testResult[i].supported)
56         {
57             BOOST_TEST(CompareTensors(testResult[i].output, testResult[i].outputExpected));
58         }
59     }
60 }
61
62 template<typename FactoryType, typename TFuncPtr, typename... Args>
63 void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
64 {
65     std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>();
66     armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
67
68     FactoryType workloadFactory;
69     auto testResult = (*testFunction)(workloadFactory, args...);
70     CompareTestResultIfSupported(testName, testResult);
71 }
72
73 #define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \
74     BOOST_AUTO_TEST_CASE(TestName) \
75     { \
76         RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
77     }
78
79 template<typename FactoryType, typename TFuncPtr, typename... Args>
80 void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
81 {
82     FactoryType workloadFactory;
83     armnn::RefWorkloadFactory refWorkloadFactory;
84     auto testResult = (*testFunction)(workloadFactory, refWorkloadFactory, args...);
85     CompareTestResultIfSupported(testName, testResult);
86 }
87
88 #define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \
89     BOOST_AUTO_TEST_CASE(TestName) \
90     { \
91         CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
92     }
93
94 #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \
95     BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \
96     { \
97         CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
98     }