Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / test / BatchNormTestImpl.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include <armnn/ArmNN.hpp>
8 #include <armnn/Tensor.hpp>
9 #include <backends/WorkloadInfo.hpp>
10
11 #include "test/TensorHelpers.hpp"
12
13 #include "backends/CpuTensorHandle.hpp"
14 #include "backends/WorkloadFactory.hpp"
15
16 #include "backends/test/QuantizeHelper.hpp"
17
18
19 template<typename T>
20 LayerTestResult<T,4> BatchNormTestImpl(armnn::IWorkloadFactory& workloadFactory,
21                                    float qScale,
22                                    int32_t qOffset)
23 {
24     const unsigned int width    = 2;
25     const unsigned int height   = 3;
26     const unsigned int channels = 2;
27     const unsigned int num      = 1;
28
29     armnn::TensorInfo inputTensorInfo({num, channels, height, width}, armnn::GetDataType<T>());
30     armnn::TensorInfo outputTensorInfo({num, channels, height, width}, armnn::GetDataType<T>());
31     armnn::TensorInfo tensorInfo({channels}, armnn::GetDataType<T>());
32
33     // Set quantization parameters if the requested type is a quantized type.
34     if(armnn::IsQuantizedType<T>())
35     {
36         inputTensorInfo.SetQuantizationScale(qScale);
37         inputTensorInfo.SetQuantizationOffset(qOffset);
38         outputTensorInfo.SetQuantizationScale(qScale);
39         outputTensorInfo.SetQuantizationOffset(qOffset);
40         tensorInfo.SetQuantizationScale(qScale);
41         tensorInfo.SetQuantizationOffset(qOffset);
42     }
43
44     auto input = MakeTensor<T, 4>(inputTensorInfo,
45         QuantizedVector<T>(qScale, qOffset,
46         {
47             1.f, 4.f,
48             4.f, 2.f,
49             1.f, 6.f,
50
51             1.f, 1.f,
52             4.f, 1.f,
53             -2.f, 4.f
54         }));
55     // These values are per-channel of the input.
56     auto mean     = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {3, -2}));
57     auto variance = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {4, 9}));
58     auto beta     = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {3, 2}));
59     auto gamma    = MakeTensor<T, 1>(tensorInfo, QuantizedVector<T>(qScale, qOffset, {2, 1}));
60     LayerTestResult<T,4> ret(outputTensorInfo);
61
62     std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
63     std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
64
65     armnn::BatchNormalizationQueueDescriptor data;
66     armnn::WorkloadInfo info;
67     armnn::ScopedCpuTensorHandle meanTensor(tensorInfo);
68     armnn::ScopedCpuTensorHandle varianceTensor(tensorInfo);
69     armnn::ScopedCpuTensorHandle betaTensor(tensorInfo);
70     armnn::ScopedCpuTensorHandle gammaTensor(tensorInfo);
71
72     AllocateAndCopyDataToITensorHandle(&meanTensor, &mean[0]);
73     AllocateAndCopyDataToITensorHandle(&varianceTensor, &variance[0]);
74     AllocateAndCopyDataToITensorHandle(&betaTensor, &beta[0]);
75     AllocateAndCopyDataToITensorHandle(&gammaTensor, &gamma[0]);
76
77     AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
78     AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
79     data.m_Mean             = &meanTensor;
80     data.m_Variance         = &varianceTensor;
81     data.m_Beta             = &betaTensor;
82     data.m_Gamma            = &gammaTensor;
83     data.m_Parameters.m_Eps = 0.0f;
84
85     // For each channel:
86     // substract mean, divide by standard deviation (with an epsilon to avoid div by 0),
87     // multiply by gamma and add beta
88     ret.outputExpected = MakeTensor<T, 4>(outputTensorInfo,
89         QuantizedVector<T>(qScale, qOffset,
90         {
91             1.f, 4.f,
92             4.f, 2.f,
93             1.f, 6.f,
94
95             3.f, 3.f,
96             4.f, 3.f,
97             2.f, 4.f
98         }));
99
100     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateBatchNormalization(data, info);
101
102     inputHandle->Allocate();
103     outputHandle->Allocate();
104
105     CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
106
107     workload->Execute();
108
109     CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
110
111     return ret;
112 }