IVGCVSW-3431 Create Profiling Service State Machine
authorNikhil Raj <nikhil.raj@arm.com>
Tue, 3 Sep 2019 14:55:33 +0000 (15:55 +0100)
committerNikhil Raj <nikhil.raj@arm.com>
Tue, 3 Sep 2019 14:55:33 +0000 (15:55 +0100)
Change-Id: I30ae52d38181a91ce642e24919ad788902e42eb4
Signed-off-by: Nikhil Raj <nikhil.raj@arm.com>
CMakeLists.txt
src/profiling/ProfilingStateMachine.cpp [new file with mode: 0644]
src/profiling/ProfilingStateMachine.hpp [new file with mode: 0644]
src/profiling/test/ProfilingTests.cpp

index 1a07f69..a285c36 100644 (file)
@@ -433,6 +433,8 @@ list(APPEND armnn_sources
     src/profiling/SendCounterPacket.cpp
     src/profiling/ProfilingUtils.hpp
     src/profiling/ProfilingUtils.cpp
+    src/profiling/ProfilingStateMachine.cpp
+    src/profiling/ProfilingStateMachine.hpp
     third-party/half/half.hpp
     )
 
diff --git a/src/profiling/ProfilingStateMachine.cpp b/src/profiling/ProfilingStateMachine.cpp
new file mode 100644 (file)
index 0000000..682e1b8
--- /dev/null
@@ -0,0 +1,93 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ProfilingStateMachine.hpp"
+
+#include <armnn/Exceptions.hpp>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+ProfilingState ProfilingStateMachine::GetCurrentState() const
+{
+    return m_State;
+}
+
+void ProfilingStateMachine::TransitionToState(ProfilingState newState)
+{
+     switch (newState)
+     {
+         case ProfilingState::Uninitialised:
+         {
+             ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
+             do {
+                 if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised))
+                 {
+                     throw armnn::Exception(std::string("Cannot transition from state [")
+                                            + GetProfilingStateName(expectedState)
+                                            +"] to [" + GetProfilingStateName(newState) + "]");
+                 }
+             } while (!m_State.compare_exchange_strong(expectedState, newState,
+                      std::memory_order::memory_order_relaxed));
+
+             break;
+         }
+         case  ProfilingState::NotConnected:
+         {
+             ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
+             do {
+                 if (!IsOneOfStates(expectedState, ProfilingState::Uninitialised, ProfilingState::NotConnected,
+                                    ProfilingState::Active))
+                 {
+                     throw armnn::Exception(std::string("Cannot transition from state [")
+                                            + GetProfilingStateName(expectedState)
+                                            +"] to [" + GetProfilingStateName(newState) + "]");
+                 }
+             } while (!m_State.compare_exchange_strong(expectedState, newState,
+                      std::memory_order::memory_order_relaxed));
+
+             break;
+         }
+         case ProfilingState::WaitingForAck:
+         {
+             ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
+             do {
+                 if (!IsOneOfStates(expectedState, ProfilingState::NotConnected, ProfilingState::WaitingForAck))
+                 {
+                     throw armnn::Exception(std::string("Cannot transition from state [")
+                                            + GetProfilingStateName(expectedState)
+                                            +"] to [" + GetProfilingStateName(newState) + "]");
+                 }
+             } while (!m_State.compare_exchange_strong(expectedState, newState,
+                      std::memory_order::memory_order_relaxed));
+
+             break;
+         }
+         case ProfilingState::Active:
+         {
+             ProfilingState expectedState = m_State.load(std::memory_order::memory_order_relaxed);
+             do {
+                 if (!IsOneOfStates(expectedState, ProfilingState::WaitingForAck, ProfilingState::Active))
+                 {
+                     throw armnn::Exception(std::string("Cannot transition from state [")
+                                            + GetProfilingStateName(expectedState)
+                                            +"] to [" + GetProfilingStateName(newState) + "]");
+                 }
+             } while (!m_State.compare_exchange_strong(expectedState, newState,
+                      std::memory_order::memory_order_relaxed));
+
+             break;
+         }
+         default:
+             break;
+     }
+}
+
+} //namespace profiling
+
+} //namespace armnn
\ No newline at end of file
diff --git a/src/profiling/ProfilingStateMachine.hpp b/src/profiling/ProfilingStateMachine.hpp
new file mode 100644 (file)
index 0000000..66f8b2c
--- /dev/null
@@ -0,0 +1,69 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include <atomic>
+
+namespace armnn
+{
+
+namespace  profiling
+{
+
+enum class ProfilingState
+{
+    Uninitialised,
+    NotConnected,
+    WaitingForAck,
+    Active
+};
+
+class ProfilingStateMachine
+{
+public:
+    ProfilingStateMachine(): m_State(ProfilingState::Uninitialised) {};
+    ProfilingStateMachine(ProfilingState state): m_State(state) {};
+
+    ProfilingState GetCurrentState() const;
+    void TransitionToState(ProfilingState newState);
+
+    bool IsOneOfStates(ProfilingState state1)
+    {
+        return false;
+    }
+
+    template<typename T, typename... Args >
+    bool IsOneOfStates(T state1, T state2, Args... args)
+    {
+        if (state1 == state2)
+        {
+            return true;
+        }
+        else
+        {
+            return IsOneOfStates(state1, args...);
+        }
+    }
+
+private:
+    std::atomic<ProfilingState> m_State;
+};
+
+constexpr char const* GetProfilingStateName(ProfilingState state)
+{
+    switch(state)
+    {
+        case ProfilingState::Uninitialised:       return "Uninitialised";
+        case ProfilingState::NotConnected:        return "NotConnected";
+        case ProfilingState::WaitingForAck:       return "WaitingForAck";
+        case ProfilingState::Active:              return "Active";
+        default:                                  return "Unknown";
+    }
+}
+
+} //namespace profiling
+
+} //namespace armnn
+
index 3fd8d79..ce278ab 100644 (file)
@@ -9,6 +9,7 @@
 #include "../EncodeVersion.hpp"
 #include "../Packet.hpp"
 #include "../PacketVersionResolver.hpp"
