2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __ONERT_BACKEND_ITENSOR_REGISTRY__
18 #define __ONERT_BACKEND_ITENSOR_REGISTRY__
23 #include "backend/ITensor.h"
30 struct ITensorRegistry
33 * @brief Deconstruct itself
35 virtual ~ITensorRegistry() = default;
38 * @brief Returns pointer of ITensor among native and migrant tensors
40 * Native Tensor is a tensor that is managed by this backend
41 * Migrant Tensor is a tensor that is imported from another backend
43 * @note Return tensor cannot be used longer than dynamic tensor manager
45 virtual std::shared_ptr<ITensor> getITensor(const ir::OperandIndex &) = 0;
47 * @brief Returns pointer of ITensor among native tensors
49 * Unlike @c getITensor , this function only searches from native tensors
51 * @note Returned tensor cannot be used longer than dynamic tensor manager
53 virtual std::shared_ptr<ITensor> getNativeITensor(const ir::OperandIndex &) = 0;
56 } // namespace backend
59 #include "ir/OperandIndexMap.h"
60 #include "backend/IPortableTensor.h"
68 * @brief TensorRegistry template class for the convenience of backend implementations
70 * If a backend uses @c IPortableTensor , and there is no special reason to implement @c
71 * ITensorRegistry on your own, you may just use this default implementation.
73 * @tparam T_Tensor Tensor type. Must be a subclass of @c onert::backend::IPortableTensor .
75 template <typename T_Tensor> class PortableTensorRegistryTemplate : public ITensorRegistry
78 std::shared_ptr<ITensor> getITensor(const ir::OperandIndex &ind) override
80 static_assert(std::is_base_of<ITensor, T_Tensor>::value, "T_Tensor must derive from ITensor.");
81 auto external_tensor = _migrant.find(ind);
82 if (external_tensor != _migrant.end())
83 return external_tensor->second;
84 return getNativeTensor(ind);
87 std::shared_ptr<ITensor> getNativeITensor(const ir::OperandIndex &ind) override
89 return getNativeTensor(ind);
92 std::shared_ptr<IPortableTensor> getPortableTensor(const ir::OperandIndex &ind)
94 auto external_tensor = _migrant.find(ind);
95 if (external_tensor != _migrant.end())
97 if (external_tensor->second)
98 return external_tensor->second;
100 return getNativeTensor(ind);
103 std::shared_ptr<T_Tensor> getNativeTensor(const ir::OperandIndex &ind)
105 auto tensor = _native.find(ind);
106 if (tensor != _native.end())
107 return tensor->second;
111 bool setMigrantTensor(const ir::OperandIndex &ind, const std::shared_ptr<IPortableTensor> &tensor)
113 // TODO Uncomment this as two tensors for an index is not allowed.
114 // But now it is temporarily allowed as a workaround. External one hides Managed one.
115 // auto itr = _native.find(ind);
116 // if (itr != _native.end() && itr->second != nullptr && tensor != nullptr)
117 // throw std::runtime_error{
118 // "Tried to set an migrant tensor but an native tensor already exists."};
119 _migrant[ind] = tensor;
123 void setNativeTensor(const ir::OperandIndex &ind, const std::shared_ptr<T_Tensor> &tensor)
125 auto itr = _migrant.find(ind);
126 if (itr != _migrant.end() && itr->second != nullptr && tensor != nullptr)
127 throw std::runtime_error{
128 "Tried to set a native tensor but an migrant tensor already exists."};
129 _native[ind] = tensor;
132 const ir::OperandIndexMap<std::shared_ptr<T_Tensor>> &native_tensors() { return _native; }
134 const ir::OperandIndexMap<std::shared_ptr<IPortableTensor>> &migrant_tensors()
140 ir::OperandIndexMap<std::shared_ptr<IPortableTensor>> _migrant;
141 ir::OperandIndexMap<std::shared_ptr<T_Tensor>> _native;
144 } // namespace backend
147 #endif // __ONERT_BACKEND_ITENSOR_REGISTRY__