IVGCVSW-3826: Implement IProfiling functions
authorFinnWilliamsArm <Finn.Williams@arm.com>
Mon, 16 Sep 2019 11:06:47 +0000 (12:06 +0100)
committerfinn.williams <finn.williams@arm.com>
Mon, 16 Sep 2019 14:17:03 +0000 (14:17 +0000)
!armnn:1814

Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Signed-off-by: FinnWilliamsArm <Finn.Williams@arm.com>
Change-Id: I82c7453d7969880e321572637adc0fb9c0e5fd7b

src/profiling/Packet.cpp
src/profiling/Packet.hpp
src/profiling/SocketProfilingConnection.cpp
src/profiling/test/ProfilingTests.cpp

index 44d5ac1..4cfa42b 100644 (file)
@@ -31,9 +31,9 @@ std::uint32_t Packet::GetLength() const
     return m_Length;
 }
 
-const char* Packet::GetData() const
+const char* const Packet::GetData() const
 {
-    return m_Data;
+    return m_Data.get();
 }
 
 std::uint32_t Packet::GetPacketClass() const
index c5e7f3c..1e047a6 100644 (file)
@@ -17,10 +17,8 @@ namespace profiling
 class Packet
 {
 public:
-    Packet(uint32_t header, uint32_t length, const char* data)
-        : m_Header(header),
-          m_Length(length),
-          m_Data(data)
+    Packet(uint32_t header, uint32_t length, std::unique_ptr<char[]>& data)
+    : m_Header(header), m_Length(length), m_Data(std::move(data))
     {
         m_PacketId = ((header >> 16) & 1023);
         m_PacketFamily = (header >> 26);
@@ -31,11 +29,21 @@ public:
         }
     }
 
+    Packet(Packet&& other) :
+           m_Header(other.m_Header),
+           m_PacketFamily(other.m_PacketFamily),
+           m_PacketId(other.m_PacketId),
+           m_Length(other.m_Length),
+           m_Data(std::move(other.m_Data)){};
+
+    Packet(const Packet& other) = delete;
+    Packet& operator=(const Packet&) = delete;
+
     uint32_t GetHeader() const;
     uint32_t GetPacketFamily() const;
     uint32_t GetPacketId() const;
     uint32_t GetLength() const;
-    const char* GetData() const;
+    const char* const GetData() const;
 
     uint32_t GetPacketClass() const;
     uint32_t GetPacketType() const;
@@ -45,7 +53,7 @@ private:
     uint32_t m_PacketFamily;
     uint32_t m_PacketId;
     uint32_t m_Length;
-    const char* m_Data;
+    std::unique_ptr<char[]> m_Data;
 };
 
 } // namespace profiling
