Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / backend / builtin / train / TensorRegistry.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__
18 #define __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__
19
20 #include <backend/train/ITensorRegistry.h>
21
22 #include "../IOTensor.h"
23 #include "../Tensor.h"
24 #include "Tensor.h"
25
26 namespace onert
27 {
28 namespace backend
29 {
30 namespace builtin
31 {
32 namespace train
33 {
34
35 using BaseTensorRegistry =
36   backend::train::PortableTensorRegistryTemplate<Tensor, TrainableTensor, DerivativeTensor,
37                                                  GradientTensor>;
38
39 class TensorRegistry : public backend::train::ITensorRegistry
40 {
41 public:
42   TensorRegistry() : _base_reg{new BaseTensorRegistry} {}
43
44   ITensor *getITensor(const ir::OperandIndex &index) override
45   {
46     auto base_tensor = _base_reg->getITensor(index);
47     if (base_tensor)
48       return base_tensor;
49     return getNativeIOTensor(index);
50   }
51
52   ITensor *getNativeITensor(const ir::OperandIndex &index) override
53   {
54     auto base_tensor = _base_reg->getNativeITensor(index);
55     if (base_tensor)
56       return base_tensor;
57     return getNativeIOTensor(index);
58   }
59
60   IPortableTensor *getPortableTensor(const ir::OperandIndex &index)
61   {
62     auto base_tensor = _base_reg->getPortableTensor(index);
63     if (base_tensor)
64       return base_tensor;
65     return getNativeIOTensor(index);
66   }
67
68   IOTensor *getNativeIOTensor(const ir::OperandIndex &index)
69   {
70     auto tensor = _native_io_tensors.find(index);
71     if (tensor != _native_io_tensors.end())
72       return tensor->second.get();
73     return nullptr;
74   }
75
76   ITensor *getDerivativeITensor(const ir::OperandIndex &index) override
77   {
78     return _base_reg->getDerivativeTensor(index);
79   }
80
81   ITensor *getGradientITensor(const ir::OperandIndex &index) override
82   {
83     return _base_reg->getGradientTensor(index);
84   }
85
86   DerivativeTensor *getDerivativeTensor(const ir::OperandIndex &index)
87   {
88     return _base_reg->getDerivativeTensor(index);
89   }
90
91   bool setMigrantTensor(const ir::OperandIndex &index, IPortableTensor *tensor) override
92   {
93     assert(tensor);
94     assert(!getITensor(index)); // For the index, tensor is not registered yet
95     _base_reg->setMigrantTensor(index, tensor);
96     return true;
97   }
98
99   void setDerivativeTensor(const ir::OperandIndex &index, std::unique_ptr<DerivativeTensor> tensor)
100   {
101     _base_reg->setDerivativeTensor(index, std::move(tensor));
102   }
103
104   void setGradientTensor(const ir::OperandIndex &index, std::unique_ptr<GradientTensor> tensor)
105   {
106     _base_reg->setGradientTensor(index, std::move(tensor));
107   }
108
109   void setNativeIOTensor(ir::OperandIndex index, std::unique_ptr<IOTensor> &&tensor)
110   {
111     assert(tensor);
112     assert(!getITensor(index)); // For the index, tensor is not registered yet
113     _native_io_tensors[index] = std::move(tensor);
114   }
115
116   const ir::OperandIndexMap<std::unique_ptr<IOTensor>> &native_io_tensors()
117   {
118     return _native_io_tensors;
119   }
120   std::shared_ptr<BaseTensorRegistry> base_reg() { return _base_reg; }
121
122 private:
123   std::shared_ptr<BaseTensorRegistry> _base_reg;
124   ir::OperandIndexMap<std::unique_ptr<IOTensor>> _native_io_tensors;
125 };
126
127 } // namespace train
128 } // namespace builtin
129 } // namespace backend
130 } // namespace onert
131
132 #endif // __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__