--- /dev/null
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "Holder.hpp"
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+CaptureData& CaptureData::operator= (const CaptureData& captureData)
+{
+ m_CapturePeriod = captureData.m_CapturePeriod;
+ m_CounterIds = captureData.m_CounterIds;
+
+ return *this;
+}
+
+void CaptureData::SetCapturePeriod(uint32_t capturePeriod)
+{
+ m_CapturePeriod = capturePeriod;
+}
+
+void CaptureData::SetCounterIds(std::vector<uint16_t>& counterIds)
+{
+ m_CounterIds = counterIds;
+}
+
+std::uint32_t CaptureData::GetCapturePeriod() const
+{
+ return m_CapturePeriod;
+}
+
+std::vector<uint16_t> CaptureData::GetCounterIds() const
+{
+ return m_CounterIds;
+}
+
+CaptureData Holder::GetCaptureData() const
+{
+ std::lock_guard<std::mutex> lockGuard(m_CaptureThreadMutex);
+ return m_CaptureData;
+}
+
+void Holder::SetCaptureData(uint32_t capturePeriod, std::vector<uint16_t>& counterIds)
+{
+ std::lock_guard<std::mutex> lockGuard(m_CaptureThreadMutex);
+ m_CaptureData.SetCapturePeriod(capturePeriod);
+ m_CaptureData.SetCounterIds(counterIds);
+}
+
+} // namespace profiling
+
+} // namespace armnn
\ No newline at end of file
--- /dev/null
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include <mutex>
+#include <vector>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+class CaptureData
+{
+public:
+ CaptureData()
+ : m_CapturePeriod(0), m_CounterIds() {};
+ CaptureData(uint32_t capturePeriod, std::vector<uint16_t>& counterIds)
+ : m_CapturePeriod(capturePeriod), m_CounterIds(counterIds) {};
+ CaptureData(const CaptureData& captureData)
+ : m_CapturePeriod(captureData.m_CapturePeriod), m_CounterIds(captureData.m_CounterIds) {};
+
+ CaptureData& operator= (const CaptureData& captureData);
+
+ void SetCapturePeriod(uint32_t capturePeriod);
+ void SetCounterIds(std::vector<uint16_t>& counterIds);
+ uint32_t GetCapturePeriod() const;
+ std::vector<uint16_t> GetCounterIds() const;
+
+private:
+ uint32_t m_CapturePeriod;
+ std::vector<uint16_t> m_CounterIds;
+};
+
+class Holder
+{
+public:
+ Holder()
+ : m_CaptureData() {};
+ CaptureData GetCaptureData() const;
+ void SetCaptureData(uint32_t capturePeriod, std::vector<uint16_t>& counterIds);
+
+private:
+ mutable std::mutex m_CaptureThreadMutex;
+ CaptureData m_CaptureData;
+};
+
+} // namespace profiling
+
+} // namespace armnn
#include "../CommandHandlerFunctor.hpp"
#include "../CommandHandlerRegistry.hpp"
#include "../EncodeVersion.hpp"
+#include "../Holder.hpp"
#include "../Packet.hpp"
#include "../PacketVersionResolver.hpp"
#include "../ProfilingStateMachine.hpp"
BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected));
}
+void CaptureDataWriteThreadImpl(Holder &holder, uint32_t capturePeriod, std::vector<uint16_t>& counterIds)
+{
+ holder.SetCaptureData(capturePeriod, counterIds);
+}
+
+void CaptureDataReadThreadImpl(Holder &holder, CaptureData& captureData)
+{
+ captureData = holder.GetCaptureData();
+}
+
+BOOST_AUTO_TEST_CASE(CheckCaptureDataHolder)
+{
+ std::vector<uint16_t> counterIds1 = {};
+ uint32_t capturePeriod1(1);
+ std::vector<uint16_t> counterIds2 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+ uint32_t capturePeriod2(2);
+ std::vector<uint16_t> counterIds3 = {4, 5, 5, 6};
+ uint32_t capturePeriod3(3);
+
+ // Check CaptureData functions
+ CaptureData capture;
+ BOOST_CHECK(capture.GetCapturePeriod() == 0);
+ BOOST_CHECK((capture.GetCounterIds()).empty());
+ capture.SetCapturePeriod(capturePeriod2);
+ capture.SetCounterIds(counterIds2);
+ BOOST_CHECK(capture.GetCapturePeriod() == capturePeriod2);
+ BOOST_CHECK(capture.GetCounterIds() == counterIds2);
+
+ Holder holder;
+ BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == 0);
+ BOOST_CHECK(((holder.GetCaptureData()).GetCounterIds()).empty());
+
+ // Check Holder functions
+ std::thread thread1(CaptureDataWriteThreadImpl, std::ref(holder), capturePeriod3, std::ref(counterIds3));
+ thread1.join();
+
+ BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == capturePeriod3);
+ BOOST_CHECK((holder.GetCaptureData()).GetCounterIds() == counterIds3);
+
+ CaptureData captureData;
+ std::thread thread2(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData));
+ thread2.join();
+ BOOST_CHECK(captureData.GetCounterIds() == counterIds3);
+
+ std::thread thread3(CaptureDataWriteThreadImpl, std::ref(holder), capturePeriod2, std::ref(counterIds1));
+ std::thread thread4(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData));
+ std::thread thread5(CaptureDataWriteThreadImpl, std::ref(holder), capturePeriod1, std::ref(counterIds2));
+ thread3.join();
+ thread4.join();
+ thread5.join();
+
+ // Check CaptureData was written/read correctly from multiple threads
+ std::vector<uint16_t> captureIds = captureData.GetCounterIds();
+ uint32_t capturePeriod = captureData.GetCapturePeriod();
+ if (captureIds == counterIds1)
+ {
+ BOOST_CHECK(capturePeriod == capturePeriod2);
+ }
+ else if (captureIds == counterIds2)
+ {
+ BOOST_CHECK(capturePeriod == capturePeriod1);
+ }
+ else
+ {
+ BOOST_ERROR("Error in CaptureData read/write.");
+ }
+
+ std::vector<uint16_t> readIds = holder.GetCaptureData().GetCounterIds();
+ BOOST_CHECK(readIds == counterIds1 || readIds == counterIds2);
+
+ // Check assignment operator
+ CaptureData assignableCaptureData;
+ assignableCaptureData.SetCapturePeriod(capturePeriod3);
+ assignableCaptureData.SetCounterIds(counterIds3);
+
+ CaptureData secondCaptureData;
+ secondCaptureData.SetCapturePeriod(capturePeriod2);
+ secondCaptureData.SetCounterIds(counterIds2);
+
+ BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 2);
+ BOOST_CHECK(secondCaptureData.GetCounterIds() == counterIds2);
+
+ secondCaptureData = assignableCaptureData;
+ BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 3);
+ BOOST_CHECK(secondCaptureData.GetCounterIds() == counterIds3);
+
+ // Check copy constructor
+ CaptureData copyConstructedCaptureData(assignableCaptureData);
+
+ BOOST_CHECK(copyConstructedCaptureData.GetCapturePeriod() == 3);
+ BOOST_CHECK(copyConstructedCaptureData.GetCounterIds() == counterIds3);
+
+}
+
BOOST_AUTO_TEST_SUITE_END()