IVGCVSW-3439 Create the Command Thread
authorFinnWilliamsArm <Finn.Williams@arm.com>
Tue, 17 Sep 2019 15:53:53 +0000 (16:53 +0100)
committerJim Flynn Arm <jim.flynn@arm.com>
Tue, 24 Sep 2019 23:44:51 +0000 (23:44 +0000)
Signed-off-by: FinnWilliamsArm <Finn.Williams@arm.com>
Change-Id: I9548c5937967f4c25841bb851273168379687bcd

CMakeLists.txt
include/armnn/Exceptions.hpp
src/profiling/CommandThread.cpp [new file with mode: 0644]
src/profiling/CommandThread.hpp [new file with mode: 0644]
src/profiling/SocketProfilingConnection.cpp
src/profiling/test/ProfilingTests.cpp

index 90eb0328dd4df6dfe7722e33e77dbe0a8280f0a9..3da7e8bcfa677f0257f0622216f44ab20e7a6319 100644 (file)
@@ -428,6 +428,8 @@ list(APPEND armnn_sources
     src/profiling/CommandHandlerKey.hpp
     src/profiling/CommandHandlerRegistry.cpp
     src/profiling/CommandHandlerRegistry.hpp
+    src/profiling/CommandThread.cpp
+    src/profiling/CommandThread.hpp
     src/profiling/ConnectionAcknowledgedCommandHandler.cpp
     src/profiling/ConnectionAcknowledgedCommandHandler.hpp
     src/profiling/CounterDirectory.cpp
index f8e0b430a62ddf80c20caf452295a153802ac8db..e21e974fc76419cb0bd8fce886e53d218f3b90b8 100644 (file)
@@ -125,6 +125,11 @@ class MemoryExportException : public Exception
     using Exception::Exception;
 };
 
+class TimeoutException : public Exception
+{
+    using Exception::Exception;
+};
+
 template <typename ExceptionType>
 void ConditionalThrow(bool condition, const std::string& message)
 {
diff --git a/src/profiling/CommandThread.cpp b/src/profiling/CommandThread.cpp
new file mode 100644 (file)
index 0000000..4cd622c
--- /dev/null
@@ -0,0 +1,97 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <atomic>
+#include "CommandThread.hpp"
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+CommandThread::CommandThread(uint32_t timeout,
+                             bool stopAfterTimeout,
+                             CommandHandlerRegistry& commandHandlerRegistry,
+                             PacketVersionResolver& packetVersionResolver,
+                             IProfilingConnection& socketProfilingConnection)
+    : m_Timeout(timeout)
+    , m_StopAfterTimeout(stopAfterTimeout)
+    , m_IsRunning(false)
+    , m_CommandHandlerRegistry(commandHandlerRegistry)
+    , m_PacketVersionResolver(packetVersionResolver)
+    , m_SocketProfilingConnection(socketProfilingConnection)
+{};
+
+void CommandThread::WaitForPacket()
+{
+    do {
+        try
+        {
+            Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout);
+            Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
+
+            CommandHandlerFunctor* commandHandlerFunctor =
+                m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue());
+            commandHandlerFunctor->operator()(packet);
+        }
+        catch(armnn::TimeoutException)
+        {
+            if(m_StopAfterTimeout)
+            {
+                m_IsRunning.store(false, std::memory_order_relaxed);
+                return;
+            }
+        }
+        catch(...)
+        {
+            //might want to differentiate the errors more
+            m_IsRunning.store(false, std::memory_order_relaxed);
+            return;
+        }
+
+    } while(m_KeepRunning.load(std::memory_order_relaxed));
+
+    m_IsRunning.store(false, std::memory_order_relaxed);
+}
+
+void CommandThread::Start()
+{
+    if (!m_CommandThread.joinable() && !IsRunning())
+    {
+        m_IsRunning.store(true, std::memory_order_relaxed);
+        m_KeepRunning.store(true, std::memory_order_relaxed);
+        m_CommandThread = std::thread(&CommandThread::WaitForPacket, this);
+    }
+}
+
+void CommandThread::Stop()
+{
+    m_KeepRunning.store(false, std::memory_order_relaxed);
+}
+
+void CommandThread::Join()
+{
+    m_CommandThread.join();
+}
+
+bool CommandThread::IsRunning() const
+{
+    return m_IsRunning.load(std::memory_order_relaxed);
+}
+
+bool CommandThread::StopAfterTimeout(bool stopAfterTimeout)
+{
+    if (!IsRunning())
+    {
+        m_StopAfterTimeout = stopAfterTimeout;
+        return true;
+    }
+    return false;
+}
+
+}//namespace profiling
+
+}//namespace armnn
\ No newline at end of file
diff --git a/src/profiling/CommandThread.hpp b/src/profiling/CommandThread.hpp
new file mode 100644 (file)
index 0000000..6237cd2
--- /dev/null
@@ -0,0 +1,53 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "CommandHandlerRegistry.hpp"
+#include "IProfilingConnection.hpp"
+#include "PacketVersionResolver.hpp"
+#include "ProfilingService.hpp"
+
+#include <atomic>
+#include <thread>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+class CommandThread
+{
+public:
+    CommandThread(uint32_t timeout,
+                  bool stopAfterTimeout,
+                  CommandHandlerRegistry& commandHandlerRegistry,
+                  PacketVersionResolver& packetVersionResolver,
+                  IProfilingConnection& socketProfilingConnection);
+
+    void Start();
+    void Stop();
+    void Join();
+    bool IsRunning() const;
+    bool StopAfterTimeout(bool StopAfterTimeout);
+
+private:
+    void WaitForPacket();
+
+    uint32_t m_Timeout;
+    bool m_StopAfterTimeout;
+    std::atomic<bool> m_IsRunning;
+    std::atomic<bool> m_KeepRunning;
+    std::thread m_CommandThread;
+
+    CommandHandlerRegistry& m_CommandHandlerRegistry;
+    PacketVersionResolver& m_PacketVersionResolver;
+    IProfilingConnection& m_SocketProfilingConnection;
+};
+
+}//namespace profiling
+
+}//namespace armnn
\ No newline at end of file
index 188ca23e121d651b12f204fde53c0c38e18be7be..91d57cc9bda7523c72d29da9caf08937bb818425 100644 (file)
@@ -135,7 +135,7 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
     }
     else // it's 0 so a timeout.
     {
-        throw armnn::Exception(": Timeout while reading from socket.");
+        throw armnn::TimeoutException(": Timeout while reading from socket.");
     }
 }
 
