From: Matteo Martincigh Date: Mon, 7 Oct 2019 12:05:13 +0000 (+0100) Subject: IVGCVSW-3937 Update the Send thread to send out the Metadata packet X-Git-Tag: submit/tizen/20200316.035456~182 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5d737fb3b06c17ff6b65fb307343ca1c0c680401;p=platform%2Fupstream%2Farmnn.git IVGCVSW-3937 Update the Send thread to send out the Metadata packet * The Send thread now automatically sends out Stream Metadata packets when the Profiling Service is in WaitingForAck state * Added a reference to the profiling state in the SendCounterPacket class * Moving the RuntimeException thrown in the Send thread to the main thread for rethrowing * The Stop method now rethrows the exception occurred in the send thread * The Stop method does not rethrow when destructing the object * Added unit tests Signed-off-by: Matteo Martincigh Change-Id: Ice7080bff63199eac84fc4fa1d37fb1a6fcdff89 --- diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp index dc5a950..b9f2b18 100644 --- a/src/profiling/SendCounterPacket.cpp +++ b/src/profiling/SendCounterPacket.cpp @@ -920,7 +920,7 @@ void SendCounterPacket::Start(IProfilingConnection& profilingConnection) m_SendThread = std::thread(&SendCounterPacket::Send, this, std::ref(profilingConnection)); } -void SendCounterPacket::Stop() +void SendCounterPacket::Stop(bool rethrowSendThreadExceptions) { // Signal the send thread to stop m_KeepRunning.store(false); @@ -934,6 +934,30 @@ void SendCounterPacket::Stop() // Wait for the send thread to complete operations m_SendThread.join(); } + + // Check if the send thread exception has to be rethrown + if (!rethrowSendThreadExceptions) + { + // No need to rethrow the send thread exception, return immediately + return; + } + + // Exception handling lock scope - Begin + { + // Lock the mutex to handle any exception coming from the send thread + std::unique_lock lock(m_WaitMutex); + + // Check if there's an exception to rethrow + if (m_SendThreadException) + { + // Rethrow the send thread exception + std::rethrow_exception(m_SendThreadException); + + // Nullify the exception as it has been rethrown + m_SendThreadException = nullptr; + } + } + // Exception handling lock scope - End } void SendCounterPacket::Send(IProfilingConnection& profilingConnection) @@ -946,20 +970,67 @@ void SendCounterPacket::Send(IProfilingConnection& profilingConnection) // Lock the mutex to wait on it std::unique_lock lock(m_WaitMutex); - if (m_Timeout < 0) + // Check the current state of the profiling service + ProfilingState currentState = m_StateMachine.GetCurrentState(); + switch (currentState) { - // Wait indefinitely until notified that something to read has become available in the buffer + case ProfilingState::Uninitialised: + case ProfilingState::NotConnected: + + // The send thread cannot be running when the profiling service is uninitialized or not connected, + // stop the thread immediately + m_KeepRunning.store(false); + m_IsRunning.store(false); + + // An exception should be thrown here, save it to be rethrown later from the main thread so that + // it can be caught by the consumer + m_SendThreadException = + std::make_exception_ptr(RuntimeException("The send thread should not be running with the " + "profiling service not yet initialized or connected")); + + return; + case ProfilingState::WaitingForAck: + + // Send out a StreamMetadata packet and wait for the profiling connection to be acknowledged. + // When a ConnectionAcknowledged packet is received, the profiling service state will be automatically + // updated by the command handler + + // Prepare a StreamMetadata packet and write it to the Counter Stream buffer + SendStreamMetaDataPacket(); + + // Flush the buffer manually to send the packet + FlushBuffer(profilingConnection); + + // Wait indefinitely until notified otherwise (it could that the profiling state has changed due to the + // connection being acknowledged, or that new data is ready to be sent, or that the send thread is + // being shut down, etc.) m_WaitCondition.wait(lock); - } - else - { - // Wait until the thread is notified of something to read from the buffer, - // or check anyway after the specified number of milliseconds - m_WaitCondition.wait_for(lock, std::chrono::milliseconds(m_Timeout)); + + // Do not flush the buffer again + continue; + case ProfilingState::Active: + default: + // Normal working state for the send thread + + // Check if the send thread is required to enforce a timeout wait policy + if (m_Timeout < 0) + { + // Wait indefinitely until notified that something to read has become available in the buffer + m_WaitCondition.wait(lock); + } + else + { + // Wait until the thread is notified of something to read from the buffer, + // or check anyway after the specified number of milliseconds + m_WaitCondition.wait_for(lock, std::chrono::milliseconds(m_Timeout)); + } + + break; } } // Wait condition lock scope - End + // Send all the available packets in the buffer FlushBuffer(profilingConnection); } @@ -1000,7 +1071,7 @@ void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection) // Mark the packet buffer as read m_BufferManager.MarkRead(packetBuffer); - // Get next available readable buffer + // Get the next available readable buffer packetBuffer = m_BufferManager.GetReadableBuffer(); } } diff --git a/src/profiling/SendCounterPacket.hpp b/src/profiling/SendCounterPacket.hpp index ed76937..9361efb 100644 --- a/src/profiling/SendCounterPacket.hpp +++ b/src/profiling/SendCounterPacket.hpp @@ -6,9 +6,10 @@ #pragma once #include "IBufferManager.hpp" -#include "ISendCounterPacket.hpp" #include "ICounterDirectory.hpp" +#include "ISendCounterPacket.hpp" #include "IProfilingConnection.hpp" +#include "ProfilingStateMachine.hpp" #include "ProfilingUtils.hpp" #include @@ -26,20 +27,25 @@ namespace profiling class SendCounterPacket : public ISendCounterPacket { public: - using CategoryRecord = std::vector; - using DeviceRecord = std::vector; - using CounterSetRecord = std::vector; - using EventRecord = std::vector; - + using CategoryRecord = std::vector; + using DeviceRecord = std::vector; + using CounterSetRecord = std::vector; + using EventRecord = std::vector; using IndexValuePairsVector = std::vector>; - SendCounterPacket(IBufferManager& buffer, int timeout = 1000) - : m_BufferManager(buffer) + SendCounterPacket(ProfilingStateMachine& profilingStateMachine, IBufferManager& buffer, int timeout = 1000) + : m_StateMachine(profilingStateMachine) + , m_BufferManager(buffer) , m_Timeout(timeout) , m_IsRunning(false) , m_KeepRunning(false) + , m_SendThreadException(nullptr) {} - ~SendCounterPacket() { Stop(); } + ~SendCounterPacket() + { + // Don't rethrow when destructing the object + Stop(false); + } void SendStreamMetaDataPacket() override; @@ -56,7 +62,7 @@ public: static const unsigned int MAX_METADATA_PACKET_LENGTH = 4096; void Start(IProfilingConnection& profilingConnection); - void Stop(); + void Stop(bool rethrowSendThreadExceptions = true); bool IsRunning() { return m_IsRunning.load(); } private: @@ -76,6 +82,7 @@ private: { SetReadyToRead(); } + if (writerBuffer != nullptr) { // Cancel the operation @@ -88,6 +95,7 @@ private: void FlushBuffer(IProfilingConnection& profilingConnection); + ProfilingStateMachine& m_StateMachine; IBufferManager& m_BufferManager; int m_Timeout; std::mutex m_WaitMutex; @@ -95,6 +103,7 @@ private: std::thread m_SendThread; std::atomic m_IsRunning; std::atomic m_KeepRunning; + std::exception_ptr m_SendThreadException; protected: // Helper methods, protected for testing diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 91568d1..24ab779 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -1767,6 +1767,8 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) { using boost::numeric_cast; + ProfilingStateMachine profilingStateMachine; + class TestCaptureThread : public IPeriodicCounterCapture { void Start() override {} @@ -1779,7 +1781,7 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) Holder holder; TestCaptureThread captureThread; MockBufferManager mockBuffer(512); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); uint32_t sizeOfUint32 = numeric_cast(sizeof(uint32_t)); uint32_t sizeOfUint16 = numeric_cast(sizeof(uint16_t)); @@ -2135,12 +2137,14 @@ BOOST_AUTO_TEST_CASE(CheckPeriodicCounterCaptureThread) std::unordered_map m_Data; }; + ProfilingStateMachine profilingStateMachine; + Holder data; std::vector captureIds1 = { 0, 1 }; std::vector captureIds2; MockBufferManager mockBuffer(512); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); std::vector counterIds; CaptureReader captureReader; @@ -2201,6 +2205,8 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest0) { using boost::numeric_cast; + ProfilingStateMachine profilingStateMachine; + const uint32_t packetId = 0x30000; const uint32_t version = 1; @@ -2209,7 +2215,7 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest0) Packet packetA(packetId, 0, packetData); MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); CounterDirectory counterDirectory; @@ -2234,6 +2240,8 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1) { using boost::numeric_cast; + ProfilingStateMachine profilingStateMachine; + const uint32_t packetId = 0x30000; const uint32_t version = 1; @@ -2242,7 +2250,7 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1) Packet packetA(packetId, 0, packetData); MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); CounterDirectory counterDirectory; const Device* device = counterDirectory.RegisterDevice("deviceA", 1); diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp index 16302bc..1216420 100644 --- a/src/profiling/test/SendCounterPacketTests.cpp +++ b/src/profiling/test/SendCounterPacketTests.cpp @@ -21,14 +21,71 @@ using namespace armnn::profiling; +namespace +{ + +void SetNotConnectedProfilingState(ProfilingStateMachine& profilingStateMachine) +{ + ProfilingState currentState = profilingStateMachine.GetCurrentState(); + switch (currentState) + { + case ProfilingState::WaitingForAck: + profilingStateMachine.TransitionToState(ProfilingState::Active); + case ProfilingState::Uninitialised: + case ProfilingState::Active: + profilingStateMachine.TransitionToState(ProfilingState::NotConnected); + case ProfilingState::NotConnected: + return; + default: + BOOST_CHECK_MESSAGE(false, "Invalid profiling state"); + } +} + +void SetWaitingForAckProfilingState(ProfilingStateMachine& profilingStateMachine) +{ + ProfilingState currentState = profilingStateMachine.GetCurrentState(); + switch (currentState) + { + case ProfilingState::Uninitialised: + case ProfilingState::Active: + profilingStateMachine.TransitionToState(ProfilingState::NotConnected); + case ProfilingState::NotConnected: + profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); + case ProfilingState::WaitingForAck: + return; + default: + BOOST_CHECK_MESSAGE(false, "Invalid profiling state"); + } +} + +void SetActiveProfilingState(ProfilingStateMachine& profilingStateMachine) +{ + ProfilingState currentState = profilingStateMachine.GetCurrentState(); + switch (currentState) + { + case ProfilingState::Uninitialised: + profilingStateMachine.TransitionToState(ProfilingState::NotConnected); + case ProfilingState::NotConnected: + profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); + case ProfilingState::WaitingForAck: + profilingStateMachine.TransitionToState(ProfilingState::Active); + case ProfilingState::Active: + return; + default: + BOOST_CHECK_MESSAGE(false, "Invalid profiling state"); + } +} + +} // Anonymous namespace + BOOST_AUTO_TEST_SUITE(SendCounterPacketTests) BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest) { MockBufferManager mockBuffer(512); - MockSendCounterPacket sendCounterPacket(mockBuffer); + MockSendCounterPacket mockSendCounterPacket(mockBuffer); - sendCounterPacket.SendStreamMetaDataPacket(); + mockSendCounterPacket.SendStreamMetaDataPacket(); auto packetBuffer = mockBuffer.GetReadableBuffer(); const char* buffer = reinterpret_cast(packetBuffer->GetReadableData()); @@ -38,7 +95,7 @@ BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest) mockBuffer.MarkRead(packetBuffer); CounterDirectory counterDirectory; - sendCounterPacket.SendCounterDirectoryPacket(counterDirectory); + mockSendCounterPacket.SendCounterDirectoryPacket(counterDirectory); packetBuffer = mockBuffer.GetReadableBuffer(); buffer = reinterpret_cast(packetBuffer->GetReadableData()); @@ -50,7 +107,7 @@ BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest) uint64_t timestamp = 0; std::vector> indexValuePairs; - sendCounterPacket.SendPeriodicCounterCapturePacket(timestamp, indexValuePairs); + mockSendCounterPacket.SendPeriodicCounterCapturePacket(timestamp, indexValuePairs); packetBuffer = mockBuffer.GetReadableBuffer(); buffer = reinterpret_cast(packetBuffer->GetReadableData()); @@ -61,7 +118,7 @@ BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest) uint32_t capturePeriod = 0; std::vector selectedCounterIds; - sendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, selectedCounterIds); + mockSendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, selectedCounterIds); packetBuffer = mockBuffer.GetReadableBuffer(); buffer = reinterpret_cast(packetBuffer->GetReadableData()); @@ -73,9 +130,11 @@ BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest) BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest) { + ProfilingStateMachine profilingStateMachine; + // Error no space left in buffer MockBufferManager mockBuffer1(10); - SendCounterPacket sendPacket1(mockBuffer1); + SendCounterPacket sendPacket1(profilingStateMachine, mockBuffer1); uint32_t capturePeriod = 1000; std::vector selectedCounterIds; @@ -84,7 +143,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest) // Packet without any counters MockBufferManager mockBuffer2(512); - SendCounterPacket sendPacket2(mockBuffer2); + SendCounterPacket sendPacket2(profilingStateMachine, mockBuffer2); sendPacket2.SendPeriodicCounterSelectionPacket(capturePeriod, selectedCounterIds); auto readBuffer2 = mockBuffer2.GetReadableBuffer(); @@ -100,7 +159,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest) // Full packet message MockBufferManager mockBuffer3(512); - SendCounterPacket sendPacket3(mockBuffer3); + SendCounterPacket sendPacket3(profilingStateMachine, mockBuffer3); selectedCounterIds.reserve(5); selectedCounterIds.emplace_back(100); @@ -134,9 +193,11 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterSelectionPacketTest) BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest) { + ProfilingStateMachine profilingStateMachine; + // Error no space left in buffer MockBufferManager mockBuffer1(10); - SendCounterPacket sendPacket1(mockBuffer1); + SendCounterPacket sendPacket1(profilingStateMachine, mockBuffer1); auto captureTimestamp = std::chrono::steady_clock::now(); uint64_t time = static_cast(captureTimestamp.time_since_epoch().count()); @@ -147,7 +208,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest) // Packet without any counters MockBufferManager mockBuffer2(512); - SendCounterPacket sendPacket2(mockBuffer2); + SendCounterPacket sendPacket2(profilingStateMachine, mockBuffer2); sendPacket2.SendPeriodicCounterCapturePacket(time, indexValuePairs); auto readBuffer2 = mockBuffer2.GetReadableBuffer(); @@ -164,7 +225,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest) // Full packet message MockBufferManager mockBuffer3(512); - SendCounterPacket sendPacket3(mockBuffer3); + SendCounterPacket sendPacket3(profilingStateMachine, mockBuffer3); indexValuePairs.reserve(5); indexValuePairs.emplace_back(std::make_pair(0, 100)); @@ -213,9 +274,11 @@ BOOST_AUTO_TEST_CASE(SendStreamMetaDataPacketTest) uint32_t sizeUint32 = numeric_cast(sizeof(uint32_t)); + ProfilingStateMachine profilingStateMachine; + // Error no space left in buffer MockBufferManager mockBuffer1(10); - SendCounterPacket sendPacket1(mockBuffer1); + SendCounterPacket sendPacket1(profilingStateMachine, mockBuffer1); BOOST_CHECK_THROW(sendPacket1.SendStreamMetaDataPacket(), armnn::profiling::BufferExhaustion); // Full metadata packet @@ -234,7 +297,7 @@ BOOST_AUTO_TEST_CASE(SendStreamMetaDataPacketTest) uint32_t packetEntries = 6; MockBufferManager mockBuffer2(512); - SendCounterPacket sendPacket2(mockBuffer2); + SendCounterPacket sendPacket2(profilingStateMachine, mockBuffer2); sendPacket2.SendStreamMetaDataPacket(); auto readBuffer2 = mockBuffer2.GetReadableBuffer(); @@ -328,8 +391,10 @@ BOOST_AUTO_TEST_CASE(SendStreamMetaDataPacketTest) BOOST_AUTO_TEST_CASE(CreateDeviceRecordTest) { + ProfilingStateMachine profilingStateMachine; + MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); // Create a device for testing uint16_t deviceUid = 27; @@ -360,8 +425,10 @@ BOOST_AUTO_TEST_CASE(CreateDeviceRecordTest) BOOST_AUTO_TEST_CASE(CreateInvalidDeviceRecordTest) { + ProfilingStateMachine profilingStateMachine; + MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); // Create a device for testing uint16_t deviceUid = 27; @@ -381,8 +448,10 @@ BOOST_AUTO_TEST_CASE(CreateInvalidDeviceRecordTest) BOOST_AUTO_TEST_CASE(CreateCounterSetRecordTest) { + ProfilingStateMachine profilingStateMachine; + MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); // Create a counter set for testing uint16_t counterSetUid = 27; @@ -413,8 +482,10 @@ BOOST_AUTO_TEST_CASE(CreateCounterSetRecordTest) BOOST_AUTO_TEST_CASE(CreateInvalidCounterSetRecordTest) { + ProfilingStateMachine profilingStateMachine; + MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); // Create a counter set for testing uint16_t counterSetUid = 27; @@ -434,8 +505,10 @@ BOOST_AUTO_TEST_CASE(CreateInvalidCounterSetRecordTest) BOOST_AUTO_TEST_CASE(CreateEventRecordTest) { + ProfilingStateMachine profilingStateMachine; + MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); // Create a counter for testing uint16_t counterUid = 7256; @@ -554,8 +627,10 @@ BOOST_AUTO_TEST_CASE(CreateEventRecordTest) BOOST_AUTO_TEST_CASE(CreateEventRecordNoUnitsTest) { + ProfilingStateMachine profilingStateMachine; + MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); // Create a counter for testing uint16_t counterUid = 44312; @@ -657,8 +732,10 @@ BOOST_AUTO_TEST_CASE(CreateEventRecordNoUnitsTest) BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest1) { + ProfilingStateMachine profilingStateMachine; + MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); // Create a counter for testing uint16_t counterUid = 7256; @@ -695,8 +772,10 @@ BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest1) BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest2) { + ProfilingStateMachine profilingStateMachine; + MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); // Create a counter for testing uint16_t counterUid = 7256; @@ -733,8 +812,10 @@ BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest2) BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest3) { + ProfilingStateMachine profilingStateMachine; + MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); // Create a counter for testing uint16_t counterUid = 7256; @@ -771,8 +852,10 @@ BOOST_AUTO_TEST_CASE(CreateInvalidEventRecordTest3) BOOST_AUTO_TEST_CASE(CreateCategoryRecordTest) { + ProfilingStateMachine profilingStateMachine; + MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); // Create a category for testing const std::string categoryName = "some_category"; @@ -972,8 +1055,10 @@ BOOST_AUTO_TEST_CASE(CreateCategoryRecordTest) BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest1) { + ProfilingStateMachine profilingStateMachine; + MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); // Create a category for testing const std::string categoryName = "some invalid category"; @@ -995,8 +1080,10 @@ BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest1) BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest2) { + ProfilingStateMachine profilingStateMachine; + MockBufferManager mockBuffer(0); - SendCounterPacketTest sendCounterPacketTest(mockBuffer); + SendCounterPacketTest sendCounterPacketTest(profilingStateMachine, mockBuffer); // Create a category for testing const std::string categoryName = "some_category"; @@ -1035,6 +1122,8 @@ BOOST_AUTO_TEST_CASE(CreateInvalidCategoryRecordTest2) BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest1) { + ProfilingStateMachine profilingStateMachine; + // The counter directory used for testing CounterDirectory counterDirectory; @@ -1054,13 +1143,15 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest1) // Buffer with not enough space MockBufferManager mockBuffer(10); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::profiling::BufferExhaustion); } BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) { + ProfilingStateMachine profilingStateMachine; + // The counter directory used for testing CounterDirectory counterDirectory; @@ -1146,7 +1237,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) // Buffer with enough space MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); BOOST_CHECK_NO_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory)); // Get the readable buffer @@ -1535,6 +1626,8 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2) BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest3) { + ProfilingStateMachine profilingStateMachine; + // Using a mock counter directory that allows to register invalid objects MockCounterDirectory counterDirectory; @@ -1547,12 +1640,14 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest3) // Buffer with enough space MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest4) { + ProfilingStateMachine profilingStateMachine; + // Using a mock counter directory that allows to register invalid objects MockCounterDirectory counterDirectory; @@ -1565,12 +1660,14 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest4) // Buffer with enough space MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest5) { + ProfilingStateMachine profilingStateMachine; + // Using a mock counter directory that allows to register invalid objects MockCounterDirectory counterDirectory; @@ -1583,12 +1680,14 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest5) // Buffer with enough space MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest6) { + ProfilingStateMachine profilingStateMachine; + // Using a mock counter directory that allows to register invalid objects MockCounterDirectory counterDirectory; @@ -1617,12 +1716,14 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest6) // Buffer with enough space MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest7) { + ProfilingStateMachine profilingStateMachine; + // Using a mock counter directory that allows to register invalid objects MockCounterDirectory counterDirectory; @@ -1666,15 +1767,18 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest7) // Buffer with enough space MockBufferManager mockBuffer(1024); - SendCounterPacket sendCounterPacket(mockBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer); BOOST_CHECK_THROW(sendCounterPacket.SendCounterDirectoryPacket(counterDirectory), armnn::RuntimeException); } BOOST_AUTO_TEST_CASE(SendThreadTest0) { + ProfilingStateMachine profilingStateMachine; + SetActiveProfilingState(profilingStateMachine); + MockProfilingConnection mockProfilingConnection; MockStreamCounterBuffer mockStreamCounterBuffer(0); - SendCounterPacket sendCounterPacket(mockStreamCounterBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockStreamCounterBuffer); // Try to start the send thread many times, it must only start once @@ -1694,11 +1798,14 @@ BOOST_AUTO_TEST_CASE(SendThreadTest0) BOOST_AUTO_TEST_CASE(SendThreadTest1) { + ProfilingStateMachine profilingStateMachine; + SetActiveProfilingState(profilingStateMachine); + unsigned int totalWrittenSize = 0; MockProfilingConnection mockProfilingConnection; MockStreamCounterBuffer mockStreamCounterBuffer(1024); - SendCounterPacket sendCounterPacket(mockStreamCounterBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockStreamCounterBuffer); sendCounterPacket.Start(mockProfilingConnection); // Interleaving writes and reads to/from the buffer with pauses to test that the send thread actually waits for @@ -1802,11 +1909,14 @@ BOOST_AUTO_TEST_CASE(SendThreadTest1) BOOST_AUTO_TEST_CASE(SendThreadTest2) { + ProfilingStateMachine profilingStateMachine; + SetActiveProfilingState(profilingStateMachine); + unsigned int totalWrittenSize = 0; MockProfilingConnection mockProfilingConnection; MockStreamCounterBuffer mockStreamCounterBuffer(1024); - SendCounterPacket sendCounterPacket(mockStreamCounterBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockStreamCounterBuffer); sendCounterPacket.Start(mockProfilingConnection); // Adding many spurious "ready to read" signals throughout the test to check that the send thread is @@ -1922,11 +2032,14 @@ BOOST_AUTO_TEST_CASE(SendThreadTest2) BOOST_AUTO_TEST_CASE(SendThreadTest3) { + ProfilingStateMachine profilingStateMachine; + SetActiveProfilingState(profilingStateMachine); + unsigned int totalWrittenSize = 0; MockProfilingConnection mockProfilingConnection; MockStreamCounterBuffer mockStreamCounterBuffer(1024); - SendCounterPacket sendCounterPacket(mockStreamCounterBuffer); + SendCounterPacket sendCounterPacket(profilingStateMachine, mockStreamCounterBuffer); sendCounterPacket.Start(mockProfilingConnection); // Not using pauses or "grace periods" to stress test the send thread @@ -2025,9 +2138,12 @@ BOOST_AUTO_TEST_CASE(SendThreadTest3) BOOST_AUTO_TEST_CASE(SendThreadBufferTest) { + ProfilingStateMachine profilingStateMachine; + SetActiveProfilingState(profilingStateMachine); + MockProfilingConnection mockProfilingConnection; BufferManager bufferManager(1, 1024); - SendCounterPacket sendCounterPacket(bufferManager, -1); + SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager, -1); sendCounterPacket.Start(mockProfilingConnection); // Interleaving writes and reads to/from the buffer with pauses to test that the send thread actually waits for @@ -2152,9 +2268,12 @@ BOOST_AUTO_TEST_CASE(SendThreadBufferTest) BOOST_AUTO_TEST_CASE(SendThreadBufferTest1) { - MockWriteProfilingConnection mockProfilingConnection; + ProfilingStateMachine profilingStateMachine; + SetActiveProfilingState(profilingStateMachine); + + MockProfilingConnection mockProfilingConnection; BufferManager bufferManager(3, 1024); - SendCounterPacket sendCounterPacket(bufferManager, -1); + SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager, -1); sendCounterPacket.Start(mockProfilingConnection); // SendStreamMetaDataPacket @@ -2203,8 +2322,7 @@ BOOST_AUTO_TEST_CASE(SendThreadBufferTest1) BOOST_TEST(reservedBuffer.get()); // Check that data was actually written to the profiling connection in any order - std::vector writtenData = mockProfilingConnection.GetWrittenData(); - std::vector expectedOutput{streamMetadataPacketsize, 32, 28}; + const std::vector& writtenData = mockProfilingConnection.GetWrittenData(); BOOST_TEST(writtenData.size() == 3); bool foundStreamMetaDataPacket = std::find(writtenData.begin(), writtenData.end(), streamMetadataPacketsize) != writtenData.end(); @@ -2215,4 +2333,113 @@ BOOST_AUTO_TEST_CASE(SendThreadBufferTest1) BOOST_TEST(foundPeriodicCounterCapturePacket); } +BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket1) +{ + ProfilingStateMachine profilingStateMachine; + + MockProfilingConnection mockProfilingConnection; + BufferManager bufferManager(3, 1024); + SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager); + sendCounterPacket.Start(mockProfilingConnection); + + // The profiling state is set to "Uninitialized", so the send thread should throw an exception + + // Wait a bit to make sure that the send thread is properly started + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + BOOST_CHECK_THROW(sendCounterPacket.Stop(), armnn::RuntimeException); +} + +BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket2) +{ + ProfilingStateMachine profilingStateMachine; + SetNotConnectedProfilingState(profilingStateMachine); + + MockProfilingConnection mockProfilingConnection; + BufferManager bufferManager(3, 1024); + SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager); + sendCounterPacket.Start(mockProfilingConnection); + + // The profiling state is set to "NotConnected", so the send thread should throw an exception + + // Wait a bit to make sure that the send thread is properly started + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + BOOST_CHECK_THROW(sendCounterPacket.Stop(), armnn::RuntimeException); +} + +BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket3) +{ + ProfilingStateMachine profilingStateMachine; + SetWaitingForAckProfilingState(profilingStateMachine); + + // Calculate the size of a Stream Metadata packet + std::string processName = GetProcessName().substr(0, 60); + unsigned int processNameSize = processName.empty() ? 0 : boost::numeric_cast(processName.size()) + 1; + unsigned int streamMetadataPacketsize = 118 + processNameSize; + + MockProfilingConnection mockProfilingConnection; + BufferManager bufferManager(3, 1024); + SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager); + sendCounterPacket.Start(mockProfilingConnection); + + // The profiling state is set to "WaitingForAck", so the send thread should send a Stream Metadata packet + + // Wait for a bit to make sure that we get the packet + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + BOOST_CHECK_NO_THROW(sendCounterPacket.Stop()); + + // Check that the buffer contains one Stream Metadata packet + const std::vector& writtenData = mockProfilingConnection.GetWrittenData(); + BOOST_TEST(writtenData.size() == 1); + BOOST_TEST(writtenData[0] == streamMetadataPacketsize); +} + +BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket4) +{ + ProfilingStateMachine profilingStateMachine; + SetWaitingForAckProfilingState(profilingStateMachine); + + // Calculate the size of a Stream Metadata packet + std::string processName = GetProcessName().substr(0, 60); + unsigned int processNameSize = processName.empty() ? 0 : boost::numeric_cast(processName.size()) + 1; + unsigned int streamMetadataPacketsize = 118 + processNameSize; + + MockProfilingConnection mockProfilingConnection; + BufferManager bufferManager(3, 1024); + SendCounterPacket sendCounterPacket(profilingStateMachine, bufferManager); + sendCounterPacket.Start(mockProfilingConnection); + + // The profiling state is set to "WaitingForAck", so the send thread should send a Stream Metadata packet + + // Wait for a bit to make sure that we get the packet + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Check that the profiling state is still "WaitingForAck" + BOOST_TEST((profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck)); + + // Check that the buffer contains one Stream Metadata packet + const std::vector& writtenData = mockProfilingConnection.GetWrittenData(); + BOOST_TEST(writtenData.size() == 1); + BOOST_TEST(writtenData[0] == streamMetadataPacketsize); + + mockProfilingConnection.Clear(); + + // Try triggering a new buffer read + sendCounterPacket.SetReadyToRead(); + + // Wait for a bit to make sure that we get the packet + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Check that the profiling state is still "WaitingForAck" + BOOST_TEST((profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck)); + + // Check that the buffer contains one Stream Metadata packet + BOOST_TEST(writtenData.size() == 1); + BOOST_TEST(writtenData[0] == streamMetadataPacketsize); + + BOOST_CHECK_NO_THROW(sendCounterPacket.Stop()); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp index 0323f62..cae02b0 100644 --- a/src/profiling/test/SendCounterPacketTests.hpp +++ b/src/profiling/test/SendCounterPacketTests.hpp @@ -19,7 +19,6 @@ namespace armnn namespace profiling { - class MockProfilingConnection : public IProfilingConnection { public: @@ -33,38 +32,20 @@ public: bool WritePacket(const unsigned char* buffer, uint32_t length) override { - return buffer != nullptr && length > 0; - } - - Packet ReadPacket(uint32_t timeout) override { return Packet(); } - -private: - bool m_IsOpen; -}; - -class MockWriteProfilingConnection : public IProfilingConnection -{ -public: - MockWriteProfilingConnection() - : m_IsOpen(true) - {} - - bool IsOpen() override { return m_IsOpen; } - - void Close() override { m_IsOpen = false; } + if (buffer == nullptr || length == 0) + { + return false; + } - bool WritePacket(const unsigned char* buffer, uint32_t length) override - { m_WrittenData.push_back(length); - return buffer != nullptr && length > 0; + return true; } Packet ReadPacket(uint32_t timeout) override { return Packet(); } - std::vector GetWrittenData() - { - return m_WrittenData; - } + const std::vector& GetWrittenData() const { return m_WrittenData; } + + void Clear() { m_WrittenData.clear(); } private: bool m_IsOpen; @@ -497,8 +478,8 @@ private: class SendCounterPacketTest : public SendCounterPacket { public: - SendCounterPacketTest(IBufferManager& buffer) - : SendCounterPacket(buffer) + SendCounterPacketTest(ProfilingStateMachine& profilingStateMachine, IBufferManager& buffer) + : SendCounterPacket(profilingStateMachine, buffer) {} bool CreateDeviceRecordTest(const DevicePtr& device,