index 21a7a1d..188ca23 100644 (file)
@@ -52,25 +52,91 @@ SocketProfilingConnection::SocketProfilingConnection()
 
 bool SocketProfilingConnection::IsOpen()
 {
-    // Dummy return value, function not implemented
-    return true;
+    if (m_Socket[0].fd > 0)
+    {
+        return true;
+    }
+    return false;
 }
 
 void SocketProfilingConnection::Close()
 {
-    // Function not implemented
+    if (0 == close(m_Socket[0].fd))
+    {
+        memset(m_Socket, 0, sizeof(m_Socket));
+    }
+    else
+    {
+        throw armnn::Exception(std::string(": Cannot close stream socket: ") + strerror(errno));
+    }
 }
 
 bool SocketProfilingConnection::WritePacket(const char* buffer, uint32_t length)
 {
-    // Dummy return value, function not implemented
+    if (-1 == write(m_Socket[0].fd, buffer, length))
+    {
+        return false;
+    }
     return true;
 }
 
 Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
 {
-    // Dummy return value, function not implemented
-    return {472580096, 0, nullptr};
+    // Poll for data on the socket or until timeout.
+    int pollResult = poll(m_Socket, 1, static_cast<int>(timeout));
+    if (pollResult > 0)
+    {
+        // Normal poll return but it could still contain an error signal.
+        if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP))
+        {
+            throw armnn::Exception(std::string(": Read failure from socket: ") + strerror(errno));
+        }
+        else if (m_Socket[0].revents & (POLLIN)) // There is data to read.
+        {
+            // Read the header first.
+            char header[8];
+            if (8 != recv(m_Socket[0].fd, &header, sizeof header, 0))
+            {
+                // What do we do here if there's not a valid 8 byte header to read?
+                throw armnn::Exception(": Received packet did not contains a valid MIPE header. ");
+            }
+            // stream_metadata_identifier is the first 4 bytes.
+            uint32_t metadataIdentifier = static_cast<uint32_t>(header[0]) << 24 |
+                                          static_cast<uint32_t>(header[1]) << 16 |
+                                          static_cast<uint32_t>(header[2]) << 8  |
+                                          static_cast<uint32_t>(header[3]);
+            // data_length is the next 4 bytes.
+            uint32_t dataLength = static_cast<uint32_t>(header[4]) << 24 |
+                                  static_cast<uint32_t>(header[5]) << 16 |
+                                  static_cast<uint32_t>(header[6]) << 8  |
+                                  static_cast<uint32_t>(header[7]);
+
+            std::unique_ptr<char[]> packetData;
+            if (dataLength > 0)
+            {
+                packetData = std::make_unique<char[]>(dataLength);
+            }
+
+            if (dataLength != recv(m_Socket[0].fd, packetData.get(), dataLength, 0))
+            {
+                // What do we do here if we can't read in a full packet?
+                throw armnn::Exception(": Invalid MIPE packet.");
+            }
+            return {metadataIdentifier, dataLength, packetData};
+        }
+        else // Some unknown return signal.
+        {
+            throw armnn::Exception(": Poll returned an unexpected event." );
+        }
+    }
+    else if (pollResult == -1)
+    {
+        throw armnn::Exception(std::string(": Read failure from socket: ") + strerror(errno));
+    }
+    else // it's 0 so a timeout.
+    {
+        throw armnn::Exception(": Timeout while reading from socket.");
+    }
 }
 
 } // namespace profiling
index 4913dde..55524a4 100644 (file)
@@ -114,21 +114,39 @@ BOOST_AUTO_TEST_CASE(CheckEncodeVersion)
 
 BOOST_AUTO_TEST_CASE(CheckPacketClass)
 {
-    const char* data = "test";
-    unsigned int length = static_cast<unsigned int>(std::strlen(data));
-
-    Packet packetTest1(472580096,length,data);
-    BOOST_CHECK_THROW(Packet packetTest2(472580096,0,""), armnn::Exception);
-
-    Packet packetTest3(472580096,0, nullptr);
-
-    BOOST_CHECK(packetTest1.GetLength() == length);
-    BOOST_CHECK(packetTest1.GetData() == data);
-
-    BOOST_CHECK(packetTest1.GetPacketFamily() == 7);
-    BOOST_CHECK(packetTest1.GetPacketId() == 43);
-    BOOST_CHECK(packetTest1.GetPacketType() == 3);
-    BOOST_CHECK(packetTest1.GetPacketClass() == 5);
+    uint32_t length = 4;
+    std::unique_ptr<char[]> packetData0 = std::make_unique<char[]>(length);
+    std::unique_ptr<char[]> packetData1 = std::make_unique<char[]>(0);
+    std::unique_ptr<char[]> nullPacketData;
+
+    Packet packetTest0(472580096, length, packetData0);
+
+    BOOST_CHECK(packetTest0.GetHeader() == 472580096);
+    BOOST_CHECK(packetTest0.GetPacketFamily() == 7);
+    BOOST_CHECK(packetTest0.GetPacketId() == 43);
+    BOOST_CHECK(packetTest0.GetLength() == length);
+    BOOST_CHECK(packetTest0.GetPacketType() == 3);
+    BOOST_CHECK(packetTest0.GetPacketClass() == 5);
+
+    BOOST_CHECK_THROW(Packet packetTest1(472580096, 0, packetData1), armnn::Exception);
+    BOOST_CHECK_NO_THROW(Packet packetTest2(472580096, 0, nullPacketData));
+
+    Packet packetTest3(472580096, 0, nullPacketData);
+    BOOST_CHECK(packetTest3.GetLength() == 0);
+    BOOST_CHECK(packetTest3.GetData() == nullptr);
+
+    const char* packetTest0Data = packetTest0.GetData();
+    Packet packetTest4(std::move(packetTest0));
+
+    BOOST_CHECK(packetTest0.GetData() == nullptr);
+    BOOST_CHECK(packetTest4.GetData() == packetTest0Data);
+
+    BOOST_CHECK(packetTest4.GetHeader() == 472580096);
+    BOOST_CHECK(packetTest4.GetPacketFamily() == 7);
+    BOOST_CHECK(packetTest4.GetPacketId() == 43);
+    BOOST_CHECK(packetTest4.GetLength() == length);
+    BOOST_CHECK(packetTest4.GetPacketType() == 3);
+    BOOST_CHECK(packetTest4.GetPacketClass() == 5);
 }
 
 // Create Derived Classes
