2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include <initializer_list>
25 const index_t MAX_DIMS = 8;
27 /** @brief Shape of Tensor object
29 * This class represents size of multidimensional table
41 Shape(std::initializer_list<T> data): _dims(data.size())
43 assert(_dims <= MAX_DIMS);
44 index_t *dataPtr = _data;
47 *dataPtr++ = static_cast<index_t>(value);
51 Shape(const Shape &orig): _dims(orig._dims)
53 for (index_t i = 0; i < _dims; ++i)
54 _data[i] = orig._data[i];
57 Shape &operator=(const Shape &orig)
60 for (index_t i = 0; i < _dims; ++i)
61 _data[i] = orig._data[i];
65 /** Returns number of table dimensions*/
66 index_t getDims() const
71 /** Sets number of table dimensions*/
72 void setDims(index_t dims)
74 assert(dims < MAX_DIMS);
78 /** Returns size of selected dimension*/
79 index_t &operator[](index_t dim)
85 /** Returns size of selected dimension, constant version*/
86 index_t operator[](index_t dim) const
92 /** Returns number of elements in table*/
93 index_t getNumElems() const
96 for (index_t i = 0; i < _dims; ++i)
104 index_t _data[MAX_DIMS];
108 /** This class points to one cell in table*/
111 /** @brief Multidimensional table
113 * This class represents multidimensional table.
114 * It is used to provide NN model interface and intermediate objects in inference sequence.
119 Tensor(): Tensor(Shape{}){}
121 Tensor(Tensor &&orig): _shape(orig._shape), _data(orig._data), _managed(orig._managed)
123 orig._managed = false;
126 /** Constructs table, that references external data as its content*/
127 Tensor(const Shape& shape, float *data): _shape(shape), _data(data){}
129 Tensor(const Shape& shape): _shape(shape), _data(new float[shape.getNumElems()]), _managed(true) {}
137 /** Copies data from external source into table*/
138 void fillData(const float *data, const index_t num_elements)
141 std::memcpy(_data, data, num_elements * sizeof(float));
144 Tensor& operator=(const Tensor& t) {
156 // this tensor is not constant so we can write data into it
159 fillData(t._data, _shape.getNumElems());
165 /** Access element in table by index*/
166 float &at(const Index &idx)
168 return *(_data + getOffset(idx));
171 /** Access element in table by index, constant version*/
172 float at(const Index &idx) const
174 return *(_data + getOffset(idx));
177 void reshape(const Shape &shape)
179 index_t oldVolume = _shape.getNumElems();
181 if (_managed && oldVolume != shape.getNumElems())
183 float* new_data = new float[shape.getNumElems()];
185 std::swap(new_data, _data);
189 /** Free memory, set empty shape */
198 /** Returns pointer to raw data*/
204 /** Returns pointer to raw data, constant version*/
205 const float *getData() const
210 /** Returns size object of this table*/
211 const Shape &getShape() const
217 index_t getOffset(const Index &idx) const
219 assert(idx.getDims() == _shape.getDims());
222 for (index_t i = _shape.getDims() - 1; i >= 0; --i)
224 assert(idx[i] < _shape[i]);
225 offset += stride * idx[i];
233 bool _managed = false;