2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
8 #include <SendCounterPacket.hpp>
9 #include <ProfilingUtils.hpp>
11 #include <armnn/Exceptions.hpp>
12 #include <armnn/Optional.hpp>
13 #include <armnn/Conversion.hpp>
15 #include <boost/numeric/conversion/cast.hpp>
23 class MockProfilingConnection : public IProfilingConnection
26 MockProfilingConnection()
30 bool IsOpen() override { return m_IsOpen; }
32 void Close() override { m_IsOpen = false; }
34 bool WritePacket(const unsigned char* buffer, uint32_t length) override
36 return buffer != nullptr && length > 0;
39 Packet ReadPacket(uint32_t timeout) override { return Packet(); }
45 class MockPacketBuffer : public IPacketBuffer
48 MockPacketBuffer(unsigned int maxSize)
51 , m_Data(std::make_unique<unsigned char[]>(m_MaxSize))
54 ~MockPacketBuffer() {}
56 const unsigned char* const GetReadableData() const override { return m_Data.get(); }
58 unsigned int GetSize() const override { return m_Size; }
60 void MarkRead() override { m_Size = 0; }
62 void Commit(unsigned int size) override { m_Size = size; }
64 void Release() override { m_Size = 0; }
66 unsigned char* GetWritableData() override { return m_Data.get(); }
69 unsigned int m_MaxSize;
71 std::unique_ptr<unsigned char[]> m_Data;
74 class MockBufferManager : public IBufferManager
77 MockBufferManager(unsigned int size)
79 m_Buffer(std::make_unique<MockPacketBuffer>(size)) {}
81 ~MockBufferManager() {}
83 std::unique_ptr<IPacketBuffer> Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
85 if (requestedSize > m_BufferSize)
87 reservedSize = m_BufferSize;
91 reservedSize = requestedSize;
94 return std::move(m_Buffer);
97 void Commit(std::unique_ptr<IPacketBuffer>& packetBuffer, unsigned int size) override
99 packetBuffer->Commit(size);
100 m_Buffer = std::move(packetBuffer);
103 std::unique_ptr<IPacketBuffer> GetReadableBuffer() override
105 return std::move(m_Buffer);
108 void Release(std::unique_ptr<IPacketBuffer>& packetBuffer) override
110 packetBuffer->Release();
111 m_Buffer = std::move(packetBuffer);
114 void MarkRead(std::unique_ptr<IPacketBuffer>& packetBuffer) override
116 packetBuffer->MarkRead();
117 m_Buffer = std::move(packetBuffer);
121 unsigned int m_BufferSize;
122 std::unique_ptr<IPacketBuffer> m_Buffer;
125 class MockStreamCounterBuffer : public IBufferManager
128 using IPacketBufferPtr = std::unique_ptr<IPacketBuffer>;
130 MockStreamCounterBuffer(unsigned int maxBufferSize = 4096)
131 : m_MaxBufferSize(maxBufferSize)
137 ~MockStreamCounterBuffer() {}
139 IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
141 std::unique_lock<std::mutex> lock(m_Mutex);
144 if (requestedSize > m_MaxBufferSize)
146 throw armnn::InvalidArgumentException("The maximum buffer size that can be requested is [" +
147 std::to_string(m_MaxBufferSize) + "] bytes");
149 reservedSize = requestedSize;
150 return std::make_unique<MockPacketBuffer>(requestedSize);
153 void Commit(IPacketBufferPtr& packetBuffer, unsigned int size) override
155 std::unique_lock<std::mutex> lock(m_Mutex);
157 packetBuffer->Commit(size);
158 m_BufferList.push_back(std::move(packetBuffer));
159 m_CommittedSize += size;
162 void Release(IPacketBufferPtr& packetBuffer) override
164 std::unique_lock<std::mutex> lock(m_Mutex);
166 packetBuffer->Release();
169 IPacketBufferPtr GetReadableBuffer() override
171 std::unique_lock<std::mutex> lock(m_Mutex);
173 if (m_BufferList.empty())
177 IPacketBufferPtr buffer = std::move(m_BufferList.back());
178 m_BufferList.pop_back();
179 m_ReadableSize += buffer->GetSize();
183 void MarkRead(IPacketBufferPtr& packetBuffer) override
185 std::unique_lock<std::mutex> lock(m_Mutex);
187 m_ReadSize += packetBuffer->GetSize();
188 packetBuffer->MarkRead();
191 unsigned int GetCommittedSize() const { return m_CommittedSize; }
192 unsigned int GetReadableSize() const { return m_ReadableSize; }
193 unsigned int GetReadSize() const { return m_ReadSize; }
196 // The maximum buffer size when creating a new buffer
197 unsigned int m_MaxBufferSize;
200 std::vector<IPacketBufferPtr> m_BufferList;
202 // The mutex to synchronize this mock's methods
205 // The total size of the buffers that has been committed for reading
206 unsigned int m_CommittedSize;
208 // The total size of the buffers that can be read
209 unsigned int m_ReadableSize;
211 // The total size of the buffers that has already been read
212 unsigned int m_ReadSize;
215 class MockSendCounterPacket : public ISendCounterPacket
218 MockSendCounterPacket(IBufferManager& sendBuffer) : m_BufferManager(sendBuffer) {}
220 void SendStreamMetaDataPacket() override
222 std::string message("SendStreamMetaDataPacket");
223 unsigned int reserved = 0;
224 std::unique_ptr<IPacketBuffer> buffer = m_BufferManager.Reserve(1024, reserved);
225 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
226 m_BufferManager.Commit(buffer, reserved);
229 void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory) override
231 std::string message("SendCounterDirectoryPacket");
232 unsigned int reserved = 0;
233 std::unique_ptr<IPacketBuffer> buffer = m_BufferManager.Reserve(1024, reserved);
234 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
235 m_BufferManager.Commit(buffer, reserved);
238 void SendPeriodicCounterCapturePacket(uint64_t timestamp,
239 const std::vector<std::pair<uint16_t, uint32_t>>& values) override
241 std::string message("SendPeriodicCounterCapturePacket");
242 unsigned int reserved = 0;
243 std::unique_ptr<IPacketBuffer> buffer = m_BufferManager.Reserve(1024, reserved);
244 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
245 m_BufferManager.Commit(buffer, reserved);
248 void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod,
249 const std::vector<uint16_t>& selectedCounterIds) override
251 std::string message("SendPeriodicCounterSelectionPacket");
252 unsigned int reserved = 0;
253 std::unique_ptr<IPacketBuffer> buffer = m_BufferManager.Reserve(1024, reserved);
254 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
255 m_BufferManager.Commit(buffer, reserved);
258 void SetReadyToRead() override {}
261 IBufferManager& m_BufferManager;
264 class MockCounterDirectory : public ICounterDirectory
267 MockCounterDirectory() = default;
268 ~MockCounterDirectory() = default;
270 // Register profiling objects
271 const Category* RegisterCategory(const std::string& categoryName,
272 const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
273 const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
275 // Get the device UID
276 uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
278 // Get the counter set UID
279 uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
281 // Create the category
282 CategoryPtr category = std::make_unique<Category>(categoryName, deviceUidValue, counterSetUidValue);
283 BOOST_ASSERT(category);
285 // Get the raw category pointer
286 const Category* categoryPtr = category.get();
287 BOOST_ASSERT(categoryPtr);
289 // Register the category
290 m_Categories.insert(std::move(category));
295 const Device* RegisterDevice(const std::string& deviceName,
297 const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
299 // Get the device UID
300 uint16_t deviceUid = GetNextUid();
303 DevicePtr device = std::make_unique<Device>(deviceUid, deviceName, cores);
304 BOOST_ASSERT(device);
306 // Get the raw device pointer
307 const Device* devicePtr = device.get();
308 BOOST_ASSERT(devicePtr);
310 // Register the device
311 m_Devices.insert(std::make_pair(deviceUid, std::move(device)));
313 // Connect the counter set to the parent category, if required
314 if (parentCategoryName.has_value())
316 // Set the counter set UID in the parent category
317 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
318 BOOST_ASSERT(parentCategory);
319 parentCategory->m_DeviceUid = deviceUid;
325 const CounterSet* RegisterCounterSet(
326 const std::string& counterSetName,
328 const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
330 // Get the counter set UID
331 uint16_t counterSetUid = GetNextUid();
333 // Create the counter set
334 CounterSetPtr counterSet = std::make_unique<CounterSet>(counterSetUid, counterSetName, count);
335 BOOST_ASSERT(counterSet);
337 // Get the raw counter set pointer
338 const CounterSet* counterSetPtr = counterSet.get();
339 BOOST_ASSERT(counterSetPtr);
341 // Register the counter set
342 m_CounterSets.insert(std::make_pair(counterSetUid, std::move(counterSet)));
344 // Connect the counter set to the parent category, if required
345 if (parentCategoryName.has_value())
347 // Set the counter set UID in the parent category
348 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
349 BOOST_ASSERT(parentCategory);
350 parentCategory->m_CounterSetUid = counterSetUid;
353 return counterSetPtr;
356 const Counter* RegisterCounter(const std::string& parentCategoryName,
357 uint16_t counterClass,
358 uint16_t interpolation,
360 const std::string& name,
361 const std::string& description,
362 const armnn::Optional<std::string>& units = armnn::EmptyOptional(),
363 const armnn::Optional<uint16_t>& numberOfCores = armnn::EmptyOptional(),
364 const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
365 const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
367 // Get the number of cores from the argument only
368 uint16_t deviceCores = numberOfCores.has_value() ? numberOfCores.value() : 0;
370 // Get the device UID
371 uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
373 // Get the counter set UID
374 uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
376 // Get the counter UIDs and calculate the max counter UID
377 std::vector<uint16_t> counterUids = GetNextCounterUids(deviceCores);
378 BOOST_ASSERT(!counterUids.empty());
379 uint16_t maxCounterUid = deviceCores <= 1 ? counterUids.front() : counterUids.back();
381 // Get the counter units
382 const std::string unitsValue = units.has_value() ? units.value() : "";
384 // Create the counter
385 CounterPtr counter = std::make_shared<Counter>(counterUids.front(),
395 BOOST_ASSERT(counter);
397 // Get the raw counter pointer
398 const Counter* counterPtr = counter.get();
399 BOOST_ASSERT(counterPtr);
401 // Process multiple counters if necessary
402 for (uint16_t counterUid : counterUids)
404 // Connect the counter to the parent category
405 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName));
406 BOOST_ASSERT(parentCategory);
407 parentCategory->m_Counters.push_back(counterUid);
409 // Register the counter
410 m_Counters.insert(std::make_pair(counterUid, counter));
416 // Getters for counts
417 uint16_t GetCategoryCount() const override { return boost::numeric_cast<uint16_t>(m_Categories.size()); }
418 uint16_t GetDeviceCount() const override { return boost::numeric_cast<uint16_t>(m_Devices.size()); }
419 uint16_t GetCounterSetCount() const override { return boost::numeric_cast<uint16_t>(m_CounterSets.size()); }
420 uint16_t GetCounterCount() const override { return boost::numeric_cast<uint16_t>(m_Counters.size()); }
422 // Getters for collections
423 const Categories& GetCategories() const override { return m_Categories; }
424 const Devices& GetDevices() const override { return m_Devices; }
425 const CounterSets& GetCounterSets() const override { return m_CounterSets; }
426 const Counters& GetCounters() const override { return m_Counters; }
428 // Getters for profiling objects
429 const Category* GetCategory(const std::string& name) const override
431 auto it = std::find_if(m_Categories.begin(), m_Categories.end(), [&name](const CategoryPtr& category)
433 BOOST_ASSERT(category);
435 return category->m_Name == name;
438 if (it == m_Categories.end())
446 const Device* GetDevice(uint16_t uid) const override
448 return nullptr; // Not used by the unit tests
451 const CounterSet* GetCounterSet(uint16_t uid) const override
453 return nullptr; // Not used by the unit tests
456 const Counter* GetCounter(uint16_t uid) const override
458 return nullptr; // Not used by the unit tests
462 Categories m_Categories;
464 CounterSets m_CounterSets;
468 class SendCounterPacketTest : public SendCounterPacket
471 SendCounterPacketTest(IProfilingConnection& profilingconnection, IBufferManager& buffer)
472 : SendCounterPacket(profilingconnection, buffer)
475 bool CreateDeviceRecordTest(const DevicePtr& device,
476 DeviceRecord& deviceRecord,
477 std::string& errorMessage)
479 return CreateDeviceRecord(device, deviceRecord, errorMessage);
482 bool CreateCounterSetRecordTest(const CounterSetPtr& counterSet,
483 CounterSetRecord& counterSetRecord,
484 std::string& errorMessage)
486 return CreateCounterSetRecord(counterSet, counterSetRecord, errorMessage);
489 bool CreateEventRecordTest(const CounterPtr& counter,
490 EventRecord& eventRecord,
491 std::string& errorMessage)
493 return CreateEventRecord(counter, eventRecord, errorMessage);
496 bool CreateCategoryRecordTest(const CategoryPtr& category,
497 const Counters& counters,
498 CategoryRecord& categoryRecord,
499 std::string& errorMessage)
501 return CreateCategoryRecord(category, counters, categoryRecord, errorMessage);
505 } // namespace profiling