cba4f1baf70b649384a12614fbf920f5619f0335
[platform/core/ml/nnfw.git] / runtime / libs / misc / include / misc / tensor / Object.h
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * @file Object.h
19  * @ingroup COM_AI_RUNTIME
20  * @brief This file contains nnfw::misc::tensor::Object class
21  */
22
23 #ifndef __NNFW_MISC_TENSOR_OBJECT_H__
24 #define __NNFW_MISC_TENSOR_OBJECT_H__
25
26 #include "misc/tensor/Shape.h"
27 #include "misc/tensor/Index.h"
28 #include "misc/tensor/IndexIterator.h"
29 #include "misc/tensor/NonIncreasingStride.h"
30 #include "misc/tensor/Reader.h"
31
32 #include <vector>
33
34 namespace nnfw
35 {
36 namespace misc
37 {
38 namespace tensor
39 {
40
41 /**
42  * @brief Class to build a tensor using specific generator
43  * @tparam T  Type of tensor element
44  */
45
46 template <typename T> class Object final : public Reader<T>
47 {
48 public:
49   /**
50    * @brief Function to generate tensor element
51    */
52   using Generator = std::function<T(const Shape &shape, const Index &index)>;
53
54 public:
55   /**
56    * @brief Construct a new @c Object object
57    * @param[in] shape   Tensor shape
58    * @param[in] fn      Function to generate tensor elements
59    */
60   Object(const Shape &shape, const Generator &fn) : _shape{shape}
61   {
62     // Set 'stride'
63     _stride.init(shape);
64
65     // Handle scalar object
66     if (shape.rank() == 0)
67     {
68       _values.resize(1);
69       _values.at(0) = fn(_shape, 0);
70     }
71     else
72     {
73       // Pre-allocate buffer
74       _values.resize(_shape.dim(0) * _stride.at(0));
75
76       // Set 'value'
77       iterate(_shape) << [this, &fn](const Index &index) {
78         _values.at(_stride.offset(index)) = fn(_shape, index);
79       };
80     }
81   }
82
83 public:
84   /**
85    * @brief Get reference of shape
86    * @return Reference of shape
87    */
88   const Shape &shape(void) const { return _shape; }
89
90 public:
91   /**
92    * @brief Get and element of tensor
93    * @param[in] index   Index of a tensor element
94    * @return Value of tensor element
95    */
96   T at(const Index &index) const override { return _values.at(_stride.offset(index)); }
97
98 private:
99   Shape _shape;
100   NonIncreasingStride _stride;
101
102 private:
103   std::vector<T> _values;
104 };
105
106 } // namespace tensor
107 } // namespace misc
108 } // namespace nnfw
109
110 #endif // __NNFW_MISC_FEATURE_OBJECT_H__