2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef LUCI_INTERPRETER_CORE_TENSOR_H
18 #define LUCI_INTERPRETER_CORE_TENSOR_H
20 #include "luci_interpreter/core/DataType.h"
29 namespace luci_interpreter
35 explicit Shape(int rank) : _dims(rank, 0) {}
37 Shape(std::initializer_list<int32_t> dims) : _dims(dims.begin(), dims.end()) {}
39 int num_dims() const { return _dims.size(); }
41 int32_t dim(int i) const
43 assert(i >= 0 && i < static_cast<int>(_dims.size()));
49 assert(i >= 0 && i < static_cast<int>(_dims.size()));
53 int32_t num_elements() const
56 for (const int32_t dim : _dims)
63 bool operator==(const Shape &other) const { return _dims == other._dims; }
65 bool operator!=(const Shape &other) const { return !operator==(other); }
68 std::vector<int32_t> _dims;
71 // Tensor affine quantization parameters.
73 // The relationship between real and quantized values:
74 // real_value = (quantized_value - zero_point) * scale
76 // In per-tensor case, 'scale' and 'zero_point' are one element each.
77 // In per-channel case, 'scale' and 'zero_point' are N elements each, where N is the size
78 // of the quantized dimension.
80 // Note that due to historical and performance reasons, per-tensor quantization uses unsigned
81 // integer types, while per-channel uses signed types assuming 'zero_point' == 0.
82 struct AffineQuantization
84 std::vector<float> scale;
85 std::vector<int32_t> zero_point;
86 int32_t quantized_dimension;
92 Tensor(DataType element_type, Shape shape, AffineQuantization quantization, std::string name);
94 DataType element_type() const { return _element_type; }
96 const Shape &shape() const { return _shape; }
100 assert(_quantization.scale.size() == 1);
101 return _quantization.scale[0];
104 float zero_point() const
106 assert(_quantization.zero_point.size() == 1);
107 return _quantization.zero_point[0];
110 const std::vector<float> &scales() const { return _quantization.scale; }
112 const std::vector<int32_t> &zero_points() const { return _quantization.zero_point; }
114 int32_t quantized_dimension() const { return _quantization.quantized_dimension; }
116 template <typename T> const T *data() const { return reinterpret_cast<const T *>(_data.get()); }
118 template <typename T> T *data() { return reinterpret_cast<T *>(_data.get()); }
120 const std::string &name() const { return _name; }
122 void readData(void *data_ptr, size_t data_size) const;
124 void writeData(const void *data_ptr, size_t data_size);
126 void resize(const Shape &new_shape);
129 DataType _element_type;
131 AffineQuantization _quantization;
132 std::unique_ptr<uint8_t[]> _data;
136 } // namespace luci_interpreter
138 #endif // LUCI_INTERPRETER_CORE_TENSOR_H