IVGCVSW-3432 Create CaptureData Holder
authorFrancis Murtagh <francis.murtagh@arm.com>
Wed, 4 Sep 2019 15:42:29 +0000 (16:42 +0100)
committerFrancis Murtagh <francis.murtagh@arm.com>
Wed, 4 Sep 2019 15:42:29 +0000 (16:42 +0100)
 * Create CaptureData and Holder classes
 * Add unit test

Signed-off-by: Ellen Norris-Thompson <ellen.norris-thompson@arm.com>
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Change-Id: I9f2766a8a6081ae4f9988904af2ca24cd434ebca

CMakeLists.txt
src/profiling/Holder.cpp [new file with mode: 0644]
src/profiling/Holder.hpp [new file with mode: 0644]
src/profiling/test/ProfilingTests.cpp

index a285c36a6882b39c766426cf1cabece05cbba34d..547853605027d647e0dd0285036fdd5a2d5ba18e 100644 (file)
@@ -420,6 +420,8 @@ list(APPEND armnn_sources
     src/profiling/CounterDirectory.cpp
     src/profiling/CounterDirectory.hpp
     src/profiling/EncodeVersion.hpp
+    src/profiling/Holder.cpp
+    src/profiling/Holder.hpp
     src/profiling/IProfilingConnection.hpp
     src/profiling/Packet.cpp
     src/profiling/Packet.hpp
diff --git a/src/profiling/Holder.cpp b/src/profiling/Holder.cpp
new file mode 100644 (file)
index 0000000..9def49d
--- /dev/null
@@ -0,0 +1,57 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "Holder.hpp"
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+CaptureData& CaptureData::operator= (const CaptureData& captureData)
+{
+    m_CapturePeriod = captureData.m_CapturePeriod;
+    m_CounterIds    = captureData.m_CounterIds;
+
+    return *this;
+}
+
+void CaptureData::SetCapturePeriod(uint32_t capturePeriod)
+{
+    m_CapturePeriod = capturePeriod;
+}
+
+void CaptureData::SetCounterIds(std::vector<uint16_t>& counterIds)
+{
+    m_CounterIds = counterIds;
+}
+
+std::uint32_t CaptureData::GetCapturePeriod() const
+{
+    return m_CapturePeriod;
+}
+
+std::vector<uint16_t> CaptureData::GetCounterIds() const
+{
+    return m_CounterIds;
+}
+
+CaptureData Holder::GetCaptureData() const
+{
+    std::lock_guard<std::mutex> lockGuard(m_CaptureThreadMutex);
+    return m_CaptureData;
+}
+
+void Holder::SetCaptureData(uint32_t capturePeriod, std::vector<uint16_t>& counterIds)
+{
+    std::lock_guard<std::mutex> lockGuard(m_CaptureThreadMutex);
+    m_CaptureData.SetCapturePeriod(capturePeriod);
+    m_CaptureData.SetCounterIds(counterIds);
+}
+
+} // namespace profiling
+
+} // namespace armnn
\ No newline at end of file
diff --git a/src/profiling/Holder.hpp b/src/profiling/Holder.hpp
new file mode 100644 (file)
index 0000000..c22c72a
--- /dev/null
@@ -0,0 +1,53 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include <mutex>
+#include <vector>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+class CaptureData
+{
+public:
+    CaptureData()
+        : m_CapturePeriod(0), m_CounterIds() {};
+    CaptureData(uint32_t capturePeriod, std::vector<uint16_t>& counterIds)
+        : m_CapturePeriod(capturePeriod), m_CounterIds(counterIds) {};
+    CaptureData(const CaptureData& captureData)
+        : m_CapturePeriod(captureData.m_CapturePeriod), m_CounterIds(captureData.m_CounterIds) {};
+
+    CaptureData& operator= (const CaptureData& captureData);
+
+    void SetCapturePeriod(uint32_t capturePeriod);
+    void SetCounterIds(std::vector<uint16_t>& counterIds);
+    uint32_t GetCapturePeriod() const;
+    std::vector<uint16_t> GetCounterIds() const;
+
+private:
+    uint32_t m_CapturePeriod;
+    std::vector<uint16_t> m_CounterIds;
+};
+
+class Holder
+{
+public:
+    Holder()
+        : m_CaptureData() {};
+    CaptureData GetCaptureData() const;
+    void SetCaptureData(uint32_t capturePeriod, std::vector<uint16_t>& counterIds);
+
+private:
+    mutable std::mutex m_CaptureThreadMutex;
+    CaptureData m_CaptureData;
+};
+
+} // namespace profiling
+
+} // namespace armnn
index ce278abee7832640cf2459fde3e4ccd9fe4c81e0..c7b0bda0ff2ce25bb2703029be7bf2aba064fff4 100644 (file)
@@ -7,6 +7,7 @@
 #include "../CommandHandlerFunctor.hpp"
 #include "../CommandHandlerRegistry.hpp"
 #include "../EncodeVersion.hpp"
