namespace profiling
{
-CommandThread::CommandThread(uint32_t timeout,
- bool stopAfterTimeout,
- CommandHandlerRegistry& commandHandlerRegistry,
- PacketVersionResolver& packetVersionResolver,
- IProfilingConnection& socketProfilingConnection)
- : m_Timeout(timeout)
- , m_StopAfterTimeout(stopAfterTimeout)
- , m_IsRunning(false)
- , m_CommandHandlerRegistry(commandHandlerRegistry)
- , m_PacketVersionResolver(packetVersionResolver)
- , m_SocketProfilingConnection(socketProfilingConnection)
-{};
-
-void CommandThread::WaitForPacket()
+void CommandThread::Start()
{
- do {
- try
- {
- Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout);
- Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
-
- CommandHandlerFunctor* commandHandlerFunctor =
- m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue());
- commandHandlerFunctor->operator()(packet);
- }
- catch(const armnn::TimeoutException&)
- {
- if(m_StopAfterTimeout)
- {
- m_IsRunning.store(false, std::memory_order_relaxed);
- return;
- }
- }
- catch(...)
- {
- //might want to differentiate the errors more
- m_IsRunning.store(false, std::memory_order_relaxed);
- return;
- }
-
- } while(m_KeepRunning.load(std::memory_order_relaxed));
+ if (IsRunning())
+ {
+ return;
+ }
- m_IsRunning.store(false, std::memory_order_relaxed);
+ m_IsRunning.store(true, std::memory_order_relaxed);
+ m_KeepRunning.store(true, std::memory_order_relaxed);
+ m_CommandThread = std::thread(&CommandThread::WaitForPacket, this);
}
-void CommandThread::Start()
+void CommandThread::Stop()
{
- if (!m_CommandThread.joinable() && !IsRunning())
+ m_KeepRunning.store(false, std::memory_order_relaxed);
+
+ if (m_CommandThread.joinable())
{
- m_IsRunning.store(true, std::memory_order_relaxed);
- m_KeepRunning.store(true, std::memory_order_relaxed);
- m_CommandThread = std::thread(&CommandThread::WaitForPacket, this);
+ m_CommandThread.join();
}
}
-void CommandThread::Stop()
+bool CommandThread::IsRunning() const
{
- m_KeepRunning.store(false, std::memory_order_relaxed);
+ return m_IsRunning.load(std::memory_order_relaxed);
}
-void CommandThread::Join()
+void CommandThread::SetTimeout(uint32_t timeout)
{
- m_CommandThread.join();
+ m_Timeout.store(timeout, std::memory_order_relaxed);
}
-bool CommandThread::IsRunning() const
+void CommandThread::SetStopAfterTimeout(bool stopAfterTimeout)
{
- return m_IsRunning.load(std::memory_order_relaxed);
+ m_StopAfterTimeout.store(stopAfterTimeout, std::memory_order_relaxed);
}
-bool CommandThread::StopAfterTimeout(bool stopAfterTimeout)
+void CommandThread::WaitForPacket()
{
- if (!IsRunning())
+ do
{
- m_StopAfterTimeout = stopAfterTimeout;
- return true;
+ try
+ {
+ Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout);
+ Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
+
+ CommandHandlerFunctor* commandHandlerFunctor =
+ m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue());
+ BOOST_ASSERT(commandHandlerFunctor);
+ commandHandlerFunctor->operator()(packet);
+ }
+ catch (const armnn::TimeoutException&)
+ {
+ if (m_StopAfterTimeout)
+ {
+ m_KeepRunning.store(false, std::memory_order_relaxed);
+ }
+ }
+ catch (...)
+ {
+ // Might want to differentiate the errors more
+ m_KeepRunning.store(false, std::memory_order_relaxed);
+ }
+
}
- return false;
+ while (m_KeepRunning.load(std::memory_order_relaxed));
+
+ m_IsRunning.store(false, std::memory_order_relaxed);
}
-}//namespace profiling
+} // namespace profiling
-}//namespace armnn
+} // namespace armnn
bool stopAfterTimeout,
CommandHandlerRegistry& commandHandlerRegistry,
PacketVersionResolver& packetVersionResolver,
- IProfilingConnection& socketProfilingConnection);
+ IProfilingConnection& socketProfilingConnection)
+ : m_Timeout(timeout)
+ , m_StopAfterTimeout(stopAfterTimeout)
+ , m_IsRunning(false)
+ , m_KeepRunning(false)
+ , m_CommandThread()
+ , m_CommandHandlerRegistry(commandHandlerRegistry)
+ , m_PacketVersionResolver(packetVersionResolver)
+ , m_SocketProfilingConnection(socketProfilingConnection)
+ {}
+ ~CommandThread() { Stop(); }
void Start();
void Stop();
- void Join();
+
bool IsRunning() const;
- bool StopAfterTimeout(bool StopAfterTimeout);
+
+ void SetTimeout(uint32_t timeout);
+ void SetStopAfterTimeout(bool stopAfterTimeout);
private:
void WaitForPacket();
- uint32_t m_Timeout;
- bool m_StopAfterTimeout;
+ std::atomic<uint32_t> m_Timeout;
+ std::atomic<bool> m_StopAfterTimeout;
std::atomic<bool> m_IsRunning;
std::atomic<bool> m_KeepRunning;
std::thread m_CommandThread;
IProfilingConnection& m_SocketProfilingConnection;
};
-}//namespace profiling
+} // namespace profiling
-}//namespace armnn
\ No newline at end of file
+} // namespace armnn
commandThread0.Start();
commandThread0.Stop();
- commandThread0.Join();
BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
testProfilingConnectionTimeOutError);
commandThread1.Start();
- commandThread1.Join();
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ BOOST_CHECK(!commandThread1.IsRunning());
+ commandThread1.Stop();
BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck);
//now commandThread1 should persist after a timeout
- commandThread1.StopAfterTimeout(false);
+ commandThread1.SetStopAfterTimeout(false);
commandThread1.Start();
for (int i = 0; i < 100; i++)
}
commandThread1.Stop();
- commandThread1.Join();
BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
-
CommandThread commandThread2(1,
false,
commandHandlerRegistry,
if (!commandThread2.IsRunning())
{
//commandThread2 should stop once it encounters a non timing error
- commandThread2.Join();
return;
}
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
BOOST_ERROR("commandThread2 has failed to stop");
+ commandThread2.Stop();
}
BOOST_AUTO_TEST_CASE(CheckEncodeVersion)