IVGCVSW-3550 Create Command Handler Registry
authorFrancis Murtagh <francis.murtagh@arm.com>
Fri, 16 Aug 2019 16:45:07 +0000 (17:45 +0100)
committerFrancis Murtagh <francis.murtagh@arm.com>
Fri, 16 Aug 2019 16:45:12 +0000 (17:45 +0100)
Change-Id: I51e34068d79ba660ae2f16b22ad2bb8191d473fa
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
CMakeLists.txt
src/profiling/CommandHandlerRegistry.cpp [new file with mode: 0644]
src/profiling/CommandHandlerRegistry.hpp [new file with mode: 0644]
src/profiling/test/ProfilingTests.cpp

index a77b112..df4b742 100644 (file)
@@ -415,6 +415,8 @@ list(APPEND armnn_sources
     src/profiling/CommandHandlerFunctor.hpp
     src/profiling/CommandHandlerKey.cpp
     src/profiling/CommandHandlerKey.hpp
+    src/profiling/CommandHandlerRegistry.cpp
+    src/profiling/CommandHandlerRegistry.hpp
     src/profiling/Packet.cpp
     src/profiling/Packet.hpp
     third-party/half/half.hpp
diff --git a/src/profiling/CommandHandlerRegistry.cpp b/src/profiling/CommandHandlerRegistry.cpp
new file mode 100644 (file)
index 0000000..d392db0
--- /dev/null
@@ -0,0 +1,29 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "CommandHandlerRegistry.hpp"
+
+#include <boost/assert.hpp>
+#include <boost/log/trivial.hpp>
+
+void CommandHandlerRegistry::RegisterFunctor(CommandHandlerFunctor* functor, uint32_t packetId, uint32_t version)
+{
+    BOOST_ASSERT_MSG(functor, "Provided functor should not be a nullptr.");
+    CommandHandlerKey key(packetId, version);
+    registry[key] = functor;
+}
+
+CommandHandlerFunctor* CommandHandlerRegistry::GetFunctor(uint32_t packetId, uint32_t version) const
+{
+    CommandHandlerKey key(packetId, version);
+
+    // Check that the requested key exists
+    if (registry.find(key) == registry.end())
+    {
+        throw armnn::Exception("Functor with requested PacketId or Version does not exist.");
+    }
+
+    return registry.at(key);
+}
diff --git a/src/profiling/CommandHandlerRegistry.hpp b/src/profiling/CommandHandlerRegistry.hpp
new file mode 100644 (file)
index 0000000..ba81f17
--- /dev/null
@@ -0,0 +1,36 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "CommandHandlerFunctor.hpp"
+#include "CommandHandlerKey.hpp"
+
+#include <boost/functional/hash.hpp>
+#include <unordered_map>
+
+struct CommandHandlerHash
+{
+    std::size_t operator() (const CommandHandlerKey& commandHandlerKey) const
+    {
+        std::size_t seed = 0;
+        boost::hash_combine(seed, commandHandlerKey.GetPacketId());
+        boost::hash_combine(seed, commandHandlerKey.GetVersion());
+        return seed;
+    }
+};
+
+class CommandHandlerRegistry
+{
+public:
+    CommandHandlerRegistry() = default;
+
+    void RegisterFunctor(CommandHandlerFunctor* functor, uint32_t packetId, uint32_t version);
+
+    CommandHandlerFunctor* GetFunctor(uint32_t packetId, uint32_t version) const;
+
+private:
+    std::unordered_map<CommandHandlerKey, CommandHandlerFunctor*, CommandHandlerHash> registry;
+};
\ No newline at end of file
index 26cbfd7..a8ec027 100644 (file)
@@ -5,6 +5,7 @@
 
 #include "../CommandHandlerKey.hpp"
 #include "../CommandHandlerFunctor.hpp"
+#include "../CommandHandlerRegistry.hpp"
 #include "../Packet.hpp"
 
 #include <cstdint>
@@ -80,35 +81,35 @@ BOOST_AUTO_TEST_CASE(CheckPacketClass)
     BOOST_CHECK(packetTest1.GetPacketClass() == 5);
 }
 
-BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
+// Create Derived Classes
+class TestFunctorA : public CommandHandlerFunctor
 {
-    // Create Derived Classes
-    class TestFunctorA : public CommandHandlerFunctor
-    {
-    public:
-        using CommandHandlerFunctor::CommandHandlerFunctor;
+public:
+    using CommandHandlerFunctor::CommandHandlerFunctor;
 
-        int GetCount() { return m_Count; }
+    int GetCount() { return m_Count; }
 
-        void operator()(const Packet& packet) override
-        {
-            m_Count++;
-        }
+    void operator()(const Packet& packet) override
+    {
+        m_Count++;
+    }
 
-    private:
-        int m_Count = 0;
-    };
+private:
+    int m_Count = 0;
+};
 
-    class TestFunctorB : public TestFunctorA
-    {
-        using TestFunctorA::TestFunctorA;
-    };
+class TestFunctorB : public TestFunctorA
+{
+    using TestFunctorA::TestFunctorA;
+};
 
-    class TestFunctorC : public TestFunctorA
-    {
-        using TestFunctorA::TestFunctorA;
-    };
+class TestFunctorC : public TestFunctorA
+{
+    using TestFunctorA::TestFunctorA;
+};
 
+BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
+{
     // Hard code the version as it will be the same during a single profiling session
     uint32_t version = 1;
 
@@ -156,4 +157,52 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
     BOOST_CHECK(testFunctorC.GetCount() == 1);
 }
 
+BOOST_AUTO_TEST_CASE(CheckCommandHandlerRegistry)
+{
+    // Hard code the version as it will be the same during a single profiling session
+    uint32_t version = 1;
+
+    TestFunctorA testFunctorA(461, version);
+    TestFunctorB testFunctorB(963, version);
+    TestFunctorC testFunctorC(983, version);
+
+    // Create the Command Handler Registry
+    CommandHandlerRegistry registry;
+
+    // Register multiple different derived classes
+    registry.RegisterFunctor(&testFunctorA, testFunctorA.GetPacketId(), testFunctorA.GetVersion());
+    registry.RegisterFunctor(&testFunctorB, testFunctorB.GetPacketId(), testFunctorB.GetVersion());
+    registry.RegisterFunctor(&testFunctorC, testFunctorC.GetPacketId(), testFunctorC.GetVersion());
+
+    Packet packetA(500000000, 0, nullptr);
+    Packet packetB(600000000, 0, nullptr);
+    Packet packetC(400000000, 0, nullptr);
+
+    // Check the correct operator of derived class is called
+    registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetA);
+    BOOST_CHECK(testFunctorA.GetCount() == 1);
+    BOOST_CHECK(testFunctorB.GetCount() == 0);
+    BOOST_CHECK(testFunctorC.GetCount() == 0);
+
+    registry.GetFunctor(packetB.GetPacketId(), version)->operator()(packetB);
+    BOOST_CHECK(testFunctorA.GetCount() == 1);
+    BOOST_CHECK(testFunctorB.GetCount() == 1);
+    BOOST_CHECK(testFunctorC.GetCount() == 0);
+
+    registry.GetFunctor(packetC.GetPacketId(), version)->operator()(packetC);
+    BOOST_CHECK(testFunctorA.GetCount() == 1);
+    BOOST_CHECK(testFunctorB.GetCount() == 1);
+    BOOST_CHECK(testFunctorC.GetCount() == 1);
+
+    // Re-register an existing key with a new function
+    registry.RegisterFunctor(&testFunctorC, testFunctorA.GetPacketId(), version);
+    registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetC);
+    BOOST_CHECK(testFunctorA.GetCount() == 1);
+    BOOST_CHECK(testFunctorB.GetCount() == 1);
+    BOOST_CHECK(testFunctorC.GetCount() == 2);
+
+    // Check that non-existent key returns nullptr for its functor
+    BOOST_CHECK_THROW(registry.GetFunctor(0, 0), armnn::Exception);
+}
+
 BOOST_AUTO_TEST_SUITE_END()