+#include "../Holder.hpp"
 #include "../Packet.hpp"
 #include "../PacketVersionResolver.hpp"
 #include "../ProfilingStateMachine.hpp"
@@ -357,4 +358,98 @@ BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine)
     BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected));
 }
 
+void CaptureDataWriteThreadImpl(Holder &holder, uint32_t capturePeriod, std::vector<uint16_t>& counterIds)
+{
+    holder.SetCaptureData(capturePeriod, counterIds);
+}
+
+void CaptureDataReadThreadImpl(Holder &holder, CaptureData& captureData)
+{
+    captureData = holder.GetCaptureData();
+}
+
+BOOST_AUTO_TEST_CASE(CheckCaptureDataHolder)
+{
+    std::vector<uint16_t> counterIds1 = {};
+    uint32_t capturePeriod1(1);
+    std::vector<uint16_t> counterIds2 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+    uint32_t capturePeriod2(2);
+    std::vector<uint16_t> counterIds3 = {4, 5, 5, 6};
+    uint32_t capturePeriod3(3);
+
+    // Check CaptureData functions
+    CaptureData capture;
+    BOOST_CHECK(capture.GetCapturePeriod() == 0);
+    BOOST_CHECK((capture.GetCounterIds()).empty());
+    capture.SetCapturePeriod(capturePeriod2);
+    capture.SetCounterIds(counterIds2);
+    BOOST_CHECK(capture.GetCapturePeriod() == capturePeriod2);
+    BOOST_CHECK(capture.GetCounterIds() == counterIds2);
+
+    Holder holder;
+    BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == 0);
+    BOOST_CHECK(((holder.GetCaptureData()).GetCounterIds()).empty());
+
+    // Check Holder functions
+    std::thread thread1(CaptureDataWriteThreadImpl, std::ref(holder), capturePeriod3, std::ref(counterIds3));
+    thread1.join();
+
+    BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == capturePeriod3);
+    BOOST_CHECK((holder.GetCaptureData()).GetCounterIds() == counterIds3);
+
+    CaptureData captureData;
+    std::thread thread2(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData));
+    thread2.join();
+    BOOST_CHECK(captureData.GetCounterIds() == counterIds3);
+
+    std::thread thread3(CaptureDataWriteThreadImpl, std::ref(holder), capturePeriod2, std::ref(counterIds1));
+    std::thread thread4(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData));
+    std::thread thread5(CaptureDataWriteThreadImpl, std::ref(holder), capturePeriod1, std::ref(counterIds2));
+    thread3.join();
+    thread4.join();
+    thread5.join();
+
+    // Check CaptureData was written/read correctly from multiple threads
+    std::vector<uint16_t> captureIds = captureData.GetCounterIds();
+    uint32_t capturePeriod = captureData.GetCapturePeriod();
+    if (captureIds == counterIds1)
+    {
+        BOOST_CHECK(capturePeriod == capturePeriod2);
+    }
+    else if (captureIds == counterIds2)
+    {
+        BOOST_CHECK(capturePeriod == capturePeriod1);
+    }
+    else
+    {
+        BOOST_ERROR("Error in CaptureData read/write.");
+    }
+
+    std::vector<uint16_t> readIds = holder.GetCaptureData().GetCounterIds();
+    BOOST_CHECK(readIds == counterIds1 || readIds == counterIds2);
+
+    // Check assignment operator
+    CaptureData assignableCaptureData;
+    assignableCaptureData.SetCapturePeriod(capturePeriod3);
+    assignableCaptureData.SetCounterIds(counterIds3);
+
+    CaptureData secondCaptureData;
+    secondCaptureData.SetCapturePeriod(capturePeriod2);
+    secondCaptureData.SetCounterIds(counterIds2);
+
+    BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 2);
+    BOOST_CHECK(secondCaptureData.GetCounterIds() == counterIds2);
+
+    secondCaptureData = assignableCaptureData;
+    BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 3);
+    BOOST_CHECK(secondCaptureData.GetCounterIds() == counterIds3);
+
+    // Check copy constructor
+    CaptureData copyConstructedCaptureData(assignableCaptureData);
+
+    BOOST_CHECK(copyConstructedCaptureData.GetCapturePeriod() == 3);
+    BOOST_CHECK(copyConstructedCaptureData.GetCounterIds() == counterIds3);
+
+}
+
 BOOST_AUTO_TEST_SUITE_END()