b8e521ce3330d8e1f2ea4364d26fd13495a81292
[platform/core/ml/nnfw.git] / runtime / onert / core / include / backend / ITensorRegister.h
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_BACKEND_ITENSOR_REGISTER_H__
18 #define __ONERT_BACKEND_ITENSOR_REGISTER_H__
19
20 #include "ir/LowerInfoMap.h"
21 #include "ITensorBuilder.h"
22 #include "ir/Layout.h"
23 #include "ir/OperandIndexSequence.h"
24 #include "ir/OperandInfo.h"
25 #include "ir/Operands.h"
26 #include "ir/OperationVisitor.h"
27
28 namespace onert
29 {
30 namespace backend
31 {
32
33 class ITensorRegister : public ir::OperationVisitor
34 {
35 public:
36   virtual ~ITensorRegister() = default;
37
38 public:
39   void registerTensors(const ir::OpSequence &op_seq, const ir::LowerInfoMap *lower_info_map)
40   {
41     _current_op_seq_layout = op_seq.getLayout();
42     _lower_info_map = lower_info_map;
43     assert(_lower_info_map != nullptr);
44     assert(tensor_builder().get() != nullptr);
45     op_seq.accept(*this);
46   }
47
48 protected:
49   virtual const ir::Operands &operands() const = 0;
50   virtual std::shared_ptr<ITensorBuilder> tensor_builder() const = 0;
51
52 protected:
53 #define OP(InternalName)                                                                   \
54   void visit(const ir::operation::InternalName &node) override                             \
55   {                                                                                        \
56     for (const auto &ind : (node.getInputs() | ir::Remove::UNDEFINED) + node.getOutputs()) \
57     {                                                                                      \
58       defaultRegisterTensorInfo(ind);                                                      \
59     }                                                                                      \
60   }
61 #include "ir/Operations.lst"
62 #undef OP
63
64 protected:
65   void defaultRegisterTensorInfo(const ir::OperandIndex &index) const
66   {
67     if (tensor_builder()->isRegistered(index))
68     {
69       return;
70     }
71
72     const auto &obj = operands().at(index);
73     const auto frontend_layout = frontendLayout();
74     const auto backend_layout = backendLayout(index);
75     ir::OperandInfo backend_info{permuteShape(obj.shape(), frontend_layout, backend_layout),
76                                  obj.typeInfo(), obj.info().memAllocType(), obj.isConstant()};
77     tensor_builder()->registerTensorInfo(index, backend_info, backend_layout);
78   }
79
80 protected:
81   ir::Layout frontendLayout() const { return _current_op_seq_layout; }
82   ir::Layout backendLayout(const ir::OperandIndex &index) const
83   {
84     assert(_lower_info_map != nullptr);
85     const auto lower_info = _lower_info_map->operand.at(index).get();
86     return lower_info->def_factors().getOnlyElement().layout();
87   }
88
89 private:
90   ir::Layout _current_op_seq_layout;
91   const ir::LowerInfoMap *_lower_info_map{nullptr};
92 };
93
94 } // namespace backend
95 } // namespace onert
96
97 #endif // __ONERT_BACKEND_ITENSOR_REGISTER_H__