IVGCVSW-3440 Fix intermittently failing send thread test
[platform/upstream/armnn.git] / src / profiling / test / SendCounterPacketTests.hpp
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <SendCounterPacket.hpp>
9 #include <ProfilingUtils.hpp>
10
11 #include <armnn/Exceptions.hpp>
12 #include <armnn/Optional.hpp>
13 #include <armnn/Conversion.hpp>
14
15 #include <boost/numeric/conversion/cast.hpp>
16
17 namespace armnn
18 {
19
20 namespace profiling
21 {
22
23 class MockProfilingConnection : public IProfilingConnection
24 {
25 public:
26     MockProfilingConnection()
27         : m_IsOpen(true)
28     {}
29
30     bool IsOpen() override { return m_IsOpen; }
31
32     void Close() override { m_IsOpen = false; }
33
34     bool WritePacket(const unsigned char* buffer, uint32_t length) override
35     {
36         return buffer != nullptr && length > 0;
37     }
38
39     Packet ReadPacket(uint32_t timeout) override { return Packet(); }
40
41 private:
42     bool m_IsOpen;
43 };
44
45 class MockPacketBuffer : public IPacketBuffer
46 {
47 public:
48     MockPacketBuffer(unsigned int maxSize)
49         : m_MaxSize(maxSize)
50         , m_Size(0)
51         , m_Data(std::make_unique<unsigned char[]>(m_MaxSize))
52     {}
53
54     ~MockPacketBuffer() {}
55
56     const unsigned char* const GetReadableData() const override { return m_Data.get(); }
57
58     unsigned int GetSize() const override { return m_Size; }
59
60     void MarkRead() override { m_Size = 0; }
61
62     void Commit(unsigned int size) override { m_Size = size; }
63
64     void Release() override { m_Size = 0; }
65
66     unsigned char* GetWritableData() override { return m_Data.get(); }
67
68 private:
69     unsigned int m_MaxSize;
70     unsigned int m_Size;
71     std::unique_ptr<unsigned char[]> m_Data;
72 };
73
74 class MockBufferManager : public IBufferManager
75 {
76 public:
77     MockBufferManager(unsigned int size)
78     : m_BufferSize(size),
79       m_Buffer(std::make_unique<MockPacketBuffer>(size)) {}
80
81     ~MockBufferManager() {}
82
83     std::unique_ptr<IPacketBuffer> Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
84     {
85         if (requestedSize > m_BufferSize)
86         {
87             reservedSize = m_BufferSize;
88         }
89         else
90         {
91             reservedSize = requestedSize;
92         }
93
94         return std::move(m_Buffer);
95     }
96
97     void Commit(std::unique_ptr<IPacketBuffer>& packetBuffer, unsigned int size) override
98     {
99         packetBuffer->Commit(size);
100         m_Buffer = std::move(packetBuffer);
101     }
102
103     std::unique_ptr<IPacketBuffer> GetReadableBuffer() override
104     {
105         return std::move(m_Buffer);
106     }
107
108     void Release(std::unique_ptr<IPacketBuffer>& packetBuffer) override
109     {
110         packetBuffer->Release();
111         m_Buffer = std::move(packetBuffer);
112     }
113
114     void MarkRead(std::unique_ptr<IPacketBuffer>& packetBuffer) override
115     {
116         packetBuffer->MarkRead();
117         m_Buffer = std::move(packetBuffer);
118     }
119
120 private:
121     unsigned int m_BufferSize;
122     std::unique_ptr<IPacketBuffer> m_Buffer;
123 };
124
125 class MockStreamCounterBuffer : public IBufferManager
126 {
127 public:
128     using IPacketBufferPtr = std::unique_ptr<IPacketBuffer>;
129
130     MockStreamCounterBuffer(unsigned int maxBufferSize = 4096)
131         : m_MaxBufferSize(maxBufferSize)
132         , m_BufferList()
133         , m_CommittedSize(0)
134         , m_ReadableSize(0)
135         , m_ReadSize(0)
136     {}
137     ~MockStreamCounterBuffer() {}
138
139     IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
140     {
141         std::unique_lock<std::mutex> lock(m_Mutex);
142
143         reservedSize = 0;
144         if (requestedSize > m_MaxBufferSize)
145         {
146             throw armnn::InvalidArgumentException("The maximum buffer size that can be requested is [" +
147                                                   std::to_string(m_MaxBufferSize) + "] bytes");
148         }
149         reservedSize = requestedSize;
150         return std::make_unique<MockPacketBuffer>(requestedSize);
151     }
152
153     void Commit(IPacketBufferPtr& packetBuffer, unsigned int size) override
154     {
155         std::unique_lock<std::mutex> lock(m_Mutex);
156
157         packetBuffer->Commit(size);
158         m_BufferList.push_back(std::move(packetBuffer));
159         m_CommittedSize += size;
160     }
161
162     void Release(IPacketBufferPtr& packetBuffer) override
163     {
164         std::unique_lock<std::mutex> lock(m_Mutex);
165
166         packetBuffer->Release();
167     }
168
169     IPacketBufferPtr GetReadableBuffer() override
170     {
171         std::unique_lock<std::mutex> lock(m_Mutex);
172
173         if (m_BufferList.empty())
174         {
175             return nullptr;
176         }
177         IPacketBufferPtr buffer = std::move(m_BufferList.back());
178         m_BufferList.pop_back();
179         m_ReadableSize += buffer->GetSize();
180         return buffer;
181     }
182
183     void MarkRead(IPacketBufferPtr& packetBuffer) override
184     {
185         std::unique_lock<std::mutex> lock(m_Mutex);
186
187         m_ReadSize += packetBuffer->GetSize();
188         packetBuffer->MarkRead();
189     }
190
191     unsigned int GetCommittedSize() const { return m_CommittedSize; }
192     unsigned int GetReadableSize()  const { return m_ReadableSize;  }
193     unsigned int GetReadSize()      const { return m_ReadSize;      }
194
195 private:
196     // The maximum buffer size when creating a new buffer
197     unsigned int m_MaxBufferSize;
198
199     // A list of buffers
200     std::vector<IPacketBufferPtr> m_BufferList;
201
202     // The mutex to synchronize this mock's methods
203     std::mutex m_Mutex;
204
205     // The total size of the buffers that has been committed for reading
206     unsigned int m_CommittedSize;
207
208     // The total size of the buffers that can be read
209     unsigned int m_ReadableSize;
210
211     // The total size of the buffers that has already been read
212     unsigned int m_ReadSize;
213 };
214
215 class MockSendCounterPacket : public ISendCounterPacket
216 {
217 public:
218     MockSendCounterPacket(IBufferManager& sendBuffer) : m_BufferManager(sendBuffer) {}
219
220     void SendStreamMetaDataPacket() override
221     {
222         std::string message("SendStreamMetaDataPacket");
223         unsigned int reserved = 0;
224         std::unique_ptr<IPacketBuffer> buffer = m_BufferManager.Reserve(1024, reserved);
225         memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
226         m_BufferManager.Commit(buffer, reserved);
227     }
228
229     void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory) override
230     {
231         std::string message("SendCounterDirectoryPacket");
232         unsigned int reserved = 0;
233         std::unique_ptr<IPacketBuffer> buffer = m_BufferManager.Reserve(1024, reserved);
234         memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
235         m_BufferManager.Commit(buffer, reserved);
236     }
237
238     void SendPeriodicCounterCapturePacket(uint64_t timestamp,
239                                           const std::vector<std::pair<uint16_t, uint32_t>>& values) override
240     {
241         std::string message("SendPeriodicCounterCapturePacket");
242         unsigned int reserved = 0;
243         std::unique_ptr<IPacketBuffer> buffer = m_BufferManager.Reserve(1024, reserved);
244         memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
245         m_BufferManager.Commit(buffer, reserved);
246     }
247
248     void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod,
249                                             const std::vector<uint16_t>& selectedCounterIds) override
250     {
251         std::string message("SendPeriodicCounterSelectionPacket");
252         unsigned int reserved = 0;
253         std::unique_ptr<IPacketBuffer> buffer = m_BufferManager.Reserve(1024, reserved);
254         memcpy(buffer->GetWritableData(), message.c_str(), static_cast<unsigned int>(message.size()) + 1);
255         m_BufferManager.Commit(buffer, reserved);
256     }
257
258     void SetReadyToRead() override {}
259
260 private:
261     IBufferManager& m_BufferManager;
262 };
263
264 class MockCounterDirectory : public ICounterDirectory
265 {
266 public:
267     MockCounterDirectory() = default;
268     ~MockCounterDirectory() = default;
269
270     // Register profiling objects
271     const Category* RegisterCategory(const std::string& categoryName,
272                                      const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
273                                      const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
274     {
275         // Get the device UID
276         uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
277
278         // Get the counter set UID
279         uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
280
281         // Create the category
282         CategoryPtr category = std::make_unique<Category>(categoryName, deviceUidValue, counterSetUidValue);
283         BOOST_ASSERT(category);
284
285         // Get the raw category pointer
286         const Category* categoryPtr = category.get();
287         BOOST_ASSERT(categoryPtr);
288
289         // Register the category
290         m_Categories.insert(std::move(category));
291
292         return categoryPtr;
293     }
294
295     const Device* RegisterDevice(const std::string& deviceName,
296                                  uint16_t cores = 0,
297                                  const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
298     {
299         // Get the device UID
300         uint16_t deviceUid = GetNextUid();
301
302         // Create the device
303         DevicePtr device = std::make_unique<Device>(deviceUid, deviceName, cores);
304         BOOST_ASSERT(device);
305
306         // Get the raw device pointer
307         const Device* devicePtr = device.get();
308         BOOST_ASSERT(devicePtr);
309
310         // Register the device
311         m_Devices.insert(std::make_pair(deviceUid, std::move(device)));
312
313         // Connect the counter set to the parent category, if required
314         if (parentCategoryName.has_value())
315         {
316             // Set the counter set UID in the parent category
317             Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
318             BOOST_ASSERT(parentCategory);
319             parentCategory->m_DeviceUid = deviceUid;
320         }
321
322         return devicePtr;
323     }
324
325     const CounterSet* RegisterCounterSet(
326             const std::string& counterSetName,
327             uint16_t count = 0,
328             const armnn::Optional<std::string>& parentCategoryName = armnn::EmptyOptional())
329     {
330         // Get the counter set UID
331         uint16_t counterSetUid = GetNextUid();
332
333         // Create the counter set
334         CounterSetPtr counterSet = std::make_unique<CounterSet>(counterSetUid, counterSetName, count);
335         BOOST_ASSERT(counterSet);
336
337         // Get the raw counter set pointer
338         const CounterSet* counterSetPtr = counterSet.get();
339         BOOST_ASSERT(counterSetPtr);
340
341         // Register the counter set
342         m_CounterSets.insert(std::make_pair(counterSetUid, std::move(counterSet)));
343
344         // Connect the counter set to the parent category, if required
345         if (parentCategoryName.has_value())
346         {
347             // Set the counter set UID in the parent category
348             Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName.value()));
349             BOOST_ASSERT(parentCategory);
350             parentCategory->m_CounterSetUid = counterSetUid;
351         }
352
353         return counterSetPtr;
354     }
355
356     const Counter* RegisterCounter(const std::string& parentCategoryName,
357                                    uint16_t counterClass,
358                                    uint16_t interpolation,
359                                    double multiplier,
360                                    const std::string& name,
361                                    const std::string& description,
362                                    const armnn::Optional<std::string>& units = armnn::EmptyOptional(),
363                                    const armnn::Optional<uint16_t>& numberOfCores = armnn::EmptyOptional(),
364                                    const armnn::Optional<uint16_t>& deviceUid = armnn::EmptyOptional(),
365                                    const armnn::Optional<uint16_t>& counterSetUid = armnn::EmptyOptional())
366     {
367         // Get the number of cores from the argument only
368         uint16_t deviceCores = numberOfCores.has_value() ? numberOfCores.value() : 0;
369
370         // Get the device UID
371         uint16_t deviceUidValue = deviceUid.has_value() ? deviceUid.value() : 0;
372
373         // Get the counter set UID
374         uint16_t counterSetUidValue = counterSetUid.has_value() ? counterSetUid.value() : 0;
375
376         // Get the counter UIDs and calculate the max counter UID
377         std::vector<uint16_t> counterUids = GetNextCounterUids(deviceCores);
378         BOOST_ASSERT(!counterUids.empty());
379         uint16_t maxCounterUid = deviceCores <= 1 ? counterUids.front() : counterUids.back();
380
381         // Get the counter units
382         const std::string unitsValue = units.has_value() ? units.value() : "";
383
384         // Create the counter
385         CounterPtr counter = std::make_shared<Counter>(counterUids.front(),
386                                                        maxCounterUid,
387                                                        counterClass,
388                                                        interpolation,
389                                                        multiplier,
390                                                        name,
391                                                        description,
392                                                        unitsValue,
393                                                        deviceUidValue,
394                                                        counterSetUidValue);
395         BOOST_ASSERT(counter);
396
397         // Get the raw counter pointer
398         const Counter* counterPtr = counter.get();
399         BOOST_ASSERT(counterPtr);
400
401         // Process multiple counters if necessary
402         for (uint16_t counterUid : counterUids)
403         {
404             // Connect the counter to the parent category
405             Category* parentCategory = const_cast<Category*>(GetCategory(parentCategoryName));
406             BOOST_ASSERT(parentCategory);
407             parentCategory->m_Counters.push_back(counterUid);
408
409             // Register the counter
410             m_Counters.insert(std::make_pair(counterUid, counter));
411         }
412
413         return counterPtr;
414     }
415
416     // Getters for counts
417     uint16_t GetCategoryCount()   const override { return boost::numeric_cast<uint16_t>(m_Categories.size());  }
418     uint16_t GetDeviceCount()     const override { return boost::numeric_cast<uint16_t>(m_Devices.size());     }
419     uint16_t GetCounterSetCount() const override { return boost::numeric_cast<uint16_t>(m_CounterSets.size()); }
420     uint16_t GetCounterCount()    const override { return boost::numeric_cast<uint16_t>(m_Counters.size());    }
421
422     // Getters for collections
423     const Categories&  GetCategories()  const override { return m_Categories;  }
424     const Devices&     GetDevices()     const override { return m_Devices;     }
425     const CounterSets& GetCounterSets() const override { return m_CounterSets; }
426     const Counters&    GetCounters()    const override { return m_Counters;    }
427
428     // Getters for profiling objects
429     const Category* GetCategory(const std::string& name) const override
430     {
431         auto it = std::find_if(m_Categories.begin(), m_Categories.end(), [&name](const CategoryPtr& category)
432         {
433             BOOST_ASSERT(category);
434
435             return category->m_Name == name;
436         });
437
438         if (it == m_Categories.end())
439         {
440             return nullptr;
441         }
442
443         return it->get();
444     }
445
446     const Device* GetDevice(uint16_t uid) const override
447     {
448         return nullptr; // Not used by the unit tests
449     }
450
451     const CounterSet* GetCounterSet(uint16_t uid) const override
452     {
453         return nullptr; // Not used by the unit tests
454     }
455
456     const Counter* GetCounter(uint16_t uid) const override
457     {
458         return nullptr; // Not used by the unit tests
459     }
460
461 private:
462     Categories  m_Categories;
463     Devices     m_Devices;
464     CounterSets m_CounterSets;
465     Counters    m_Counters;
466 };
467
468 class SendCounterPacketTest : public SendCounterPacket
469 {
470 public:
471     SendCounterPacketTest(IProfilingConnection& profilingconnection, IBufferManager& buffer)
472         : SendCounterPacket(profilingconnection, buffer)
473     {}
474
475     bool CreateDeviceRecordTest(const DevicePtr& device,
476                                 DeviceRecord& deviceRecord,
477                                 std::string& errorMessage)
478     {
479         return CreateDeviceRecord(device, deviceRecord, errorMessage);
480     }
481
482     bool CreateCounterSetRecordTest(const CounterSetPtr& counterSet,
483                                     CounterSetRecord& counterSetRecord,
484                                     std::string& errorMessage)
485     {
486         return CreateCounterSetRecord(counterSet, counterSetRecord, errorMessage);
487     }
488
489     bool CreateEventRecordTest(const CounterPtr& counter,
490                                EventRecord& eventRecord,
491                                std::string& errorMessage)
492     {
493         return CreateEventRecord(counter, eventRecord, errorMessage);
494     }
495
496     bool CreateCategoryRecordTest(const CategoryPtr& category,
497                                   const Counters& counters,
498                                   CategoryRecord& categoryRecord,
499                                   std::string& errorMessage)
500     {
501         return CreateCategoryRecord(category, counters, categoryRecord, errorMessage);
502     }
503 };
504
505 } // namespace profiling
506
507 } // namespace armnn