IVGCVSW-3431 Create Profiling Service State Machine
[platform/upstream/armnn.git] / src / profiling / test / ProfilingTests.cpp
index 3fd8d79..ce278ab 100644 (file)
@@ -9,6 +9,7 @@
 #include "../EncodeVersion.hpp"
 #include "../Packet.hpp"
 #include "../PacketVersionResolver.hpp"
+#include "../ProfilingStateMachine.hpp"
 
 #include <boost/test/unit_test.hpp>
 
@@ -17,6 +18,7 @@
 #include <limits>
 #include <map>
 #include <random>
+#include <thread>
 
 BOOST_AUTO_TEST_SUITE(ExternalProfiling)
 
@@ -265,5 +267,94 @@ BOOST_AUTO_TEST_CASE(CheckPacketVersionResolver)
         BOOST_TEST(resolvedVersion == expectedVersion);
     }
 }
+void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states)
+{
+    ProfilingState newState = ProfilingState::NotConnected;
+    states.GetCurrentState();
+    states.TransitionToState(newState);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine)
+{
+    ProfilingStateMachine profilingState1(ProfilingState::Uninitialised);
+    profilingState1.TransitionToState(ProfilingState::Uninitialised);
+    BOOST_CHECK(profilingState1.GetCurrentState() ==  ProfilingState::Uninitialised);
+
+    ProfilingStateMachine profilingState2(ProfilingState::Uninitialised);
+    profilingState2.TransitionToState(ProfilingState::NotConnected);
+    BOOST_CHECK(profilingState2.GetCurrentState() == ProfilingState::NotConnected);
+
+    ProfilingStateMachine profilingState3(ProfilingState::NotConnected);
+    profilingState3.TransitionToState(ProfilingState::NotConnected);
+    BOOST_CHECK(profilingState3.GetCurrentState() == ProfilingState::NotConnected);
+
+    ProfilingStateMachine profilingState4(ProfilingState::NotConnected);
+    profilingState4.TransitionToState(ProfilingState::WaitingForAck);
+    BOOST_CHECK(profilingState4.GetCurrentState() == ProfilingState::WaitingForAck);
+
+    ProfilingStateMachine profilingState5(ProfilingState::WaitingForAck);
+    profilingState5.TransitionToState(ProfilingState::WaitingForAck);
+    BOOST_CHECK(profilingState5.GetCurrentState() == ProfilingState::WaitingForAck);
+
+    ProfilingStateMachine profilingState6(ProfilingState::WaitingForAck);
+    profilingState6.TransitionToState(ProfilingState::Active);
+    BOOST_CHECK(profilingState6.GetCurrentState() == ProfilingState::Active);
+
+    ProfilingStateMachine profilingState7(ProfilingState::Active);
+    profilingState7.TransitionToState(ProfilingState::NotConnected);
+    BOOST_CHECK(profilingState7.GetCurrentState() == ProfilingState::NotConnected);
+
+    ProfilingStateMachine profilingState8(ProfilingState::Active);
+    profilingState8.TransitionToState(ProfilingState::Active);
+    BOOST_CHECK(profilingState8.GetCurrentState() == ProfilingState::Active);
+
+    ProfilingStateMachine profilingState9(ProfilingState::Uninitialised);
+    BOOST_CHECK_THROW(profilingState9.TransitionToState(ProfilingState::WaitingForAck),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState10(ProfilingState::Uninitialised);
+    BOOST_CHECK_THROW(profilingState10.TransitionToState(ProfilingState::Active),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState11(ProfilingState::NotConnected);
+    BOOST_CHECK_THROW(profilingState11.TransitionToState(ProfilingState::Uninitialised),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState12(ProfilingState::NotConnected);
+    BOOST_CHECK_THROW(profilingState12.TransitionToState(ProfilingState::Active),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState13(ProfilingState::WaitingForAck);
+    BOOST_CHECK_THROW(profilingState13.TransitionToState(ProfilingState::Uninitialised),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState14(ProfilingState::WaitingForAck);
+    BOOST_CHECK_THROW(profilingState14.TransitionToState(ProfilingState::NotConnected),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState15(ProfilingState::Active);
+    BOOST_CHECK_THROW(profilingState15.TransitionToState(ProfilingState::Uninitialised),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState16(armnn::profiling::ProfilingState::Active);
+    BOOST_CHECK_THROW(profilingState16.TransitionToState(ProfilingState::WaitingForAck),
+                      armnn::Exception);
+
+    ProfilingStateMachine profilingState17(ProfilingState::Uninitialised);
+
+    std::thread thread1 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+    std::thread thread2 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+    std::thread thread3 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+    std::thread thread4 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+    std::thread thread5 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
+
+    thread1.join();
+    thread2.join();
+    thread3.join();
+    thread4.join();
+    thread5.join();
+
+    BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected));
+}
 
 BOOST_AUTO_TEST_SUITE_END()