2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
6 #if (defined(__aarch64__)) || (defined(__x86_64__)) // disable test failing on FireFly/Armv7
8 #include "arm_compute/runtime/CL/CLScheduler.h"
9 #include "backends/ClContextControl.hpp"
10 #include "backends/ClWorkloadFactory.hpp"
11 #include "backends/CpuTensorHandle.hpp"
12 #include <boost/format.hpp>
14 #include "OpenClTimer.hpp"
15 #include "backends/test/TensorCopyUtils.hpp"
16 #include "TensorHelpers.hpp"
17 #include <boost/test/unit_test.hpp>
18 #include "backends/WorkloadFactory.hpp"
19 #include "backends/test/WorkloadTestUtils.hpp"
21 using namespace armnn;
25 // Initialising ClContextControl to ensure OpenCL is loaded correctly for each test case.
26 // NOTE: Profiling needs to be enabled in ClContextControl to be able to obtain execution
27 // times from OpenClTimer.
28 OpenClFixture() : m_ClContextControl(nullptr, true) {}
31 ClContextControl m_ClContextControl;
34 BOOST_FIXTURE_TEST_SUITE(OpenClTimerBatchNorm, OpenClFixture)
35 using FactoryType = ClWorkloadFactory;
37 BOOST_AUTO_TEST_CASE(OpenClTimerBatchNorm)
39 ClWorkloadFactory workloadFactory;
41 const unsigned int width = 2;
42 const unsigned int height = 3;
43 const unsigned int channels = 2;
44 const unsigned int num = 1;
48 TensorInfo inputTensorInfo({num, channels, height, width}, GetDataType<float>());
49 TensorInfo outputTensorInfo({num, channels, height, width}, GetDataType<float>());
50 TensorInfo tensorInfo({channels}, GetDataType<float>());
52 // Set quantization parameters if the requested type is a quantized type.
53 if(IsQuantizedType<float>())
55 inputTensorInfo.SetQuantizationScale(qScale);
56 inputTensorInfo.SetQuantizationOffset(qOffset);
57 outputTensorInfo.SetQuantizationScale(qScale);
58 outputTensorInfo.SetQuantizationOffset(qOffset);
59 tensorInfo.SetQuantizationScale(qScale);
60 tensorInfo.SetQuantizationOffset(qOffset);
63 auto input = MakeTensor<float, 4>(inputTensorInfo,
64 QuantizedVector<float>(qScale, qOffset,
74 // these values are per-channel of the input
75 auto mean = MakeTensor<float, 1>(tensorInfo, QuantizedVector<float>(qScale, qOffset, {3, -2}));
76 auto variance = MakeTensor<float, 1>(tensorInfo, QuantizedVector<float>(qScale, qOffset, {4, 9}));
77 auto beta = MakeTensor<float, 1>(tensorInfo, QuantizedVector<float>(qScale, qOffset, {3, 2}));
78 auto gamma = MakeTensor<float, 1>(tensorInfo, QuantizedVector<float>(qScale, qOffset, {2, 1}));
80 std::unique_ptr<ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
81 std::unique_ptr<ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
83 BatchNormalizationQueueDescriptor data;
85 ScopedCpuTensorHandle meanTensor(tensorInfo);
86 ScopedCpuTensorHandle varianceTensor(tensorInfo);
87 ScopedCpuTensorHandle betaTensor(tensorInfo);
88 ScopedCpuTensorHandle gammaTensor(tensorInfo);
90 AllocateAndCopyDataToITensorHandle(&meanTensor, &mean[0]);
91 AllocateAndCopyDataToITensorHandle(&varianceTensor, &variance[0]);
92 AllocateAndCopyDataToITensorHandle(&betaTensor, &beta[0]);
93 AllocateAndCopyDataToITensorHandle(&gammaTensor, &gamma[0]);
95 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
96 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
97 data.m_Mean = &meanTensor;
98 data.m_Variance = &varianceTensor;
99 data.m_Beta = &betaTensor;
100 data.m_Gamma = &gammaTensor;
101 data.m_Parameters.m_Eps = 0.0f;
104 // substract mean, divide by standard deviation (with an epsilon to avoid div by 0)
105 // multiply by gamma and add beta
106 std::unique_ptr<IWorkload> workload = workloadFactory.CreateBatchNormalization(data, info);
108 inputHandle->Allocate();
109 outputHandle->Allocate();
111 CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
113 OpenClTimer openClTimer;
115 BOOST_CHECK_EQUAL(openClTimer.GetName(), "OpenClKernelTimer");
120 //Execute the workload
126 BOOST_CHECK_EQUAL(openClTimer.GetMeasurements().size(), 1);
128 BOOST_CHECK_EQUAL(openClTimer.GetMeasurements().front().m_Name,
129 "OpenClKernelTimer/0: batchnormalization_layer_nchw GWS[1,3,2]");
131 BOOST_CHECK(openClTimer.GetMeasurements().front().m_Value > 0);
135 BOOST_AUTO_TEST_SUITE_END()
137 #endif //aarch64 or x86_64