IVGCVSW-3429 Add a utility Version class
[platform/upstream/armnn.git] / src / profiling / test / ProfilingTests.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "../CommandHandlerKey.hpp"
7 #include "../CommandHandlerFunctor.hpp"
8 #include "../CommandHandlerRegistry.hpp"
9 #include "../EncodeVersion.hpp"
10 #include "../Packet.hpp"
11
12 #include <cstdint>
13 #include <cstring>
14 #include <boost/test/unit_test.hpp>
15 #include <map>
16
17 BOOST_AUTO_TEST_SUITE(ExternalProfiling)
18
19 BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons)
20 {
21     CommandHandlerKey testKey0(1, 1);
22     CommandHandlerKey testKey1(1, 1);
23     CommandHandlerKey testKey2(1, 1);
24     CommandHandlerKey testKey3(0, 0);
25     CommandHandlerKey testKey4(2, 2);
26     CommandHandlerKey testKey5(0, 2);
27
28     BOOST_CHECK(testKey1<testKey4);
29     BOOST_CHECK(testKey1>testKey3);
30     BOOST_CHECK(testKey1<=testKey4);
31     BOOST_CHECK(testKey1>=testKey3);
32     BOOST_CHECK(testKey1<=testKey2);
33     BOOST_CHECK(testKey1>=testKey2);
34     BOOST_CHECK(testKey1==testKey2);
35     BOOST_CHECK(testKey1==testKey1);
36
37     BOOST_CHECK(!(testKey1==testKey5));
38     BOOST_CHECK(!(testKey1!=testKey1));
39     BOOST_CHECK(testKey1!=testKey5);
40
41     BOOST_CHECK(testKey1==testKey2 && testKey2==testKey1);
42     BOOST_CHECK(testKey0==testKey1 && testKey1==testKey2 && testKey0==testKey2);
43
44     BOOST_CHECK(testKey1.GetPacketId()==1);
45     BOOST_CHECK(testKey1.GetVersion()==1);
46
47     std::vector<CommandHandlerKey> vect =
48         {
49             CommandHandlerKey(0,1), CommandHandlerKey(2,0), CommandHandlerKey(1,0),
50             CommandHandlerKey(2,1), CommandHandlerKey(1,1), CommandHandlerKey(0,1),
51             CommandHandlerKey(2,0), CommandHandlerKey(0,0)
52         };
53
54     std::sort(vect.begin(), vect.end());
55
56     std::vector<CommandHandlerKey> expectedVect =
57         {
58             CommandHandlerKey(0,0), CommandHandlerKey(0,1), CommandHandlerKey(0,1),
59             CommandHandlerKey(1,0), CommandHandlerKey(1,1), CommandHandlerKey(2,0),
60             CommandHandlerKey(2,0), CommandHandlerKey(2,1)
61         };
62
63     BOOST_CHECK(vect == expectedVect);
64 }
65
66 BOOST_AUTO_TEST_CASE(CheckEncodeVersion)
67 {
68     mlutil::Version version1(12);
69
70     BOOST_CHECK(version1.GetMajor() == 0);
71     BOOST_CHECK(version1.GetMinor() == 0);
72     BOOST_CHECK(version1.GetPatch() == 12);
73
74     mlutil::Version version2(4108);
75
76     BOOST_CHECK(version2.GetMajor() == 0);
77     BOOST_CHECK(version2.GetMinor() == 1);
78     BOOST_CHECK(version2.GetPatch() == 12);
79
80     mlutil::Version version3(4198412);
81
82     BOOST_CHECK(version3.GetMajor() == 1);
83     BOOST_CHECK(version3.GetMinor() == 1);
84     BOOST_CHECK(version3.GetPatch() == 12);
85
86     mlutil::Version version4(0);
87
88     BOOST_CHECK(version4.GetMajor() == 0);
89     BOOST_CHECK(version4.GetMinor() == 0);
90     BOOST_CHECK(version4.GetPatch() == 0);
91
92     mlutil::Version version5(1,0,0);
93     BOOST_CHECK(version5.GetEncodedValue() == 4194304);
94 }
95
96 BOOST_AUTO_TEST_CASE(CheckPacketClass)
97 {
98     const char* data = "test";
99     unsigned int length = static_cast<unsigned int>(std::strlen(data));
100
101     Packet packetTest1(472580096,length,data);
102     BOOST_CHECK_THROW(Packet packetTest2(472580096,0,""), armnn::Exception);
103
104     Packet packetTest3(472580096,0, nullptr);
105
106     BOOST_CHECK(packetTest1.GetLength() == length);
107     BOOST_CHECK(packetTest1.GetData() == data);
108
109     BOOST_CHECK(packetTest1.GetPacketFamily() == 7);
110     BOOST_CHECK(packetTest1.GetPacketId() == 43);
111     BOOST_CHECK(packetTest1.GetPacketType() == 3);
112     BOOST_CHECK(packetTest1.GetPacketClass() == 5);
113 }
114
115 // Create Derived Classes
116 class TestFunctorA : public CommandHandlerFunctor
117 {
118 public:
119     using CommandHandlerFunctor::CommandHandlerFunctor;
120
121     int GetCount() { return m_Count; }
122
123     void operator()(const Packet& packet) override
124     {
125         m_Count++;
126     }
127
128 private:
129     int m_Count = 0;
130 };
131
132 class TestFunctorB : public TestFunctorA
133 {
134     using TestFunctorA::TestFunctorA;
135 };
136
137 class TestFunctorC : public TestFunctorA
138 {
139     using TestFunctorA::TestFunctorA;
140 };
141
142 BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
143 {
144     // Hard code the version as it will be the same during a single profiling session
145     uint32_t version = 1;
146
147     TestFunctorA testFunctorA(461, version);
148     TestFunctorB testFunctorB(963, version);
149     TestFunctorC testFunctorC(983, version);
150
151     CommandHandlerKey keyA(testFunctorA.GetPacketId(), testFunctorA.GetVersion());
152     CommandHandlerKey keyB(testFunctorB.GetPacketId(), testFunctorB.GetVersion());
153     CommandHandlerKey keyC(testFunctorC.GetPacketId(), testFunctorC.GetVersion());
154
155     // Create the unwrapped map to simulate the Command Handler Registry
156     std::map<CommandHandlerKey, CommandHandlerFunctor*> registry;
157
158     registry.insert(std::make_pair(keyB, &testFunctorB));
159     registry.insert(std::make_pair(keyA, &testFunctorA));
160     registry.insert(std::make_pair(keyC, &testFunctorC));
161
162     // Check the order of the map is correct
163     auto it = registry.begin();
164     BOOST_CHECK(it->first==keyA);
165     it++;
166     BOOST_CHECK(it->first==keyB);
167     it++;
168     BOOST_CHECK(it->first==keyC);
169
170     Packet packetA(500000000, 0, nullptr);
171     Packet packetB(600000000, 0, nullptr);
172     Packet packetC(400000000, 0, nullptr);
173
174     // Check the correct operator of derived class is called
175     registry.at(CommandHandlerKey(packetA.GetPacketId(), version))->operator()(packetA);
176     BOOST_CHECK(testFunctorA.GetCount() == 1);
177     BOOST_CHECK(testFunctorB.GetCount() == 0);
178     BOOST_CHECK(testFunctorC.GetCount() == 0);
179
180     registry.at(CommandHandlerKey(packetB.GetPacketId(), version))->operator()(packetB);
181     BOOST_CHECK(testFunctorA.GetCount() == 1);
182     BOOST_CHECK(testFunctorB.GetCount() == 1);
183     BOOST_CHECK(testFunctorC.GetCount() == 0);
184
185     registry.at(CommandHandlerKey(packetC.GetPacketId(), version))->operator()(packetC);
186     BOOST_CHECK(testFunctorA.GetCount() == 1);
187     BOOST_CHECK(testFunctorB.GetCount() == 1);
188     BOOST_CHECK(testFunctorC.GetCount() == 1);
189 }
190
191 BOOST_AUTO_TEST_CASE(CheckCommandHandlerRegistry)
192 {
193     // Hard code the version as it will be the same during a single profiling session
194     uint32_t version = 1;
195
196     TestFunctorA testFunctorA(461, version);
197     TestFunctorB testFunctorB(963, version);
198     TestFunctorC testFunctorC(983, version);
199
200     // Create the Command Handler Registry
201     CommandHandlerRegistry registry;
202
203     // Register multiple different derived classes
204     registry.RegisterFunctor(&testFunctorA, testFunctorA.GetPacketId(), testFunctorA.GetVersion());
205     registry.RegisterFunctor(&testFunctorB, testFunctorB.GetPacketId(), testFunctorB.GetVersion());
206     registry.RegisterFunctor(&testFunctorC, testFunctorC.GetPacketId(), testFunctorC.GetVersion());
207
208     Packet packetA(500000000, 0, nullptr);
209     Packet packetB(600000000, 0, nullptr);
210     Packet packetC(400000000, 0, nullptr);
211
212     // Check the correct operator of derived class is called
213     registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetA);
214     BOOST_CHECK(testFunctorA.GetCount() == 1);
215     BOOST_CHECK(testFunctorB.GetCount() == 0);
216     BOOST_CHECK(testFunctorC.GetCount() == 0);
217
218     registry.GetFunctor(packetB.GetPacketId(), version)->operator()(packetB);
219     BOOST_CHECK(testFunctorA.GetCount() == 1);
220     BOOST_CHECK(testFunctorB.GetCount() == 1);
221     BOOST_CHECK(testFunctorC.GetCount() == 0);
222
223     registry.GetFunctor(packetC.GetPacketId(), version)->operator()(packetC);
224     BOOST_CHECK(testFunctorA.GetCount() == 1);
225     BOOST_CHECK(testFunctorB.GetCount() == 1);
226     BOOST_CHECK(testFunctorC.GetCount() == 1);
227
228     // Re-register an existing key with a new function
229     registry.RegisterFunctor(&testFunctorC, testFunctorA.GetPacketId(), version);
230     registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetC);
231     BOOST_CHECK(testFunctorA.GetCount() == 1);
232     BOOST_CHECK(testFunctorB.GetCount() == 1);
233     BOOST_CHECK(testFunctorC.GetCount() == 2);
234
235     // Check that non-existent key returns nullptr for its functor
236     BOOST_CHECK_THROW(registry.GetFunctor(0, 0), armnn::Exception);
237 }
238
239 BOOST_AUTO_TEST_SUITE_END()