2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __ONERT_BACKEND_BASIC_TRAIN_TRAINABLE_BACKEND_CONTEXT_HELPERS_H__
18 #define __ONERT_BACKEND_BASIC_TRAIN_TRAINABLE_BACKEND_CONTEXT_HELPERS_H__
20 #include "backend/basic/BackendContextHelpers.h"
21 #include "backend/train/TrainableBackendContext.h"
32 // TODO Unify with the above `getTensors()` function in `BackendContextHelpers.h`
33 template <typename TensorBuilder>
34 ITensorRegistry *genTensors(backend::train::TrainableBackendContext &ctx,
35 const std::shared_ptr<TensorBuilder> &tensor_builder)
37 const auto &tgraph = *ctx.trainable_graph();
40 (tgraph.getInputs() + tgraph.getOutputs()) | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
41 tgraph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
42 if (ctx.external_operands().contains(ind))
44 // NOTE Assuming there is no layout changes (Always assume NHWC or UNKNOWN)
45 assert(tgraph.layout() != ir::Layout::NCHW);
46 ir::OperandInfo backend_info{obj.shape(), obj.typeInfo(), obj.info().memAllocType(),
48 tensor_builder->registerTensorInfo(ind, backend_info, ir::Layout::NHWC);
51 // For the executors that does not have fixed linear execution order:
52 // To make tensors never be deallocated, this is a workaround to use static memory planner
53 tgraph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &) {
54 if (tensor_builder->isRegistered(ind))
55 tensor_builder->notifyFirstUse(ind);
58 tensor_builder->allocate();
60 return ctx.tensor_registry().get();
65 } // namespace backend
68 #endif // __ONERT_BACKEND_BASIC_TRAIN_TRAINABLE_BACKEND_CONTEXT_HELPERS_H__