Release 18.02
[platform/upstream/armnn.git] / src / armnn / backends / test / PermuteTestImpl.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include <armnn/ArmNN.hpp>
8 #include <armnn/Tensor.hpp>
9 #include <armnn/TypesUtils.hpp>
10 #include <backends/WorkloadInfo.hpp>
11
12 #include "test/TensorHelpers.hpp"
13 #include "QuantizeHelper.hpp"
14
15 #include "backends/CpuTensorHandle.hpp"
16 #include "backends/WorkloadFactory.hpp"
17
18 template<typename T>
19 LayerTestResult<T, 4> SimplePermuteTestImpl(
20         armnn::IWorkloadFactory& workloadFactory,
21         armnn::PermuteDescriptor descriptor,
22         armnn::TensorInfo inputTensorInfo,
23         armnn::TensorInfo outputTensorInfo,
24         const std::vector<T>& inputData,
25         const std::vector<T>& outputExpectedData)
26 {
27     auto input = MakeTensor<T, 4>(inputTensorInfo, inputData);
28
29     LayerTestResult<T, 4> ret(outputTensorInfo);
30     ret.outputExpected = MakeTensor<T, 4>(outputTensorInfo, outputExpectedData);
31
32     std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
33     std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
34
35     armnn::PermuteQueueDescriptor data;
36     data.m_Parameters = descriptor;
37     armnn::WorkloadInfo info;
38     AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
39     AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
40
41     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreatePermute(data, info);
42
43     inputHandle->Allocate();
44     outputHandle->Allocate();
45
46     CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
47
48     workload->Execute();
49
50     CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
51
52     return ret;
53 }
54
55 LayerTestResult<float, 4> SimplePermuteFloat32TestCommon(armnn::IWorkloadFactory& workloadFactory)
56 {
57     armnn::TensorInfo inputTensorInfo;
58     armnn::TensorInfo outputTensorInfo;
59
60     unsigned int inputShape[] = { 1, 2, 2, 2 };
61     unsigned int outputShape[] = { 1, 2, 2, 2 };
62
63     armnn::PermuteDescriptor descriptor;
64     descriptor.m_DimMappings = {0U, 3U, 1U, 2U};
65
66     inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::DataType::Float32);
67     outputTensorInfo = armnn::TensorInfo(4, outputShape, armnn::DataType::Float32);
68
69     std::vector<float> input = std::vector<float>(
70             {
71                     1.0f, 2.0f,
72                     3.0f, 4.0f,
73
74                     5.0f, 6.0f,
75                     7.0f, 8.0f
76             });
77
78     std::vector<float> outputExpected = std::vector<float>(
79             {
80                     1.0f, 5.0f, 2.0f, 6.0f,
81                     3.0f, 7.0f, 4.0f, 8.0f
82             });
83
84     return SimplePermuteTestImpl<float>(workloadFactory, descriptor, inputTensorInfo,
85                                         outputTensorInfo, input, outputExpected);
86 }
87
88 LayerTestResult<uint8_t, 4> SimplePermuteUint8TestCommon(armnn::IWorkloadFactory& workloadFactory)
89 {
90     armnn::TensorInfo inputTensorInfo;
91     armnn::TensorInfo outputTensorInfo;
92
93     unsigned int inputShape[] = { 1, 2, 2, 2 };
94     unsigned int outputShape[] = { 1, 2, 2, 2 };
95
96     armnn::PermuteDescriptor descriptor;
97     descriptor.m_DimMappings = {0U, 3U, 1U, 2U};
98
99     inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::DataType::QuantisedAsymm8);
100     inputTensorInfo.SetQuantizationScale(1.0f);
101     outputTensorInfo = armnn::TensorInfo(4, outputShape, armnn::DataType::QuantisedAsymm8);
102     outputTensorInfo.SetQuantizationScale(1.0f);
103
104     std::vector<uint8_t> input = std::vector<uint8_t>(
105             {
106                     1, 2,
107                     3, 4,
108
109                     5, 6,
110                     7, 8
111             });
112
113     std::vector<uint8_t> outputExpected = std::vector<uint8_t>(
114             {
115                     1, 5, 2, 6,
116                     3, 7, 4, 8
117             });
118
119     return SimplePermuteTestImpl<uint8_t>(workloadFactory, descriptor, inputTensorInfo,
120                                           outputTensorInfo, input, outputExpected);
121 }