+#include "../ProfilingStateMachine.hpp"
 
 #include <boost/test/unit_test.hpp>
 
@@ -17,6 +18,7 @@
 #include <limits>
 #include <map>
 #include <random>
+#include <thread>
 
 BOOST_AUTO_TEST_SUITE(ExternalProfiling)
 
@@ -265,5 +267,94 @@ BOOST_AUTO_TEST_CASE(CheckPacketVersionResolver)
         BOOST_TEST(resolvedVersion == expectedVersion);
     }
 }
+void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states)
+{
+    ProfilingState newState = ProfilingState::NotConnected;
+    states.GetCurrentState();
+    states.TransitionToState(newState);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine)
+{
+    ProfilingStateMachine profilingState1(ProfilingState::Uninitialised);
+    profilingState1.TransitionToState(ProfilingState::Uninitialised);
+    BOOST_CHECK(profilingState1.GetCurrentState() ==  ProfilingState::Uninitialised);
+
+    ProfilingStateMachine profilingState2(ProfilingState::Uninitialised);
+    profilingState2.TransitionToState(ProfilingState::NotConnected);
+    BOOST_CHECK(profilingState2.GetCurrentState() == ProfilingState::NotConnected);
+
+    ProfilingStateMachine profilingState3(ProfilingState::NotConnected);
+    profilingState3.TransitionToState(ProfilingState::NotConnected);
+    BOOST_CHECK(profilingState3.GetCurrentState() == ProfilingState::NotConnected);
+
+    ProfilingStateMachine profilingState4(ProfilingState::NotConnected);
+    profilingState4.TransitionToState(ProfilingState::WaitingForAck);
+    BOOST_CHECK(profilingState4.GetCurrentState() == ProfilingState::WaitingForAck);
+
+    ProfilingStateMachine profilingState5(ProfilingState::WaitingForAck);
+    profilingState5.TransitionToState(ProfilingState::WaitingForAck);
+    BOOST_CHECK(profilingState5.GetCurrentState() == ProfilingState::WaitingForAck);
+
+    ProfilingStateMachine profilingState6(ProfilingState::WaitingForAck);
+    profilingState6.TransitionToState(ProfilingState::Active);
+    BOOST_CHECK(profilingState6.GetCurrentState() == ProfilingState::Active);
+
+    ProfilingStateMachine profilingState7(ProfilingState::Active);
+    profilingState7.TransitionToState(ProfilingState::NotConnected);
+    BOOST_CHECK(profilingState7.GetCurrentState() == ProfilingState::NotConnected);
+
+    ProfilingStateMachine profilingState8(ProfilingState::Active);
+    profilingState8.TransitionToState(ProfilingState::Active);
+    BOOST_CHECK(profilingState8.GetCurrentState() == ProfilingState::Active);
+
+    ProfilingStateMachine profilingState9(ProfilingState::Uninitialised);
+    BOOST_CHECK_THROW(profilingState9.TransitionToState(ProfilingState::WaitingForAck),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState10(ProfilingState::Uninitialised);
+    BOOST_CHECK_THROW(profilingState10.TransitionToState(ProfilingState::Active),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState11(ProfilingState::NotConnected);
+    BOOST_CHECK_THROW(profilingState11.TransitionToState(ProfilingState::Uninitialised),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState12(ProfilingState::NotConnected);
+    BOOST_CHECK_THROW(profilingState12.TransitionToState(ProfilingState::Active),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState13(ProfilingState::WaitingForAck);
+    BOOST_CHECK_THROW(profilingState13.TransitionToState(ProfilingState::Uninitialised),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState14(ProfilingState::WaitingForAck);
+    BOOST_CHECK_THROW(profilingState14.TransitionToState(ProfilingState::NotConnected),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState15(ProfilingState::Active);
+    BOOST_CHECK_THROW(profilingState15.TransitionToState(ProfilingState::Uninitialised),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState16(armnn::profiling::ProfilingState::Active);
+    BOOST_CHECK_THROW(profilingState16.TransitionToState(ProfilingState::WaitingForAck),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState17(ProfilingState::Uninitialised);
+
+    std::thread thread1 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+    std::thread thread2 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+    std::thread thread3 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+    std::thread thread4 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+    std::thread thread5 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+
+    thread1.join();
+    thread2.join();
+    thread3.join();
+    thread4.join();
+    thread5.join();
+
+    BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected));
+}
 
 BOOST_AUTO_TEST_SUITE_END()