Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / runtime / libs / misc / include / misc / tensor / Object.h
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * @file Object.h
19  * @ingroup COM_AI_RUNTIME
20  * @brief This file contains nnfw::misc::tensor::Object class
21  */
22
23 #ifndef __NNFW_MISC_TENSOR_OBJECT_H__
24 #define __NNFW_MISC_TENSOR_OBJECT_H__
25
26 #include "misc/tensor/Shape.h"
27 #include "misc/tensor/Index.h"
28 #include "misc/tensor/IndexIterator.h"
29 #include "misc/tensor/NonIncreasingStride.h"
30 #include "misc/tensor/Reader.h"
31
32 #include <vector>
33
34 namespace nnfw
35 {
36 namespace misc
37 {
38 namespace tensor
39 {
40
41 /**
42  * @brief Class to build a tensor using specific generator
43  * @tparam T  Type of tensor element
44  */
45
46 template <typename T> class Object final : public Reader<T>
47 {
48 public:
49   /**
50    * @brief Function to generate tensor element
51    */
52   using Generator = std::function<T(const Shape &shape, const Index &index)>;
53
54 public:
55   /**
56    * @brief Construct a new @c Object object
57    * @param[in] shape   Tensor shape
58    * @param[in] fn      Function to generate tensor elements
59    */
60   Object(const Shape &shape, const Generator &fn) : _shape{shape}
61   {
62     // Set 'stride'
63     _stride.init(shape);
64
65     // Handle scalar object
66     if (shape.rank() == 0)
67     {
68       _values.resize(1);
69       _values.at(0) = fn(_shape, 0);
70     }
71     else
72     {
73       // Pre-allocate buffer
74       _values.resize(_shape.dim(0) * _stride.at(0));
75
76       // Set 'value'
77       iterate(_shape) <<
78         [this, &fn](const Index &index) { _values.at(_stride.offset(index)) = fn(_shape, index); };
79     }
80   }
81
82 public:
83   /**
84    * @brief Get reference of shape
85    * @return Reference of shape
86    */
87   const Shape &shape(void) const { return _shape; }
88
89 public:
90   /**
91    * @brief Get and element of tensor
92    * @param[in] index   Index of a tensor element
93    * @return Value of tensor element
94    */
95   T at(const Index &index) const override { return _values.at(_stride.offset(index)); }
96
97 private:
98   Shape _shape;
99   NonIncreasingStride _stride;
100
101 private:
102   std::vector<T> _values;
103 };
104
105 } // namespace tensor
106 } // namespace misc
107 } // namespace nnfw
108
109 #endif // __NNFW_MISC_FEATURE_OBJECT_H__