2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/TestUtils.h"
20 namespace luci_interpreter
27 using ::testing::FloatNear;
28 using ::testing::Matcher;
30 Tensor makeOutputTensor(DataType element_type) { return Tensor(element_type, {}, {}, ""); }
32 Tensor makeOutputTensor(DataType element_type, float scale, int32_t zero_point)
34 return Tensor(element_type, {}, {{scale}, {zero_point}}, "");
37 std::vector<float> dequantizeTensorData(const Tensor &tensor)
39 if (tensor.element_type() == DataType::U8)
41 std::vector<uint8_t> data = extractTensorData<uint8_t>(tensor);
42 return dequantize(data.data(), data.size(), tensor.scale(), tensor.zero_point());
44 if (tensor.element_type() == DataType::S8)
46 std::vector<int8_t> data = extractTensorData<int8_t>(tensor);
47 return dequantize(data.data(), data.size(), tensor.scale(), tensor.zero_point());
49 else if (tensor.element_type() == DataType::S16)
51 // S16 quantization is symmetric, so zero point should be zero.
52 for (auto zp : tensor.zero_points())
58 std::vector<int16_t> data = extractTensorData<int16_t>(tensor);
59 if (tensor.scales().size() == 1)
61 return dequantize(data.data(), data.size(), tensor.scale(), 0);
64 // quantize_dimension breaks shape into two parts:
65 // inner dimensions that contains continuous data with one quantization type
66 // outer dimensions that contains other dimensions
67 const Shape shape = tensor.shape();
68 const int32_t quantized_dimension = tensor.quantized_dimension();
69 assert(quantized_dimension < shape.num_dims());
70 size_t outer_dims_size = 1;
71 int32_t quant_dim_size = shape.dim(quantized_dimension);
72 size_t inner_dims_size = 1;
73 assert(quant_dim_size == tensor.scales().size());
75 for (int i = 0; i < quantized_dimension; ++i)
76 outer_dims_size *= shape.dim(i);
77 for (int i = quantized_dimension + 1; i < shape.num_dims(); ++i)
78 inner_dims_size *= shape.dim(i);
80 assert(shape.num_elements() == outer_dims_size * quant_dim_size * inner_dims_size);
82 std::vector<float> dequantized_data;
83 dequantized_data.reserve(shape.num_elements());
84 for (size_t outer_it = 0; outer_it < outer_dims_size; ++outer_it)
85 for (int32_t channel = 0; channel < quant_dim_size; ++channel)
87 float scale = tensor.scales()[channel];
88 size_t offset = inner_dims_size * (quant_dim_size * outer_it + channel);
89 std::vector<float> part_dequantized_data =
90 dequantize(data.data() + offset, inner_dims_size, scale, 0);
91 dequantized_data.insert(dequantized_data.end(), part_dequantized_data.begin(),
92 part_dequantized_data.end());
94 return dequantized_data;
98 assert(false && "Unsupported type.");
102 Matcher<std::vector<float>> FloatArrayNear(const std::vector<float> &values, float max_abs_error)
104 std::vector<Matcher<float>> matchers;
105 matchers.reserve(values.size());
106 for (const float v : values)
108 matchers.emplace_back(FloatNear(v, max_abs_error));
110 return ElementsAreArray(matchers);
113 std::vector<int32_t> extractTensorShape(const Tensor &tensor)
115 std::vector<int32_t> result;
116 int dims = tensor.shape().num_dims();
117 for (int i = 0; i < dims; i++)
119 result.push_back(tensor.shape().dim(i));
124 } // namespace testing
125 } // namespace kernels
126 } // namespace luci_interpreter