class Packet
{
public:
- Packet(uint32_t header, uint32_t length, const char* data)
- : m_Header(header),
- m_Length(length),
- m_Data(data)
+ Packet(uint32_t header, uint32_t length, std::unique_ptr<char[]>& data)
+ : m_Header(header), m_Length(length), m_Data(std::move(data))
{
m_PacketId = ((header >> 16) & 1023);
m_PacketFamily = (header >> 26);
}
}
+ Packet(Packet&& other) :
+ m_Header(other.m_Header),
+ m_PacketFamily(other.m_PacketFamily),
+ m_PacketId(other.m_PacketId),
+ m_Length(other.m_Length),
+ m_Data(std::move(other.m_Data)){};
+
+ Packet(const Packet& other) = delete;
+ Packet& operator=(const Packet&) = delete;
+
uint32_t GetHeader() const;
uint32_t GetPacketFamily() const;
uint32_t GetPacketId() const;
uint32_t GetLength() const;
- const char* GetData() const;
+ const char* const GetData() const;
uint32_t GetPacketClass() const;
uint32_t GetPacketType() const;
uint32_t m_PacketFamily;
uint32_t m_PacketId;
uint32_t m_Length;
- const char* m_Data;
+ std::unique_ptr<char[]> m_Data;
};
} // namespace profiling
bool SocketProfilingConnection::IsOpen()
{
- // Dummy return value, function not implemented
- return true;
+ if (m_Socket[0].fd > 0)
+ {
+ return true;
+ }
+ return false;
}
void SocketProfilingConnection::Close()
{
- // Function not implemented
+ if (0 == close(m_Socket[0].fd))
+ {
+ memset(m_Socket, 0, sizeof(m_Socket));
+ }
+ else
+ {
+ throw armnn::Exception(std::string(": Cannot close stream socket: ") + strerror(errno));
+ }
}
bool SocketProfilingConnection::WritePacket(const char* buffer, uint32_t length)
{
- // Dummy return value, function not implemented
+ if (-1 == write(m_Socket[0].fd, buffer, length))
+ {
+ return false;
+ }
return true;
}
Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
{
- // Dummy return value, function not implemented
- return {472580096, 0, nullptr};
+ // Poll for data on the socket or until timeout.
+ int pollResult = poll(m_Socket, 1, static_cast<int>(timeout));
+ if (pollResult > 0)
+ {
+ // Normal poll return but it could still contain an error signal.
+ if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP))
+ {
+ throw armnn::Exception(std::string(": Read failure from socket: ") + strerror(errno));
+ }
+ else if (m_Socket[0].revents & (POLLIN)) // There is data to read.
+ {
+ // Read the header first.
+ char header[8];
+ if (8 != recv(m_Socket[0].fd, &header, sizeof header, 0))
+ {
+ // What do we do here if there's not a valid 8 byte header to read?
+ throw armnn::Exception(": Received packet did not contains a valid MIPE header. ");
+ }
+ // stream_metadata_identifier is the first 4 bytes.
+ uint32_t metadataIdentifier = static_cast<uint32_t>(header[0]) << 24 |
+ static_cast<uint32_t>(header[1]) << 16 |
+ static_cast<uint32_t>(header[2]) << 8 |
+ static_cast<uint32_t>(header[3]);
+ // data_length is the next 4 bytes.
+ uint32_t dataLength = static_cast<uint32_t>(header[4]) << 24 |
+ static_cast<uint32_t>(header[5]) << 16 |
+ static_cast<uint32_t>(header[6]) << 8 |
+ static_cast<uint32_t>(header[7]);
+
+ std::unique_ptr<char[]> packetData;
+ if (dataLength > 0)
+ {
+ packetData = std::make_unique<char[]>(dataLength);
+ }
+
+ if (dataLength != recv(m_Socket[0].fd, packetData.get(), dataLength, 0))
+ {
+ // What do we do here if we can't read in a full packet?
+ throw armnn::Exception(": Invalid MIPE packet.");
+ }
+ return {metadataIdentifier, dataLength, packetData};
+ }
+ else // Some unknown return signal.
+ {
+ throw armnn::Exception(": Poll returned an unexpected event." );
+ }
+ }
+ else if (pollResult == -1)
+ {
+ throw armnn::Exception(std::string(": Read failure from socket: ") + strerror(errno));
+ }
+ else // it's 0 so a timeout.
+ {
+ throw armnn::Exception(": Timeout while reading from socket.");
+ }
}
} // namespace profiling
BOOST_AUTO_TEST_CASE(CheckPacketClass)
{
- const char* data = "test";
- unsigned int length = static_cast<unsigned int>(std::strlen(data));
-
- Packet packetTest1(472580096,length,data);
- BOOST_CHECK_THROW(Packet packetTest2(472580096,0,""), armnn::Exception);
-
- Packet packetTest3(472580096,0, nullptr);
-
- BOOST_CHECK(packetTest1.GetLength() == length);
- BOOST_CHECK(packetTest1.GetData() == data);
-
- BOOST_CHECK(packetTest1.GetPacketFamily() == 7);
- BOOST_CHECK(packetTest1.GetPacketId() == 43);
- BOOST_CHECK(packetTest1.GetPacketType() == 3);
- BOOST_CHECK(packetTest1.GetPacketClass() == 5);
+ uint32_t length = 4;
+ std::unique_ptr<char[]> packetData0 = std::make_unique<char[]>(length);
+ std::unique_ptr<char[]> packetData1 = std::make_unique<char[]>(0);
+ std::unique_ptr<char[]> nullPacketData;
+
+ Packet packetTest0(472580096, length, packetData0);
+
+ BOOST_CHECK(packetTest0.GetHeader() == 472580096);
+ BOOST_CHECK(packetTest0.GetPacketFamily() == 7);
+ BOOST_CHECK(packetTest0.GetPacketId() == 43);
+ BOOST_CHECK(packetTest0.GetLength() == length);
+ BOOST_CHECK(packetTest0.GetPacketType() == 3);
+ BOOST_CHECK(packetTest0.GetPacketClass() == 5);
+
+ BOOST_CHECK_THROW(Packet packetTest1(472580096, 0, packetData1), armnn::Exception);
+ BOOST_CHECK_NO_THROW(Packet packetTest2(472580096, 0, nullPacketData));
+
+ Packet packetTest3(472580096, 0, nullPacketData);
+ BOOST_CHECK(packetTest3.GetLength() == 0);
+ BOOST_CHECK(packetTest3.GetData() == nullptr);
+
+ const char* packetTest0Data = packetTest0.GetData();
+ Packet packetTest4(std::move(packetTest0));
+
+ BOOST_CHECK(packetTest0.GetData() == nullptr);
+ BOOST_CHECK(packetTest4.GetData() == packetTest0Data);
+
+ BOOST_CHECK(packetTest4.GetHeader() == 472580096);
+ BOOST_CHECK(packetTest4.GetPacketFamily() == 7);
+ BOOST_CHECK(packetTest4.GetPacketId() == 43);
+ BOOST_CHECK(packetTest4.GetLength() == length);
+ BOOST_CHECK(packetTest4.GetPacketType() == 3);
+ BOOST_CHECK(packetTest4.GetPacketClass() == 5);
}
// Create Derived Classes
it++;
BOOST_CHECK(it->first==keyC);
- Packet packetA(500000000, 0, nullptr);
- Packet packetB(600000000, 0, nullptr);
- Packet packetC(400000000, 0, nullptr);
+ std::unique_ptr<char[]> packetDataA;
+ std::unique_ptr<char[]> packetDataB;
+ std::unique_ptr<char[]> packetDataC;
+
+ Packet packetA(500000000, 0, packetDataA);
+ Packet packetB(600000000, 0, packetDataB);
+ Packet packetC(400000000, 0, packetDataC);
// Check the correct operator of derived class is called
registry.at(CommandHandlerKey(packetA.GetPacketId(), version))->operator()(packetA);
registry.RegisterFunctor(&testFunctorB, testFunctorB.GetPacketId(), testFunctorB.GetVersion());
registry.RegisterFunctor(&testFunctorC, testFunctorC.GetPacketId(), testFunctorC.GetVersion());
- Packet packetA(500000000, 0, nullptr);
- Packet packetB(600000000, 0, nullptr);
- Packet packetC(400000000, 0, nullptr);
+ std::unique_ptr<char[]> packetDataA;
+ std::unique_ptr<char[]> packetDataB;
+ std::unique_ptr<char[]> packetDataC;
+
+ Packet packetA(500000000, 0, packetDataA);
+ Packet packetB(600000000, 0, packetDataB);
+ Packet packetC(400000000, 0, packetDataC);
// Check the correct operator of derived class is called
registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetA);
// Data with period and counters
uint32_t period1 = 10;
uint32_t dataLength1 = 8;
- unsigned char data1[dataLength1];
uint32_t offset = 0;
+ std::unique_ptr<char[]> uniqueData1 = std::make_unique<char[]>(dataLength1);
+ unsigned char* data1 = reinterpret_cast<unsigned char*>(uniqueData1.get());
+
WriteUint32(data1, offset, period1);
offset += sizeOfUint32;
WriteUint16(data1, offset, 4000);
offset += sizeOfUint16;
WriteUint16(data1, offset, 5000);
- Packet packetA(packetId, dataLength1, reinterpret_cast<const char*>(data1));
+ Packet packetA(packetId, dataLength1, uniqueData1);
PeriodicCounterSelectionCommandHandler commandHandler(packetId, version, holder, captureThread,
sendCounterPacket);
// Data with period only
uint32_t period2 = 11;
uint32_t dataLength2 = 4;
- unsigned char data2[dataLength2];
- WriteUint32(data2, 0, period2);
+ std::unique_ptr<char[]> uniqueData2 = std::make_unique<char[]>(dataLength2);
+
+ WriteUint32(reinterpret_cast<unsigned char*>(uniqueData2.get()), 0, period2);
- Packet packetB(packetId, dataLength2, reinterpret_cast<const char*>(data2));
+ Packet packetB(packetId, dataLength2, uniqueData2);
commandHandler(packetB);