85551312490ccc18fa9878a8dd5891e7a700d78e
[platform/core/ml/nnfw.git] / runtime / onert / core / include / backend / ITensorRegistry.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_BACKEND_ITENSOR_REGISTRY__
18 #define __ONERT_BACKEND_ITENSOR_REGISTRY__
19
20 #include <memory>
21
22 #include "ir/Index.h"
23 #include "backend/ITensor.h"
24
25 namespace onert
26 {
27 namespace backend
28 {
29
30 struct ITensorRegistry
31 {
32   /**
33    * @brief Deconstruct itself
34    */
35   virtual ~ITensorRegistry() = default;
36
37   /**
38    * @brief Returns pointer of ITensor among native and migrant tensors
39    *
40    * Native Tensor is a tensor that is managed by this backend
41    * Migrant Tensor is a tensor that is imported from another backend
42    *
43    * @note  Return tensor cannot be used longer than dynamic tensor manager
44    */
45   virtual std::shared_ptr<ITensor> getITensor(const ir::OperandIndex &) = 0;
46   /**
47    * @brief Returns pointer of ITensor among native tensors
48    *
49    * Unlike @c getITensor , this function only searches from native tensors
50    *
51    * @note  Returned tensor cannot be used longer than dynamic tensor manager
52    */
53   virtual std::shared_ptr<ITensor> getNativeITensor(const ir::OperandIndex &) = 0;
54 };
55
56 } // namespace backend
57 } // namespace onert
58
59 #include "ir/OperandIndexMap.h"
60 #include "backend/IPortableTensor.h"
61
62 namespace onert
63 {
64 namespace backend
65 {
66
67 /**
68  * @brief  TensorRegistry template class for the convenience of backend implementations
69  *
70  * If a backend uses @c IPortableTensor , and there is no special reason to implement @c
71  * ITensorRegistry on your own, you may just use this default implementation.
72  *
73  * @tparam T_Tensor Tensor type. Must be a subclass of @c onert::backend::IPortableTensor .
74  */
75 template <typename T_Tensor> class PortableTensorRegistryTemplate : public ITensorRegistry
76 {
77 public:
78   std::shared_ptr<ITensor> getITensor(const ir::OperandIndex &ind) override
79   {
80     static_assert(std::is_base_of<ITensor, T_Tensor>::value, "T_Tensor must derive from ITensor.");
81     auto external_tensor = _migrant.find(ind);
82     if (external_tensor != _migrant.end())
83       return external_tensor->second;
84     return getNativeTensor(ind);
85   }
86
87   std::shared_ptr<ITensor> getNativeITensor(const ir::OperandIndex &ind) override
88   {
89     return getNativeTensor(ind);
90   }
91
92   std::shared_ptr<IPortableTensor> getPortableTensor(const ir::OperandIndex &ind)
93   {
94     auto external_tensor = _migrant.find(ind);
95     if (external_tensor != _migrant.end())
96     {
97       if (external_tensor->second)
98         return external_tensor->second;
99     }
100     return getNativeTensor(ind);
101   }
102
103   std::shared_ptr<T_Tensor> getNativeTensor(const ir::OperandIndex &ind)
104   {
105     auto tensor = _native.find(ind);
106     if (tensor != _native.end())
107       return tensor->second;
108     return nullptr;
109   }
110
111   bool setMigrantTensor(const ir::OperandIndex &ind, const std::shared_ptr<IPortableTensor> &tensor)
112   {
113     // TODO Uncomment this as two tensors for an index is not allowed.
114     //      But now it is temporarily allowed as a workaround. External one hides Managed one.
115     // auto itr = _native.find(ind);
116     // if (itr != _native.end() && itr->second != nullptr && tensor != nullptr)
117     //  throw std::runtime_error{
118     //      "Tried to set an migrant tensor but an native tensor already exists."};
119     _migrant[ind] = tensor;
120     return true;
121   }
122
123   void setNativeTensor(const ir::OperandIndex &ind, const std::shared_ptr<T_Tensor> &tensor)
124   {
125     auto itr = _migrant.find(ind);
126     if (itr != _migrant.end() && itr->second != nullptr && tensor != nullptr)
127       throw std::runtime_error{
128           "Tried to set a native tensor but an migrant tensor already exists."};
129     _native[ind] = tensor;
130   }
131
132   const ir::OperandIndexMap<std::shared_ptr<T_Tensor>> &native_tensors() { return _native; }
133
134   const ir::OperandIndexMap<std::shared_ptr<IPortableTensor>> &migrant_tensors()
135   {
136     return _migrant;
137   }
138
139 private:
140   ir::OperandIndexMap<std::shared_ptr<IPortableTensor>> _migrant;
141   ir::OperandIndexMap<std::shared_ptr<T_Tensor>> _native;
142 };
143
144 } // namespace backend
145 } // namespace onert
146
147 #endif // __ONERT_BACKEND_ITENSOR_REGISTRY__