2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
9 #include <IProfilingConnectionFactory.hpp>
10 #include <IProfilingServiceStatus.hpp>
11 #include <ProfilingService.hpp>
12 #include <ProfilingGuidGenerator.hpp>
13 #include <ProfilingUtils.hpp>
14 #include <SendCounterPacket.hpp>
15 #include <SendThread.hpp>
17 #include <armnn/Exceptions.hpp>
18 #include <armnn/Optional.hpp>
19 #include <armnn/Conversion.hpp>
20 #include <armnn/utility/Assert.hpp>
21 #include <armnn/utility/IgnoreUnused.hpp>
22 #include <armnn/utility/NumericCast.hpp>
25 #include <condition_variable>
35 class MockProfilingConnection : public IProfilingConnection
38 MockProfilingConnection()
47 ConnectionAcknowledge,
50 PeriodicCounterSelection,
51 PerJobCounterSelection,
52 TimelineMessageDirectory,
53 PeriodicCounterCapture,
54 ActivateTimelineReporting,
55 DeactivateTimelineReporting,
59 bool IsOpen() const override
61 std::lock_guard<std::mutex> lock(m_Mutex);
68 std::lock_guard<std::mutex> lock(m_Mutex);
73 bool WritePacket(const unsigned char* buffer, uint32_t length) override
75 if (buffer == nullptr || length == 0)
80 uint32_t header = ReadUint32(buffer, 0);
82 uint32_t packetFamily = (header >> 26);
83 uint32_t packetId = ((header >> 16) & 1023);
85 PacketType packetType;
90 packetType = packetId < 8 ? PacketType(packetId) : PacketType::Unknown;
93 packetType = packetId == 0 ? PacketType::TimelineMessageDirectory : PacketType::Unknown;
96 packetType = packetId == 0 ? PacketType::PeriodicCounterCapture : PacketType::Unknown;
99 packetType = PacketType::Unknown;
102 std::lock_guard<std::mutex> lock(m_Mutex);
104 m_WrittenData.push_back({ packetType, length });
108 long CheckForPacket(const std::pair<PacketType, uint32_t> packetInfo)
110 std::lock_guard<std::mutex> lock(m_Mutex);
112 if(packetInfo.second != 0)
114 return static_cast<long>(std::count(m_WrittenData.begin(), m_WrittenData.end(), packetInfo));
118 return static_cast<long>(std::count_if(m_WrittenData.begin(), m_WrittenData.end(),
119 [&packetInfo](const std::pair<PacketType, uint32_t> pair) { return packetInfo.first == pair.first; }));
123 bool WritePacket(arm::pipe::Packet&& packet)
125 std::lock_guard<std::mutex> lock(m_Mutex);
127 m_Packet = std::move(packet);
131 arm::pipe::Packet ReadPacket(uint32_t timeout) override
133 IgnoreUnused(timeout);
135 // Simulate a delay in the reading process. The default timeout is way too long.
136 std::this_thread::sleep_for(std::chrono::milliseconds(5));
137 std::lock_guard<std::mutex> lock(m_Mutex);
138 return std::move(m_Packet);
141 unsigned long GetWrittenDataSize()
143 std::lock_guard<std::mutex> lock(m_Mutex);
145 return static_cast<unsigned long>(m_WrittenData.size());
150 std::lock_guard<std::mutex> lock(m_Mutex);
152 m_WrittenData.clear();
157 std::vector<std::pair<PacketType, uint32_t>> m_WrittenData;
158 arm::pipe::Packet m_Packet;
159 mutable std::mutex m_Mutex;
162 class MockProfilingConnectionFactory : public IProfilingConnectionFactory
165 IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
167 IgnoreUnused(options);
168 return std::make_unique<MockProfilingConnection>();
172 class MockPacketBuffer : public IPacketBuffer
175 MockPacketBuffer(unsigned int maxSize)
178 , m_Data(std::make_unique<unsigned char[]>(m_MaxSize))
181 ~MockPacketBuffer() {}
183 const unsigned char* GetReadableData() const override { return m_Data.get(); }
185 unsigned int GetSize() const override { return m_Size; }
187 void MarkRead() override { m_Size = 0; }
189 void Commit(unsigned int size) override { m_Size = size; }
191 void Release() override { m_Size = 0; }
193 unsigned char* GetWritableData() override { return m_Data.get(); }
195 void Destroy() override {m_Data.reset(nullptr); m_Size = 0; m_MaxSize =0;}
198 unsigned int m_MaxSize;
200 std::unique_ptr<unsigned char[]> m_Data;
203 class MockBufferManager : public IBufferManager
206 MockBufferManager(unsigned int size)
207 : m_BufferSize(size),
208 m_Buffer(std::make_unique<MockPacketBuffer>(size)) {}
210 ~MockBufferManager() {}
212 IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
214 if (requestedSize > m_BufferSize)
216 reservedSize = m_BufferSize;
220 reservedSize = requestedSize;
223 return std::move(m_Buffer);
226 void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
228 packetBuffer->Commit(size);
229 m_Buffer = std::move(packetBuffer);
237 IPacketBufferPtr GetReadableBuffer() override
239 return std::move(m_Buffer);
242 void Release(IPacketBufferPtr& packetBuffer) override
244 packetBuffer->Release();
245 m_Buffer = std::move(packetBuffer);
248 void MarkRead(IPacketBufferPtr& packetBuffer) override
250 packetBuffer->MarkRead();
251 m_Buffer = std::move(packetBuffer);
254 void SetConsumer(IConsumer* consumer) override
256 if (consumer != nullptr)
258 m_Consumer = consumer;
262 void FlushReadList() override
264 // notify consumer that packet is ready to read
265 if (m_Consumer != nullptr)
267 m_Consumer->SetReadyToRead();
272 unsigned int m_BufferSize;
273 IPacketBufferPtr m_Buffer;
274 IConsumer* m_Consumer = nullptr;
277 class MockStreamCounterBuffer : public IBufferManager
280 MockStreamCounterBuffer(unsigned int maxBufferSize = 4096)
281 : m_MaxBufferSize(maxBufferSize)
287 ~MockStreamCounterBuffer() {}
289 IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
291 std::lock_guard<std::mutex> lock(m_Mutex);
294 if (requestedSize > m_MaxBufferSize)
296 throw armnn::InvalidArgumentException("The maximum buffer size that can be requested is [" +
297 std::to_string(m_MaxBufferSize) + "] bytes");
299 reservedSize = requestedSize;
300 return std::make_unique<MockPacketBuffer>(requestedSize);
303 void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
305 std::lock_guard<std::mutex> lock(m_Mutex);
307 packetBuffer->Commit(size);
308 m_BufferList.push_back(std::move(packetBuffer));
309 m_CommittedSize += size;
317 void Release(IPacketBufferPtr& packetBuffer) override
319 std::lock_guard<std::mutex> lock(m_Mutex);
321 packetBuffer->Release();
324 IPacketBufferPtr GetReadableBuffer() override
326 std::lock_guard<std::mutex> lock(m_Mutex);
328 if (m_BufferList.empty())
332 IPacketBufferPtr buffer = std::move(m_BufferList.back());
333 m_BufferList.pop_back();
334 m_ReadableSize += buffer->GetSize();
338 void MarkRead(IPacketBufferPtr& packetBuffer) override
340 std::lock_guard<std::mutex> lock(m_Mutex);
342 m_ReadSize += packetBuffer->GetSize();
343 packetBuffer->MarkRead();
346 void SetConsumer(IConsumer* consumer) override
348 if (consumer != nullptr)
350 m_Consumer = consumer;
354 void FlushReadList() override
356 // notify consumer that packet is ready to read
357 if (m_Consumer != nullptr)
359 m_Consumer->SetReadyToRead();
363 unsigned int GetCommittedSize() const { return m_CommittedSize; }
364 unsigned int GetReadableSize() const { return m_ReadableSize; }
365 unsigned int GetReadSize() const { return m_ReadSize; }
368 // The maximum buffer size when creating a new buffer
369 unsigned int m_MaxBufferSize;
372 std::vector<IPacketBufferPtr> m_BufferList;
374 // The mutex to synchronize this mock's methods
377 // The total size of the buffers that has been committed for reading
378 unsigned int m_CommittedSize;
380 // The total size of the buffers that can be read
381 unsigned int m_ReadableSize;
383 // The total size of the buffers that has already been read
384 unsigned int m_ReadSize;
386 // Consumer thread to notify packet is ready to read
387 IConsumer* m_Consumer = nullptr;
390 class MockSendCounterPacket : public ISendCounterPacket
393 MockSendCounterPacket(IBufferManager& sendBuffer) : m_BufferManager(sendBuffer) {}
395 void SendStreamMetaDataPacket() override
397 std::string message("SendStreamMetaDataPacket");
398 unsigned int reserved = 0;
399 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
400 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
401 m_BufferManager.Commit(buffer, reserved, false);
404 void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory) override
406 IgnoreUnused(counterDirectory);
408 std::string message("SendCounterDirectoryPacket");
409 unsigned int reserved = 0;
410 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
411 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
412 m_BufferManager.Commit(buffer, reserved);
415 void SendPeriodicCounterCapturePacket(uint64_t timestamp,
416 const std::vector<CounterValue>& values) override
418 IgnoreUnused(timestamp, values);
420 std::string message("SendPeriodicCounterCapturePacket");
421 unsigned int reserved = 0;
422 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
423 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
424 m_BufferManager.Commit(buffer, reserved);
427 void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod,
428 const std::vector<uint16_t>& selectedCounterIds) override
430 IgnoreUnused(capturePeriod, selectedCounterIds);
432 std::string message("SendPeriodicCounterSelectionPacket");
433 unsigned int reserved = 0;
434 IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
435 memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
436 m_BufferManager.Commit(buffer, reserved);
440 IBufferManager& m_BufferManager;
443 class MockCounterDirectory : public ICounterDirectory
446 MockCounterDirectory() = default;
447 ~MockCounterDirectory() = default;
449 // Register profiling objects
450 const Category* RegisterCategory(const std::string& categoryName)
452 // Create the category
453 CategoryPtr category = std::make_unique<Category>(categoryName);
454 ARMNN_ASSERT(category);
456 // Get the raw category pointer
457 const Category* categoryPtr = category.get();
458 ARMNN_ASSERT(categoryPtr);
460 // Register the category
461 m_Categories.insert(std::move(category));
466 const Device* RegisterDevice(const std::string& deviceName,
469 // Get the device UID
470 uint16_t deviceUid = GetNextUid();
473 DevicePtr device = std::make_unique<Device>(deviceUid, deviceName, cores);
474 ARMNN_ASSERT(device);
476 // Get the raw device pointer
477 const Device* devicePtr = device.get();
478 ARMNN_ASSERT(devicePtr);
480 // Register the device
481 m_Devices.insert(std::make_pair(deviceUid, std::move(device)));
486 const CounterSet* RegisterCounterSet(
487 const std::string& counterSetName,
490 // Get the counter set UID
491 uint16_t counterSetUid = GetNextUid();
493 // Create the counter set
494 CounterSetPtr counterSet = std::make_unique<CounterSet>(counterSetUid, counterSetName, count);
495 ARMNN_ASSERT(counterSet);
497 // Get the raw counter set pointer
498 const CounterSet* counterSetPtr = counterSet.get();
499 ARMNN_ASSERT(counterSetPtr);
501 // Register the counter set
502 m_CounterSets.insert(std::make_pair(counterSetUid, std::move(counterSet)));
504 return counterSetPtr;
507 const Counter* RegisterCounter(const BackendId& backendId,
509 const std::string& parentCategoryName,
510 uint16_t counterClass,
511 uint16_t interpolation,
513 const std::string& name,
514 const std::string& description,
515 const armnn::Optional<std::string>& units = armnn::EmptyOptional(),
516 const armnn::Optional<uint16_t>& numberOfCores = armnn::EmptyOptional(),
517 const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
518 const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
520 IgnoreUnused(backendId);
522 // Get the number of cores from the argument only
523 uint16_t deviceCores = numberOfCores.has_value() ? numberOfCores.value() : 0;
525 // Get the device UID
526 uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
528 // Get the counter set UID
529 uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
531 // Get the counter UIDs and calculate the max counter UID
532 std::vector<uint16_t> counterUids = GetNextCounterUids(uid, deviceCores);
533 ARMNN_ASSERT(!counterUids.empty());
534 uint16_t maxCounterUid = deviceCores <= 1 ? counterUids.front() : counterUids.back();
536 // Get the counter units
537 const std::string unitsValue = units.has_value() ? units.value() : "";
539 // Create the counter
540 CounterPtr counter = std::make_shared<Counter>(armnn::profiling::BACKEND_ID,
551 ARMNN_ASSERT(counter);
553 // Get the raw counter pointer
554 const Counter* counterPtr = counter.get();
555 ARMNN_ASSERT(counterPtr);
557 // Process multiple counters if necessary
558 for (uint16_t counterUid : counterUids)
560 // Connect the counter to the parent category
561 Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName));
562 ARMNN_ASSERT(parentCategory);
563 parentCategory->m_Counters.push_back(counterUid);
565 // Register the counter
566 m_Counters.insert(std::make_pair(counterUid, counter));
572 // Getters for counts
573 uint16_t GetCategoryCount() const override { return armnn::numeric_cast<uint16_t>(m_Categories.size()); }
574 uint16_t GetDeviceCount() const override { return armnn::numeric_cast<uint16_t>(m_Devices.size()); }
575 uint16_t GetCounterSetCount() const override { return armnn::numeric_cast<uint16_t>(m_CounterSets.size()); }
576 uint16_t GetCounterCount() const override { return armnn::numeric_cast<uint16_t>(m_Counters.size()); }
578 // Getters for collections
579 const Categories& GetCategories() const override { return m_Categories; }
580 const Devices& GetDevices() const override { return m_Devices; }
581 const CounterSets& GetCounterSets() const override { return m_CounterSets; }
582 const Counters& GetCounters() const override { return m_Counters; }
584 // Getters for profiling objects
585 const Category* GetCategory(const std::string& name) const override
587 auto it = std::find_if(m_Categories.begin(), m_Categories.end(), [&name](const CategoryPtr& category)
589 ARMNN_ASSERT(category);
591 return category->m_Name == name;
594 if (it == m_Categories.end())
602 const Device* GetDevice(uint16_t uid) const override
605 return nullptr; // Not used by the unit tests
608 const CounterSet* GetCounterSet(uint16_t uid) const override
611 return nullptr; // Not used by the unit tests
614 const Counter* GetCounter(uint16_t uid) const override
617 return nullptr; // Not used by the unit tests
621 Categories m_Categories;
623 CounterSets m_CounterSets;
627 class MockProfilingService : public ProfilingService
630 MockProfilingService(MockBufferManager& mockBufferManager,
631 bool isProfilingEnabled,
632 const CaptureData& captureData) :
633 m_SendCounterPacket(mockBufferManager),
634 m_IsProfilingEnabled(isProfilingEnabled),
635 m_CaptureData(captureData)
638 /// Return the next random Guid in the sequence
639 ProfilingDynamicGuid NextGuid() override
641 return m_GuidGenerator.NextGuid();
644 /// Create a ProfilingStaticGuid based on a hash of the string
645 ProfilingStaticGuid GenerateStaticId(const std::string& str) override
647 return m_GuidGenerator.GenerateStaticId(str);
650 std::unique_ptr<ISendTimelinePacket> GetSendTimelinePacket() const override
655 const ICounterMappings& GetCounterMappings() const override
657 return m_CounterMapping;
660 ISendCounterPacket& GetSendCounterPacket() override
662 return m_SendCounterPacket;
665 bool IsProfilingEnabled() const override
667 return m_IsProfilingEnabled;
670 CaptureData GetCaptureData() override
672 CaptureData copy(m_CaptureData);
676 void RegisterMapping(uint16_t globalCounterId,
677 uint16_t backendCounterId,
678 const armnn::BackendId& backendId)
680 m_CounterMapping.RegisterMapping(globalCounterId, backendCounterId, backendId);
685 m_CounterMapping.Reset();
689 ProfilingGuidGenerator m_GuidGenerator;
690 CounterIdMap m_CounterMapping;
691 SendCounterPacket m_SendCounterPacket;
692 bool m_IsProfilingEnabled;
693 CaptureData m_CaptureData;
696 class MockProfilingServiceStatus : public IProfilingServiceStatus
699 void NotifyProfilingServiceActive() override {}
700 void WaitForProfilingServiceActivation(unsigned int timeout) override { IgnoreUnused(timeout); }
703 } // namespace profiling