index 32a41f37c2febd8c76230c5164bfae6d15b6d215..48723dbc34b7b6002c3dfa1bb11413eb686f85ea 100644 (file)
@@ -4,6 +4,7 @@
 //
 
 #include "SendCounterPacketTests.hpp"
+#include "../CommandThread.hpp"
 
 #include <CommandHandlerKey.hpp>
 #include <CommandHandlerFunctor.hpp>
@@ -87,6 +88,150 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons)
     BOOST_CHECK(vect == expectedVect);
 }
 
+class TestProfilingConnectionBase :public IProfilingConnection
+{
+public:
+    TestProfilingConnectionBase() = default;
+    ~TestProfilingConnectionBase() = default;
+
+    bool IsOpen()
+    {
+        return true;
+    }
+
+    void Close(){}
+
+    bool WritePacket(const char* buffer, uint32_t length)
+    {
+        return false;
+    }
+
+    Packet ReadPacket(uint32_t timeout)
+    {
+        std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+        std::unique_ptr<char[]> packetData;
+        //Return connection acknowledged packet
+        return {65536 ,0 , packetData};
+    }
+};
+
+class TestProfilingConnectionTimeoutError :public TestProfilingConnectionBase
+{
+    int readRequests = 0;
+public:
+    Packet ReadPacket(uint32_t timeout) {
+        if (readRequests < 3)
+        {
+            readRequests++;
+            throw armnn::TimeoutException(": Simulate a timeout");
+        }
+        std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+        std::unique_ptr<char[]> packetData;
+        //Return connection acknowledged packet after three timeouts
+        return {65536 ,0 , packetData};
+    }
+};
+
+class TestProfilingConnectionArmnnError :public TestProfilingConnectionBase
+{
+public:
+
+    Packet ReadPacket(uint32_t timeout)
+    {
+        std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+        throw armnn::Exception(": Simulate a non timeout error");
+    }
+};
+
+BOOST_AUTO_TEST_CASE(CheckCommandThread)
+{
+        PacketVersionResolver packetVersionResolver;
+        ProfilingStateMachine profilingStateMachine;
+
+        TestProfilingConnectionBase testProfilingConnectionBase;
+        TestProfilingConnectionTimeoutError testProfilingConnectionTimeOutError;
+        TestProfilingConnectionArmnnError testProfilingConnectionArmnnError;
+
+        ConnectionAcknowledgedCommandHandler connectionAcknowledgedCommandHandler(1, 4194304, profilingStateMachine);
+        CommandHandlerRegistry commandHandlerRegistry;
+
+        commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler, 1, 4194304);
+
+        profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
+        profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
+
+        CommandThread commandThread0(1,
+                                     true,
+                                     commandHandlerRegistry,
+                                     packetVersionResolver,
+                                     testProfilingConnectionBase);
+
+        commandThread0.Start();
+        commandThread0.Start();
+        commandThread0.Start();
+
+        commandThread0.Stop();
+        commandThread0.Join();
+
+        BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
+
+        profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
+        profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
+        //commandThread1 should give up after one timeout
+        CommandThread commandThread1(1,
+                                     true,
+                                     commandHandlerRegistry,
+                                     packetVersionResolver,
+                                     testProfilingConnectionTimeOutError);
+
+        commandThread1.Start();
+        commandThread1.Join();
+
+        BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck);
+        //now commandThread1 should persist after a timeout
+        commandThread1.StopAfterTimeout(false);
+        commandThread1.Start();
+
+        for (int i = 0; i < 100; i++)
+        {
+            if (profilingStateMachine.GetCurrentState() == ProfilingState::Active)
+            {
+                break;
+            }
+            else
+            {
+                std::this_thread::sleep_for(std::chrono::milliseconds(5));
+            }
+        }
+
+        commandThread1.Stop();
+        commandThread1.Join();
+
+        BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
+
+
+        CommandThread commandThread2(1,
+                                     false,
+                                     commandHandlerRegistry,
+                                     packetVersionResolver,
+                                     testProfilingConnectionArmnnError);
+
+        commandThread2.Start();
+
+        for (int i = 0; i < 100; i++)
+        {
+            if (!commandThread2.IsRunning())
+            {
+                //commandThread2 should stop once it encounters a non timing error
+                commandThread2.Join();
+                return;
+            }
+            std::this_thread::sleep_for(std::chrono::milliseconds(5));
+        }
+
+        BOOST_ERROR("commandThread2 has failed to stop");
+}
+
 BOOST_AUTO_TEST_CASE(CheckEncodeVersion)
 {
     Version version1(12);