Imported Upstream version 1.21.0
[platform/core/ml/nnfw.git] / compiler / luci-micro / luci-interpreter / src / kernels / TestUtils.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/TestUtils.h"
19
20 #include <stdexcept>
21
22 namespace luci_interpreter
23 {
24 namespace kernels
25 {
26 namespace testing
27 {
28
29 using ::testing::FloatNear;
30 using ::testing::Matcher;
31
32 Tensor makeOutputTensor(DataType element_type) { return Tensor(element_type, {}, {}, ""); }
33
34 Tensor makeOutputTensor(DataType element_type, float scale, int32_t zero_point)
35 {
36   return Tensor(element_type, {}, {{scale}, {zero_point}}, "");
37 }
38
39 std::vector<float> dequantizeTensorData(const Tensor &tensor)
40 {
41   if (tensor.element_type() == DataType::U8)
42   {
43     std::vector<uint8_t> data = extractTensorData<uint8_t>(tensor);
44     return dequantize(data.data(), data.size(), tensor.scale(), tensor.zero_point());
45   }
46   if (tensor.element_type() == DataType::S8)
47   {
48     std::vector<int8_t> data = extractTensorData<int8_t>(tensor);
49     return dequantize(data.data(), data.size(), tensor.scale(), tensor.zero_point());
50   }
51   else if (tensor.element_type() == DataType::S16)
52   {
53     // S16 quantization is symmetric, so zero point should be zero.
54     for (auto zp : tensor.zero_points())
55     {
56       (void)zp;
57       assert(zp == 0);
58     }
59
60     std::vector<int16_t> data = extractTensorData<int16_t>(tensor);
61     if (tensor.scales().size() == 1)
62     {
63       return dequantize(data.data(), data.size(), tensor.scale(), 0);
64     }
65
66     // quantize_dimension breaks shape into two parts:
67     // inner dimensions that contains continuous data with one quantization type
68     // outer dimensions that contains other dimensions
69     const Shape shape = tensor.shape();
70     const int32_t quantized_dimension = tensor.quantized_dimension();
71     assert(quantized_dimension < shape.num_dims());
72     size_t outer_dims_size = 1;
73     int32_t quant_dim_size = shape.dim(quantized_dimension);
74     size_t inner_dims_size = 1;
75     assert(quant_dim_size == tensor.scales().size());
76
77     for (int i = 0; i < quantized_dimension; ++i)
78       outer_dims_size *= shape.dim(i);
79     for (int i = quantized_dimension + 1; i < shape.num_dims(); ++i)
80       inner_dims_size *= shape.dim(i);
81
82     assert(shape.num_elements() == outer_dims_size * quant_dim_size * inner_dims_size);
83
84     std::vector<float> dequantized_data;
85     dequantized_data.reserve(shape.num_elements());
86     for (size_t outer_it = 0; outer_it < outer_dims_size; ++outer_it)
87       for (int32_t channel = 0; channel < quant_dim_size; ++channel)
88       {
89         float scale = tensor.scales()[channel];
90         size_t offset = inner_dims_size * (quant_dim_size * outer_it + channel);
91         std::vector<float> part_dequantized_data =
92           dequantize(data.data() + offset, inner_dims_size, scale, 0);
93         dequantized_data.insert(dequantized_data.end(), part_dequantized_data.begin(),
94                                 part_dequantized_data.end());
95       }
96     return dequantized_data;
97   }
98   else
99   {
100     throw std::runtime_error("Unsupported type.");
101   }
102 }
103
104 Matcher<std::vector<float>> FloatArrayNear(const std::vector<float> &values, float max_abs_error)
105 {
106   std::vector<Matcher<float>> matchers;
107   matchers.reserve(values.size());
108   for (const float v : values)
109   {
110     matchers.emplace_back(FloatNear(v, max_abs_error));
111   }
112   return ElementsAreArray(matchers);
113 }
114
115 std::vector<int32_t> extractTensorShape(const Tensor &tensor)
116 {
117   std::vector<int32_t> result;
118   int dims = tensor.shape().num_dims();
119   for (int i = 0; i < dims; i++)
120   {
121     result.push_back(tensor.shape().dim(i));
122   }
123   return result;
124 }
125
126 } // namespace testing
127 } // namespace kernels
128 } // namespace luci_interpreter