@@ -186,9 +204,13 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
     it++;
     BOOST_CHECK(it->first==keyC);
 
-    Packet packetA(500000000, 0, nullptr);
-    Packet packetB(600000000, 0, nullptr);
-    Packet packetC(400000000, 0, nullptr);
+    std::unique_ptr<char[]> packetDataA;
+    std::unique_ptr<char[]> packetDataB;
+    std::unique_ptr<char[]> packetDataC;
+
+    Packet packetA(500000000, 0, packetDataA);
+    Packet packetB(600000000, 0, packetDataB);
+    Packet packetC(400000000, 0, packetDataC);
 
     // Check the correct operator of derived class is called
     registry.at(CommandHandlerKey(packetA.GetPacketId(), version))->operator()(packetA);
@@ -224,9 +246,13 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerRegistry)
     registry.RegisterFunctor(&testFunctorB, testFunctorB.GetPacketId(), testFunctorB.GetVersion());
     registry.RegisterFunctor(&testFunctorC, testFunctorC.GetPacketId(), testFunctorC.GetVersion());
 
-    Packet packetA(500000000, 0, nullptr);
-    Packet packetB(600000000, 0, nullptr);
-    Packet packetC(400000000, 0, nullptr);
+    std::unique_ptr<char[]> packetDataA;
+    std::unique_ptr<char[]> packetDataB;
+    std::unique_ptr<char[]> packetDataC;
+
+    Packet packetA(500000000, 0, packetDataA);
+    Packet packetB(600000000, 0, packetDataB);
+    Packet packetC(400000000, 0, packetDataC);
 
     // Check the correct operator of derived class is called
     registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetA);
@@ -561,16 +587,18 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
     // Data with period and counters
     uint32_t period1 = 10;
     uint32_t dataLength1 = 8;
-    unsigned char data1[dataLength1];
     uint32_t offset = 0;
 
+    std::unique_ptr<char[]> uniqueData1 = std::make_unique<char[]>(dataLength1);
+    unsigned char* data1 = reinterpret_cast<unsigned char*>(uniqueData1.get());
+
     WriteUint32(data1, offset, period1);
     offset += sizeOfUint32;
     WriteUint16(data1, offset, 4000);
     offset += sizeOfUint16;
     WriteUint16(data1, offset, 5000);
 
-    Packet packetA(packetId, dataLength1, reinterpret_cast<const char*>(data1));
+    Packet packetA(packetId, dataLength1, uniqueData1);
 
     PeriodicCounterSelectionCommandHandler commandHandler(packetId, version, holder, captureThread,
                                                           sendCounterPacket);
@@ -611,11 +639,12 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
     // Data with period only
     uint32_t period2 = 11;
     uint32_t dataLength2 = 4;
-    unsigned char data2[dataLength2];
 
-    WriteUint32(data2, 0, period2);
+    std::unique_ptr<char[]> uniqueData2 = std::make_unique<char[]>(dataLength2);
+
+    WriteUint32(reinterpret_cast<unsigned char*>(uniqueData2.get()), 0, period2);
 
-    Packet packetB(packetId, dataLength2, reinterpret_cast<const char*>(data2));
+    Packet packetB(packetId, dataLength2, uniqueData2);
 
     commandHandler(packetB);