2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __ONERT_BACKEND_CONTROLFLOW_TENSOR_REGISTRY_H__
18 #define __ONERT_BACKEND_CONTROLFLOW_TENSOR_REGISTRY_H__
20 #include "backend/cpu_common/TensorRegistry.h"
21 #include "backend/ITensorRegistry.h"
34 * @brief Tensor registry class for controlflow backend
36 * This class contains three types of tensors. Two native tensors(tensors that are managed by this
37 * backend) and the other is migrant tensor.
39 * - NativeIOTensor - @c IOTensor managed by this backend ( in @c _base_reg )
40 * - NOTE The tensor it actually points to can be from another backend
41 * - NativeOwnTensor - @c cpu_common::Tensor managed by this backend ( in @c _base_reg )
42 * - MigrantTensor - @c IPortableTensor managed by other backends
44 * @note @c _base_reg is used in implementation to reuse @c cpu_common::StaticTensorManager
47 class TensorRegistry : public ITensorRegistry
50 TensorRegistry() : _base_reg{new cpu_common::TensorRegistry} {}
52 ITensor *getITensor(const ir::OperandIndex &ind) override
54 auto base_tensor = _base_reg->getITensor(ind);
57 return getNativeIOTensor(ind);
60 ITensor *getNativeITensor(const ir::OperandIndex &ind) override
62 auto base_tensor = _base_reg->getNativeITensor(ind);
65 return getNativeIOTensor(ind);
68 IPortableTensor *getPortableTensor(const ir::OperandIndex &ind)
70 auto base_tensor = _base_reg->getPortableTensor(ind);
73 return getNativeIOTensor(ind);
76 IPortableTensor *getNativeTensor(const ir::OperandIndex &ind)
78 auto base_tensor = _base_reg->getNativeTensor(ind);
81 return getNativeIOTensor(ind);
84 Tensor *getNativeOwnTensor(const ir::OperandIndex &ind)
86 return _base_reg->getNativeTensor(ind);
89 IOTensor *getNativeIOTensor(const ir::OperandIndex &ind)
91 auto tensor = _native_io_tensors.find(ind);
92 if (tensor != _native_io_tensors.end())
93 return tensor->second.get();
97 bool setMigrantTensor(const ir::OperandIndex &ind, IPortableTensor *tensor) override
100 assert(!getITensor(ind)); // For the ind, tensor is not registered yet
101 _base_reg->setMigrantTensor(ind, tensor);
105 void setNativeOwnTensor(ir::OperandIndex ind, std::unique_ptr<Tensor> &&tensor)
108 assert(!getITensor(ind)); // For the ind, tensor is not registered yet
109 _base_reg->setNativeTensor(ind, std::move(tensor));
112 void setNativeIOTensor(ir::OperandIndex ind, std::unique_ptr<IOTensor> &&tensor)
115 assert(!getITensor(ind)); // For the ind, tensor is not registered yet
116 _native_io_tensors[ind] = std::move(tensor);
119 const ir::OperandIndexMap<std::unique_ptr<IOTensor>> &native_io_tensors()
121 return _native_io_tensors;
123 std::shared_ptr<cpu_common::TensorRegistry> base_reg() { return _base_reg; }
126 std::shared_ptr<cpu_common::TensorRegistry> _base_reg;
127 ir::OperandIndexMap<std::unique_ptr<IOTensor>> _native_io_tensors;
130 } // namespace controlflow
131 } // namespace backend
134 #endif // ifndef __ONERT_BACKEND_CONTROLFLOW_TENSOR_REGISTRY_H__