78caa373886d59cbdc5ed3fc209b072adf3cab60
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / TestUtils.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/TestUtils.h"
19
20 namespace luci_interpreter
21 {
22 namespace kernels
23 {
24 namespace testing
25 {
26
27 using ::testing::FloatNear;
28 using ::testing::Matcher;
29
30 Tensor makeOutputTensor(DataType element_type) { return Tensor(element_type, {}, {}, ""); }
31
32 Tensor makeOutputTensor(DataType element_type, float scale, int32_t zero_point)
33 {
34   return Tensor(element_type, {}, {{scale}, {zero_point}}, "");
35 }
36
37 std::vector<float> dequantizeTensorData(const Tensor &tensor)
38 {
39   if (tensor.element_type() == DataType::U8)
40   {
41     std::vector<uint8_t> data = extractTensorData<uint8_t>(tensor);
42     return dequantize(data.data(), data.size(), tensor.scale(), tensor.zero_point());
43   }
44   if (tensor.element_type() == DataType::S8)
45   {
46     std::vector<int8_t> data = extractTensorData<int8_t>(tensor);
47     return dequantize(data.data(), data.size(), tensor.scale(), tensor.zero_point());
48   }
49   else if (tensor.element_type() == DataType::S16)
50   {
51     // S16 quantization is symmetric, so zero point should be zero.
52     for (auto zp : tensor.zero_points())
53     {
54       (void)zp;
55       assert(zp == 0);
56     }
57
58     std::vector<int16_t> data = extractTensorData<int16_t>(tensor);
59     if (tensor.scales().size() == 1)
60     {
61       return dequantize(data.data(), data.size(), tensor.scale(), 0);
62     }
63
64     // quantize_dimension breaks shape into two parts:
65     // inner dimensions that contains continuous data with one quantization type
66     // outer dimensions that contains other dimensions
67     const Shape shape = tensor.shape();
68     const int32_t quantized_dimension = tensor.quantized_dimension();
69     assert(quantized_dimension < shape.num_dims());
70     size_t outer_dims_size = 1;
71     int32_t quant_dim_size = shape.dim(quantized_dimension);
72     size_t inner_dims_size = 1;
73     assert(quant_dim_size == tensor.scales().size());
74
75     for (int i = 0; i < quantized_dimension; ++i)
76       outer_dims_size *= shape.dim(i);
77     for (int i = quantized_dimension + 1; i < shape.num_dims(); ++i)
78       inner_dims_size *= shape.dim(i);
79
80     assert(shape.num_elements() == outer_dims_size * quant_dim_size * inner_dims_size);
81
82     std::vector<float> dequantized_data;
83     dequantized_data.reserve(shape.num_elements());
84     for (size_t outer_it = 0; outer_it < outer_dims_size; ++outer_it)
85       for (int32_t channel = 0; channel < quant_dim_size; ++channel)
86       {
87         float scale = tensor.scales()[channel];
88         size_t offset = inner_dims_size * (quant_dim_size * outer_it + channel);
89         std::vector<float> part_dequantized_data =
90           dequantize(data.data() + offset, inner_dims_size, scale, 0);
91         dequantized_data.insert(dequantized_data.end(), part_dequantized_data.begin(),
92                                 part_dequantized_data.end());
93       }
94     return dequantized_data;
95   }
96   else
97   {
98     assert(false && "Unsupported type.");
99   }
100 }
101
102 Matcher<std::vector<float>> FloatArrayNear(const std::vector<float> &values, float max_abs_error)
103 {
104   std::vector<Matcher<float>> matchers;
105   matchers.reserve(values.size());
106   for (const float v : values)
107   {
108     matchers.emplace_back(FloatNear(v, max_abs_error));
109   }
110   return ElementsAreArray(matchers);
111 }
112
113 std::vector<int32_t> extractTensorShape(const Tensor &tensor)
114 {
115   std::vector<int32_t> result;
116   int dims = tensor.shape().num_dims();
117   for (int i = 0; i < dims; i++)
118   {
119     result.push_back(tensor.shape().dim(i));
120   }
121   return result;
122 }
123
124 } // namespace testing
125 } // namespace kernels
126 } // namespace luci_interpreter