arm_compute v17.10
[platform/upstream/armcl.git] / arm_compute / graph / Tensor.h
1 /*
2  * Copyright (c) 2017 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef __ARM_COMPUTE_GRAPH_TENSOR_H__
25 #define __ARM_COMPUTE_GRAPH_TENSOR_H__
26
27 #include "arm_compute/graph/ITensorAccessor.h"
28 #include "arm_compute/graph/Types.h"
29 #include "support/ToolchainSupport.h"
30
31 #include <memory>
32
33 namespace arm_compute
34 {
35 namespace graph
36 {
37 /** Tensor class */
38 class Tensor
39 {
40 public:
41     /** Constructor
42      *
43      * @param[in] info Tensor info to use
44      */
45     Tensor(TensorInfo &&info);
46     /** Constructor
47      *
48      * @param[in] accessor Tensor accessor
49      */
50     template <typename AccessorType>
51     Tensor(std::unique_ptr<AccessorType> accessor)
52         : _target(TargetHint::DONT_CARE), _info(), _accessor(std::move(accessor)), _tensor(nullptr)
53     {
54     }
55     /** Constructor
56      *
57      * @param[in] accessor Tensor accessor
58      */
59     template <typename AccessorType>
60     Tensor(AccessorType &&accessor)
61         : _target(TargetHint::DONT_CARE), _info(), _accessor(arm_compute::support::cpp14::make_unique<AccessorType>(std::forward<AccessorType>(accessor))), _tensor(nullptr)
62     {
63     }
64     /** Constructor
65      *
66      * @param[in] info     Tensor info to use
67      * @param[in] accessor Tensor accessor
68      */
69     template <typename AccessorType>
70     Tensor(TensorInfo &&info, AccessorType &&accessor)
71         : _target(TargetHint::DONT_CARE), _info(info), _accessor(arm_compute::support::cpp14::make_unique<AccessorType>(std::forward<AccessorType>(accessor))), _tensor(nullptr)
72     {
73     }
74     /** Default Destructor */
75     ~Tensor() = default;
76     /** Move Constructor
77      *
78      * @param[in] src Tensor to move
79      */
80     Tensor(Tensor &&src) noexcept;
81
82     /** Sets the given TensorInfo to the tensor
83      *
84      * @param[in] info TensorInfo to set
85      */
86     void set_info(TensorInfo &&info);
87     /** Calls accessor on tensor
88      *
89      * @return True if succeeds else false
90      */
91     bool call_accessor();
92     /** Sets target of the tensor
93      *
94      * @param[in] target Target where the tensor should be pinned in
95      *
96      * @return
97      */
98     ITensor *set_target(TargetHint target);
99     /** Returns tensor's TensorInfo
100      *
101      * @return TensorInfo of the tensor
102      */
103     const TensorInfo &info() const;
104     /** Returns a pointer to the internal tensor
105      *
106      * @return Tensor
107      */
108     ITensor *tensor();
109     /** Allocates and fills the tensor if needed */
110     void allocate_and_fill_if_needed();
111     /** Allocates the tensor */
112     void allocate();
113     /** Return the target that this tensor is pinned on
114      *
115      * @return Target of the tensor
116      */
117     TargetHint target() const;
118
119 private:
120     TargetHint                       _target;   /**< Target that this tensor is pinned on */
121     TensorInfo                       _info;     /**< Tensor metadata */
122     std::unique_ptr<ITensorAccessor> _accessor; /**< Tensor Accessor */
123     std::unique_ptr<ITensor>         _tensor;   /**< Tensor */
124 };
125 } // namespace graph
126 } // namespace arm_compute
127 #endif /* __ARM_COMPUTE_GRAPH_TENSOR_H__ */