Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / backend / controlflow / TensorRegistry.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_BACKEND_CONTROLFLOW_TENSOR_REGISTRY_H__
18 #define __ONERT_BACKEND_CONTROLFLOW_TENSOR_REGISTRY_H__
19
20 #include "backend/cpu_common/TensorRegistry.h"
21 #include "backend/ITensorRegistry.h"
22 #include "Tensor.h"
23 #include "IOTensor.h"
24 #include <assert.h>
25
26 namespace onert
27 {
28 namespace backend
29 {
30 namespace controlflow
31 {
32
33 /**
34  * @brief Tensor registry class for controlflow backend
35  *
36  * This class contains three types of tensors. Two native tensors(tensors that are managed by this
37  * backend) and the other is migrant tensor.
38  *
39  * - NativeIOTensor  - @c IOTensor managed by this backend ( in @c _base_reg )
40  *     - NOTE The tensor it actually points to can be from another backend
41  * - NativeOwnTensor - @c cpu_common::Tensor managed by this backend ( in @c _base_reg )
42  * - MigrantTensor   - @c IPortableTensor managed by other backends
43  *
44  * @note @c _base_reg is used in implementation to reuse @c cpu_common::StaticTensorManager
45  *
46  */
47 class TensorRegistry : public ITensorRegistry
48 {
49 public:
50   TensorRegistry() : _base_reg{new cpu_common::TensorRegistry} {}
51
52   ITensor *getITensor(const ir::OperandIndex &ind) override
53   {
54     auto base_tensor = _base_reg->getITensor(ind);
55     if (base_tensor)
56       return base_tensor;
57     return getNativeIOTensor(ind);
58   }
59
60   ITensor *getNativeITensor(const ir::OperandIndex &ind) override
61   {
62     auto base_tensor = _base_reg->getNativeITensor(ind);
63     if (base_tensor)
64       return base_tensor;
65     return getNativeIOTensor(ind);
66   }
67
68   IPortableTensor *getPortableTensor(const ir::OperandIndex &ind)
69   {
70     auto base_tensor = _base_reg->getPortableTensor(ind);
71     if (base_tensor)
72       return base_tensor;
73     return getNativeIOTensor(ind);
74   }
75
76   IPortableTensor *getNativeTensor(const ir::OperandIndex &ind)
77   {
78     auto base_tensor = _base_reg->getNativeTensor(ind);
79     if (base_tensor)
80       return base_tensor;
81     return getNativeIOTensor(ind);
82   }
83
84   Tensor *getNativeOwnTensor(const ir::OperandIndex &ind)
85   {
86     return _base_reg->getNativeTensor(ind);
87   }
88
89   IOTensor *getNativeIOTensor(const ir::OperandIndex &ind)
90   {
91     auto tensor = _native_io_tensors.find(ind);
92     if (tensor != _native_io_tensors.end())
93       return tensor->second.get();
94     return nullptr;
95   }
96
97   bool setMigrantTensor(const ir::OperandIndex &ind, IPortableTensor *tensor) override
98   {
99     assert(tensor);
100     assert(!getITensor(ind)); // For the ind, tensor is not registered yet
101     _base_reg->setMigrantTensor(ind, tensor);
102     return true;
103   }
104
105   void setNativeOwnTensor(ir::OperandIndex ind, std::unique_ptr<Tensor> &&tensor)
106   {
107     assert(tensor);
108     assert(!getITensor(ind)); // For the ind, tensor is not registered yet
109     _base_reg->setNativeTensor(ind, std::move(tensor));
110   }
111
112   void setNativeIOTensor(ir::OperandIndex ind, std::unique_ptr<IOTensor> &&tensor)
113   {
114     assert(tensor);
115     assert(!getITensor(ind)); // For the ind, tensor is not registered yet
116     _native_io_tensors[ind] = std::move(tensor);
117   }
118
119   const ir::OperandIndexMap<std::unique_ptr<IOTensor>> &native_io_tensors()
120   {
121     return _native_io_tensors;
122   }
123   std::shared_ptr<cpu_common::TensorRegistry> base_reg() { return _base_reg; }
124
125 private:
126   std::shared_ptr<cpu_common::TensorRegistry> _base_reg;
127   ir::OperandIndexMap<std::unique_ptr<IOTensor>> _native_io_tensors;
128 };
129
130 } // namespace controlflow
131 } // namespace backend
132 } // namespace onert
133
134 #endif // ifndef __ONERT_BACKEND_CONTROLFLOW_TENSOR_REGISTRY_H__