IVGCVSW-3429 Add a utility Version class
[platform/upstream/armnn.git] / src / profiling / test / ProfilingTests.cpp
index 26cbfd7..8a2f2bd 100644 (file)
@@ -5,6 +5,8 @@
 
 #include "../CommandHandlerKey.hpp"
 #include "../CommandHandlerFunctor.hpp"
+#include "../CommandHandlerRegistry.hpp"
+#include "../EncodeVersion.hpp"
 #include "../Packet.hpp"
 
 #include <cstdint>
@@ -61,6 +63,36 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons)
     BOOST_CHECK(vect == expectedVect);
 }
 
+BOOST_AUTO_TEST_CASE(CheckEncodeVersion)
+{
+    mlutil::Version version1(12);
+
+    BOOST_CHECK(version1.GetMajor() == 0);
+    BOOST_CHECK(version1.GetMinor() == 0);
+    BOOST_CHECK(version1.GetPatch() == 12);
+
+    mlutil::Version version2(4108);
+
+    BOOST_CHECK(version2.GetMajor() == 0);
+    BOOST_CHECK(version2.GetMinor() == 1);
+    BOOST_CHECK(version2.GetPatch() == 12);
+
+    mlutil::Version version3(4198412);
+
+    BOOST_CHECK(version3.GetMajor() == 1);
+    BOOST_CHECK(version3.GetMinor() == 1);
+    BOOST_CHECK(version3.GetPatch() == 12);
+
+    mlutil::Version version4(0);
+
+    BOOST_CHECK(version4.GetMajor() == 0);
+    BOOST_CHECK(version4.GetMinor() == 0);
+    BOOST_CHECK(version4.GetPatch() == 0);
+
+    mlutil::Version version5(1,0,0);
+    BOOST_CHECK(version5.GetEncodedValue() == 4194304);
+}
+
 BOOST_AUTO_TEST_CASE(CheckPacketClass)
 {
     const char* data = "test";
@@ -80,35 +112,35 @@ BOOST_AUTO_TEST_CASE(CheckPacketClass)
     BOOST_CHECK(packetTest1.GetPacketClass() == 5);
 }
 
-BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
+// Create Derived Classes
+class TestFunctorA : public CommandHandlerFunctor
 {
-    // Create Derived Classes
-    class TestFunctorA : public CommandHandlerFunctor
-    {
-    public:
-        using CommandHandlerFunctor::CommandHandlerFunctor;
+public:
+    using CommandHandlerFunctor::CommandHandlerFunctor;
 
-        int GetCount() { return m_Count; }
+    int GetCount() { return m_Count; }
 
-        void operator()(const Packet& packet) override
-        {
-            m_Count++;
-        }
+    void operator()(const Packet& packet) override
+    {
+        m_Count++;
+    }
 
-    private:
-        int m_Count = 0;
-    };
+private:
+    int m_Count = 0;
+};
 
-    class TestFunctorB : public TestFunctorA
-    {
-        using TestFunctorA::TestFunctorA;
-    };
+class TestFunctorB : public TestFunctorA
+{
+    using TestFunctorA::TestFunctorA;
+};
 
-    class TestFunctorC : public TestFunctorA
-    {
-        using TestFunctorA::TestFunctorA;
-    };
+class TestFunctorC : public TestFunctorA
+{
+    using TestFunctorA::TestFunctorA;
+};
 
+BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
+{
     // Hard code the version as it will be the same during a single profiling session
     uint32_t version = 1;
 
@@ -156,4 +188,52 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
     BOOST_CHECK(testFunctorC.GetCount() == 1);
 }
 
+BOOST_AUTO_TEST_CASE(CheckCommandHandlerRegistry)
+{
+    // Hard code the version as it will be the same during a single profiling session
+    uint32_t version = 1;
+
+    TestFunctorA testFunctorA(461, version);
+    TestFunctorB testFunctorB(963, version);
+    TestFunctorC testFunctorC(983, version);
+
+    // Create the Command Handler Registry
+    CommandHandlerRegistry registry;
+
+    // Register multiple different derived classes
+    registry.RegisterFunctor(&testFunctorA, testFunctorA.GetPacketId(), testFunctorA.GetVersion());
+    registry.RegisterFunctor(&testFunctorB, testFunctorB.GetPacketId(), testFunctorB.GetVersion());
+    registry.RegisterFunctor(&testFunctorC, testFunctorC.GetPacketId(), testFunctorC.GetVersion());
+
+    Packet packetA(500000000, 0, nullptr);
+    Packet packetB(600000000, 0, nullptr);
+    Packet packetC(400000000, 0, nullptr);
+
+    // Check the correct operator of derived class is called
+    registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetA);
+    BOOST_CHECK(testFunctorA.GetCount() == 1);
+    BOOST_CHECK(testFunctorB.GetCount() == 0);
+    BOOST_CHECK(testFunctorC.GetCount() == 0);
+
+    registry.GetFunctor(packetB.GetPacketId(), version)->operator()(packetB);
+    BOOST_CHECK(testFunctorA.GetCount() == 1);
+    BOOST_CHECK(testFunctorB.GetCount() == 1);
+    BOOST_CHECK(testFunctorC.GetCount() == 0);
+
+    registry.GetFunctor(packetC.GetPacketId(), version)->operator()(packetC);
+    BOOST_CHECK(testFunctorA.GetCount() == 1);
+    BOOST_CHECK(testFunctorB.GetCount() == 1);
+    BOOST_CHECK(testFunctorC.GetCount() == 1);
+
+    // Re-register an existing key with a new function
+    registry.RegisterFunctor(&testFunctorC, testFunctorA.GetPacketId(), version);
+    registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetC);
+    BOOST_CHECK(testFunctorA.GetCount() == 1);
+    BOOST_CHECK(testFunctorB.GetCount() == 1);
+    BOOST_CHECK(testFunctorC.GetCount() == 2);
+
+    // Check that non-existent key returns nullptr for its functor
+    BOOST_CHECK_THROW(registry.GetFunctor(0, 0), armnn::Exception);
+}
+
 BOOST_AUTO_TEST_SUITE_END()