2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 * @ingroup COM_AI_RUNTIME
20 * @brief This file contains nnfw::misc::tensor::Object class
23 #ifndef __NNFW_MISC_TENSOR_OBJECT_H__
24 #define __NNFW_MISC_TENSOR_OBJECT_H__
26 #include "misc/tensor/Shape.h"
27 #include "misc/tensor/Index.h"
28 #include "misc/tensor/IndexIterator.h"
29 #include "misc/tensor/NonIncreasingStride.h"
30 #include "misc/tensor/Reader.h"
42 * @brief Class to build a tensor using specific generator
43 * @tparam T Type of tensor element
46 template <typename T> class Object final : public Reader<T>
50 * @brief Function to generate tensor element
52 using Generator = std::function<T(const Shape &shape, const Index &index)>;
56 * @brief Construct a new @c Object object
57 * @param[in] shape Tensor shape
58 * @param[in] fn Function to generate tensor elements
60 Object(const Shape &shape, const Generator &fn) : _shape{shape}
65 // Handle scalar object
66 if (shape.rank() == 0)
69 _values.at(0) = fn(_shape, 0);
73 // Pre-allocate buffer
74 _values.resize(_shape.dim(0) * _stride.at(0));
77 iterate(_shape) << [this, &fn](const Index &index) {
78 _values.at(_stride.offset(index)) = fn(_shape, index);
85 * @brief Get reference of shape
86 * @return Reference of shape
88 const Shape &shape(void) const { return _shape; }
92 * @brief Get and element of tensor
93 * @param[in] index Index of a tensor element
94 * @return Value of tensor element
96 T at(const Index &index) const override { return _values.at(_stride.offset(index)); }
100 NonIncreasingStride _stride;
103 std::vector<T> _values;
106 } // namespace tensor
110 #endif // __NNFW_MISC_FEATURE_OBJECT_H__