IVGCVSW-5301 Remove all boost::numeric_cast from armnn/src/profiling
[platform/upstream/armnn.git] / src / profiling / test / ProfilingMocks.hpp
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <Holder.hpp>
9 #include <IProfilingConnectionFactory.hpp>
10 #include <IProfilingServiceStatus.hpp>
11 #include <ProfilingService.hpp>
12 #include <ProfilingGuidGenerator.hpp>
13 #include <ProfilingUtils.hpp>
14 #include <SendCounterPacket.hpp>
15 #include <SendThread.hpp>
16
17 #include <armnn/Exceptions.hpp>
18 #include <armnn/Optional.hpp>
19 #include <armnn/Conversion.hpp>
20 #include <armnn/utility/Assert.hpp>
21 #include <armnn/utility/IgnoreUnused.hpp>
22 #include <armnn/utility/NumericCast.hpp>
23
24 #include <atomic>
25 #include <condition_variable>
26 #include <mutex>
27 #include <thread>
28
29 namespace armnn
30 {
31
32 namespace profiling
33 {
34
35 class MockProfilingConnection : public IProfilingConnection
36 {
37 public:
38     MockProfilingConnection()
39         : m_IsOpen(true)
40         , m_WrittenData()
41         , m_Packet()
42     {}
43
44     enum class PacketType
45     {
46         StreamMetaData,
47         ConnectionAcknowledge,
48         CounterDirectory,
49         ReqCounterDirectory,
50         PeriodicCounterSelection,
51         PerJobCounterSelection,
52         TimelineMessageDirectory,
53         PeriodicCounterCapture,
54         ActivateTimelineReporting,
55         DeactivateTimelineReporting,
56         Unknown
57     };
58
59     bool IsOpen() const override
60     {
61         std::lock_guard<std::mutex> lock(m_Mutex);
62
63         return m_IsOpen;
64     }
65
66     void Close() override
67     {
68         std::lock_guard<std::mutex> lock(m_Mutex);
69
70         m_IsOpen = false;
71     }
72
73     bool WritePacket(const unsigned char* buffer, uint32_t length) override
74     {
75         if (buffer == nullptr || length == 0)
76         {
77             return false;
78         }
79
80         uint32_t header = ReadUint32(buffer, 0);
81
82         uint32_t packetFamily = (header >> 26);
83         uint32_t packetId = ((header >> 16) & 1023);
84
85         PacketType packetType;
86
87         switch (packetFamily)
88         {
89             case 0:
90                 packetType = packetId < 8 ? PacketType(packetId) : PacketType::Unknown;
91                 break;
92             case 1:
93                 packetType = packetId == 0 ? PacketType::TimelineMessageDirectory : PacketType::Unknown;
94                 break;
95             case 3:
96                 packetType = packetId == 0 ? PacketType::PeriodicCounterCapture : PacketType::Unknown;
97                 break;
98             default:
99                 packetType = PacketType::Unknown;
100         }
101
102         std::lock_guard<std::mutex> lock(m_Mutex);
103
104         m_WrittenData.push_back({ packetType, length });
105         return true;
106     }
107
108     long CheckForPacket(const std::pair<PacketType, uint32_t> packetInfo)
109     {
110         std::lock_guard<std::mutex> lock(m_Mutex);
111
112         if(packetInfo.second != 0)
113         {
114             return static_cast<long>(std::count(m_WrittenData.begin(), m_WrittenData.end(), packetInfo));
115         }
116         else
117         {
118             return static_cast<long>(std::count_if(m_WrittenData.begin(), m_WrittenData.end(),
119             [&packetInfo](const std::pair<PacketType, uint32_t> pair) { return packetInfo.first == pair.first; }));
120         }
121     }
122
123     bool WritePacket(arm::pipe::Packet&& packet)
124     {
125         std::lock_guard<std::mutex> lock(m_Mutex);
126
127         m_Packet = std::move(packet);
128         return true;
129     }
130
131     arm::pipe::Packet ReadPacket(uint32_t timeout) override
132     {
133         IgnoreUnused(timeout);
134
135         // Simulate a delay in the reading process. The default timeout is way too long.
136         std::this_thread::sleep_for(std::chrono::milliseconds(5));
137         std::lock_guard<std::mutex> lock(m_Mutex);
138         return std::move(m_Packet);
139     }
140
141     unsigned long GetWrittenDataSize()
142     {
143         std::lock_guard<std::mutex> lock(m_Mutex);
144
145         return static_cast<unsigned long>(m_WrittenData.size());
146     }
147
148     void Clear()
149     {
150         std::lock_guard<std::mutex> lock(m_Mutex);
151
152         m_WrittenData.clear();
153     }
154
155 private:
156     bool m_IsOpen;
157     std::vector<std::pair<PacketType, uint32_t>> m_WrittenData;
158     arm::pipe::Packet m_Packet;
159     mutable std::mutex m_Mutex;
160 };
161
162 class MockProfilingConnectionFactory : public IProfilingConnectionFactory
163 {
164 public:
165     IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
166     {
167         IgnoreUnused(options);
168         return std::make_unique<MockProfilingConnection>();
169     }
170 };
171
172 class MockPacketBuffer : public IPacketBuffer
173 {
174 public:
175     MockPacketBuffer(unsigned int maxSize)
176         : m_MaxSize(maxSize)
177         , m_Size(0)
178         , m_Data(std::make_unique<unsigned char[]>(m_MaxSize))
179     {}
180
181     ~MockPacketBuffer() {}
182
183     const unsigned char* GetReadableData() const override { return m_Data.get(); }
184
185     unsigned int GetSize() const override { return m_Size; }
186
187     void MarkRead() override { m_Size = 0; }
188
189     void Commit(unsigned int size) override { m_Size = size; }
190
191     void Release() override { m_Size = 0; }
192
193     unsigned char* GetWritableData() override { return m_Data.get(); }
194
195     void Destroy() override {m_Data.reset(nullptr); m_Size = 0; m_MaxSize =0;}
196
197 private:
198     unsigned int m_MaxSize;
199     unsigned int m_Size;
200     std::unique_ptr<unsigned char[]> m_Data;
201 };
202
203 class MockBufferManager : public IBufferManager
204 {
205 public:
206     MockBufferManager(unsigned int size)
207     : m_BufferSize(size),
208       m_Buffer(std::make_unique<MockPacketBuffer>(size)) {}
209
210     ~MockBufferManager() {}
211
212     IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
213     {
214         if (requestedSize > m_BufferSize)
215         {
216             reservedSize = m_BufferSize;
217         }
218         else
219         {
220             reservedSize = requestedSize;
221         }
222
223         return std::move(m_Buffer);
224     }
225
226     void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
227     {
228         packetBuffer->Commit(size);
229         m_Buffer = std::move(packetBuffer);
230
231         if (notifyConsumer)
232         {
233             FlushReadList();
234         }
235     }
236
237     IPacketBufferPtr GetReadableBuffer() override
238     {
239         return std::move(m_Buffer);
240     }
241
242     void Release(IPacketBufferPtr& packetBuffer) override
243     {
244         packetBuffer->Release();
245         m_Buffer = std::move(packetBuffer);
246     }
247
248     void MarkRead(IPacketBufferPtr& packetBuffer) override
249     {
250         packetBuffer->MarkRead();
251         m_Buffer = std::move(packetBuffer);
252     }
253
254     void SetConsumer(IConsumer* consumer) override
255    {
256         if (consumer != nullptr)
257         {
258             m_Consumer = consumer;
259         }
260    }
261
262     void FlushReadList() override
263     {
264         // notify consumer that packet is ready to read
265         if (m_Consumer != nullptr)
266         {
267             m_Consumer->SetReadyToRead();
268         }
269     }
270
271 private:
272     unsigned int m_BufferSize;
273     IPacketBufferPtr m_Buffer;
274     IConsumer* m_Consumer = nullptr;
275 };
276
277 class MockStreamCounterBuffer : public IBufferManager
278 {
279 public:
280     MockStreamCounterBuffer(unsigned int maxBufferSize = 4096)
281         : m_MaxBufferSize(maxBufferSize)
282         , m_BufferList()
283         , m_CommittedSize(0)
284         , m_ReadableSize(0)
285         , m_ReadSize(0)
286     {}
287     ~MockStreamCounterBuffer() {}
288
289     IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
290     {
291         std::lock_guard<std::mutex> lock(m_Mutex);
292
293         reservedSize = 0;
294         if (requestedSize > m_MaxBufferSize)
295         {
296             throw armnn::InvalidArgumentException("The maximum buffer size that can be requested is [" +
297                                                   std::to_string(m_MaxBufferSize) + "] bytes");
298         }
299         reservedSize = requestedSize;
300         return std::make_unique<MockPacketBuffer>(requestedSize);
301     }
302
303     void Commit(IPacketBufferPtr& packetBuffer, unsigned int size, bool notifyConsumer = true) override
304     {
305         std::lock_guard<std::mutex> lock(m_Mutex);
306
307         packetBuffer->Commit(size);
308         m_BufferList.push_back(std::move(packetBuffer));
309         m_CommittedSize += size;
310
311         if (notifyConsumer)
312         {
313             FlushReadList();
314         }
315     }
316
317     void Release(IPacketBufferPtr& packetBuffer) override
318     {
319         std::lock_guard<std::mutex> lock(m_Mutex);
320
321         packetBuffer->Release();
322     }
323
324     IPacketBufferPtr GetReadableBuffer() override
325     {
326         std::lock_guard<std::mutex> lock(m_Mutex);
327
328         if (m_BufferList.empty())
329         {
330             return nullptr;
331         }
332         IPacketBufferPtr buffer = std::move(m_BufferList.back());
333         m_BufferList.pop_back();
334         m_ReadableSize += buffer->GetSize();
335         return buffer;
336     }
337
338     void MarkRead(IPacketBufferPtr& packetBuffer) override
339     {
340         std::lock_guard<std::mutex> lock(m_Mutex);
341
342         m_ReadSize += packetBuffer->GetSize();
343         packetBuffer->MarkRead();
344     }
345
346     void SetConsumer(IConsumer* consumer) override
347     {
348         if (consumer != nullptr)
349         {
350             m_Consumer = consumer;
351         }
352     }
353
354     void FlushReadList() override
355     {
356         // notify consumer that packet is ready to read
357         if (m_Consumer != nullptr)
358         {
359             m_Consumer->SetReadyToRead();
360         }
361     }
362
363     unsigned int GetCommittedSize() const { return m_CommittedSize; }
364     unsigned int GetReadableSize()  const { return m_ReadableSize;  }
365     unsigned int GetReadSize()      const { return m_ReadSize;      }
366
367 private:
368     // The maximum buffer size when creating a new buffer
369     unsigned int m_MaxBufferSize;
370
371     // A list of buffers
372     std::vector<IPacketBufferPtr> m_BufferList;
373
374     // The mutex to synchronize this mock's methods
375     std::mutex m_Mutex;
376
377     // The total size of the buffers that has been committed for reading
378     unsigned int m_CommittedSize;
379
380     // The total size of the buffers that can be read
381     unsigned int m_ReadableSize;
382
383     // The total size of the buffers that has already been read
384     unsigned int m_ReadSize;
385
386     // Consumer thread to notify packet is ready to read
387     IConsumer* m_Consumer = nullptr;
388 };
389
390 class MockSendCounterPacket : public ISendCounterPacket
391 {
392 public:
393     MockSendCounterPacket(IBufferManager& sendBuffer) : m_BufferManager(sendBuffer) {}
394
395     void SendStreamMetaDataPacket() override
396     {
397         std::string message("SendStreamMetaDataPacket");
398         unsigned int reserved = 0;
399         IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
400         memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
401         m_BufferManager.Commit(buffer, reserved, false);
402     }
403
404     void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory) override
405     {
406         IgnoreUnused(counterDirectory);
407
408         std::string message("SendCounterDirectoryPacket");
409         unsigned int reserved = 0;
410         IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
411         memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
412         m_BufferManager.Commit(buffer, reserved);
413     }
414
415     void SendPeriodicCounterCapturePacket(uint64_t timestamp,
416                                           const std::vector<CounterValue>& values) override
417     {
418         IgnoreUnused(timestamp, values);
419
420         std::string message("SendPeriodicCounterCapturePacket");
421         unsigned int reserved = 0;
422         IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
423         memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
424         m_BufferManager.Commit(buffer, reserved);
425     }
426
427     void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod,
428                                             const std::vector<uint16_t>& selectedCounterIds) override
429     {
430         IgnoreUnused(capturePeriod, selectedCounterIds);
431
432         std::string message("SendPeriodicCounterSelectionPacket");
433         unsigned int reserved = 0;
434         IPacketBufferPtr buffer = m_BufferManager.Reserve(1024, reserved);
435         memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
436         m_BufferManager.Commit(buffer, reserved);
437     }
438
439 private:
440     IBufferManager& m_BufferManager;
441 };
442
443 class MockCounterDirectory : public ICounterDirectory
444 {
445 public:
446     MockCounterDirectory() = default;
447     ~MockCounterDirectory() = default;
448
449     // Register profiling objects
450     const Category* RegisterCategory(const std::string& categoryName)
451     {
452         // Create the category
453         CategoryPtr category = std::make_unique<Category>(categoryName);
454         ARMNN_ASSERT(category);
455
456         // Get the raw category pointer
457         const Category* categoryPtr = category.get();
458         ARMNN_ASSERT(categoryPtr);
459
460         // Register the category
461         m_Categories.insert(std::move(category));
462
463         return categoryPtr;
464     }
465
466     const Device* RegisterDevice(const std::string& deviceName,
467                                  uint16_t cores = 0)
468     {
469         // Get the device UID
470         uint16_t deviceUid = GetNextUid();
471
472         // Create the device
473         DevicePtr device = std::make_unique<Device>(deviceUid, deviceName, cores);
474         ARMNN_ASSERT(device);
475
476         // Get the raw device pointer
477         const Device* devicePtr = device.get();
478         ARMNN_ASSERT(devicePtr);
479
480         // Register the device
481         m_Devices.insert(std::make_pair(deviceUid, std::move(device)));
482
483         return devicePtr;
484     }
485
486     const CounterSet* RegisterCounterSet(
487             const std::string& counterSetName,
488             uint16_t count = 0)
489     {
490         // Get the counter set UID
491         uint16_t counterSetUid = GetNextUid();
492
493         // Create the counter set
494         CounterSetPtr counterSet = std::make_unique<CounterSet>(counterSetUid, counterSetName, count);
495         ARMNN_ASSERT(counterSet);
496
497         // Get the raw counter set pointer
498         const CounterSet* counterSetPtr = counterSet.get();
499         ARMNN_ASSERT(counterSetPtr);
500
501         // Register the counter set
502         m_CounterSets.insert(std::make_pair(counterSetUid, std::move(counterSet)));
503
504         return counterSetPtr;
505     }
506
507     const Counter* RegisterCounter(const BackendId& backendId,
508                                    const uint16_t uid,
509                                    const std::string& parentCategoryName,
510                                    uint16_t counterClass,
511                                    uint16_t interpolation,
512                                    double multiplier,
513                                    const std::string& name,
514                                    const std::string& description,
515                                    const armnn::Optional<std::string>& units = armnn::EmptyOptional(),
516                                    const armnn::Optional<uint16_t>& numberOfCores = armnn::EmptyOptional(),
517                                    const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
518                                    const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
519     {
520         IgnoreUnused(backendId);
521
522         // Get the number of cores from the argument only
523         uint16_t deviceCores = numberOfCores.has_value() ? numberOfCores.value() : 0;
524
525         // Get the device UID
526         uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
527
528         // Get the counter set UID
529         uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
530
531         // Get the counter UIDs and calculate the max counter UID
532         std::vector<uint16_t> counterUids = GetNextCounterUids(uid, deviceCores);
533         ARMNN_ASSERT(!counterUids.empty());
534         uint16_t maxCounterUid = deviceCores <= 1 ? counterUids.front() : counterUids.back();
535
536         // Get the counter units
537         const std::string unitsValue = units.has_value() ? units.value() : "";
538
539         // Create the counter
540         CounterPtr counter = std::make_shared<Counter>(armnn::profiling::BACKEND_ID,
541                                                        counterUids.front(),
542                                                        maxCounterUid,
543                                                        counterClass,
544                                                        interpolation,
545                                                        multiplier,
546                                                        name,
547                                                        description,
548                                                        unitsValue,
549                                                        deviceUidValue,
550                                                        counterSetUidValue);
551         ARMNN_ASSERT(counter);
552
553         // Get the raw counter pointer
554         const Counter* counterPtr = counter.get();
555         ARMNN_ASSERT(counterPtr);
556
557         // Process multiple counters if necessary
558         for (uint16_t counterUid : counterUids)
559         {
560             // Connect the counter to the parent category
561             Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName));
562             ARMNN_ASSERT(parentCategory);
563             parentCategory->m_Counters.push_back(counterUid);
564
565             // Register the counter
566             m_Counters.insert(std::make_pair(counterUid, counter));
567         }
568
569         return counterPtr;
570     }
571
572     // Getters for counts
573     uint16_t GetCategoryCount()   const override { return armnn::numeric_cast<uint16_t>(m_Categories.size());  }
574     uint16_t GetDeviceCount()     const override { return armnn::numeric_cast<uint16_t>(m_Devices.size());     }
575     uint16_t GetCounterSetCount() const override { return armnn::numeric_cast<uint16_t>(m_CounterSets.size()); }
576     uint16_t GetCounterCount()    const override { return armnn::numeric_cast<uint16_t>(m_Counters.size());    }
577
578     // Getters for collections
579     const Categories&  GetCategories()  const override { return m_Categories;  }
580     const Devices&     GetDevices()     const override { return m_Devices;     }
581     const CounterSets& GetCounterSets() const override { return m_CounterSets; }
582     const Counters&    GetCounters()    const override { return m_Counters;    }
583
584     // Getters for profiling objects
585     const Category* GetCategory(const std::string& name) const override
586     {
587         auto it = std::find_if(m_Categories.begin(), m_Categories.end(), [&name](const CategoryPtr& category)
588         {
589             ARMNN_ASSERT(category);
590
591             return category->m_Name == name;
592         });
593
594         if (it == m_Categories.end())
595         {
596             return nullptr;
597         }
598
599         return it->get();
600     }
601
602     const Device* GetDevice(uint16_t uid) const override
603     {
604         IgnoreUnused(uid);
605         return nullptr; // Not used by the unit tests
606     }
607
608     const CounterSet* GetCounterSet(uint16_t uid) const override
609     {
610         IgnoreUnused(uid);
611         return nullptr; // Not used by the unit tests
612     }
613
614     const Counter* GetCounter(uint16_t uid) const override
615     {
616         IgnoreUnused(uid);
617         return nullptr; // Not used by the unit tests
618     }
619
620 private:
621     Categories  m_Categories;
622     Devices     m_Devices;
623     CounterSets m_CounterSets;
624     Counters    m_Counters;
625 };
626
627 class MockProfilingService : public ProfilingService
628 {
629 public:
630     MockProfilingService(MockBufferManager& mockBufferManager,
631                          bool isProfilingEnabled,
632                          const CaptureData& captureData) :
633         m_SendCounterPacket(mockBufferManager),
634         m_IsProfilingEnabled(isProfilingEnabled),
635         m_CaptureData(captureData)
636     {}
637
638     /// Return the next random Guid in the sequence
639     ProfilingDynamicGuid NextGuid() override
640     {
641         return m_GuidGenerator.NextGuid();
642     }
643
644     /// Create a ProfilingStaticGuid based on a hash of the string
645     ProfilingStaticGuid GenerateStaticId(const std::string& str) override
646     {
647         return m_GuidGenerator.GenerateStaticId(str);
648     }
649
650     std::unique_ptr<ISendTimelinePacket> GetSendTimelinePacket() const override
651     {
652         return nullptr;
653     }
654
655     const ICounterMappings& GetCounterMappings() const override
656     {
657         return m_CounterMapping;
658     }
659
660     ISendCounterPacket& GetSendCounterPacket() override
661     {
662         return m_SendCounterPacket;
663     }
664
665     bool IsProfilingEnabled() const override
666     {
667         return m_IsProfilingEnabled;
668     }
669
670     CaptureData GetCaptureData() override
671     {
672         CaptureData copy(m_CaptureData);
673         return copy;
674     }
675
676     void RegisterMapping(uint16_t globalCounterId,
677                          uint16_t backendCounterId,
678                          const armnn::BackendId& backendId)
679     {
680         m_CounterMapping.RegisterMapping(globalCounterId, backendCounterId, backendId);
681     }
682
683     void Reset()
684     {
685         m_CounterMapping.Reset();
686     }
687
688 private:
689     ProfilingGuidGenerator m_GuidGenerator;
690     CounterIdMap           m_CounterMapping;
691     SendCounterPacket      m_SendCounterPacket;
692     bool                   m_IsProfilingEnabled;
693     CaptureData            m_CaptureData;
694 };
695
696 class MockProfilingServiceStatus : public IProfilingServiceStatus
697 {
698 public:
699     void NotifyProfilingServiceActive() override {}
700     void WaitForProfilingServiceActivation(unsigned int timeout) override { IgnoreUnused(timeout); }
701 };
702
703 } // namespace profiling
704
705 } // namespace armnn