2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "loader/GraphLoader.h"
19 namespace luci_interpreter
24 bool isInplaceOperation(const circle::BuiltinOperator &op)
28 case circle::BuiltinOperator_ABS:
29 case circle::BuiltinOperator_LOGISTIC:
30 case circle::BuiltinOperator_RESHAPE:
31 case circle::BuiltinOperator_ELU:
32 case circle::BuiltinOperator_EXPAND_DIMS:
33 case circle::BuiltinOperator_EXP:
34 case circle::BuiltinOperator_TANH:
35 case circle::BuiltinOperator_LEAKY_RELU:
36 case circle::BuiltinOperator_RELU:
37 case circle::BuiltinOperator_RELU6:
38 case circle::BuiltinOperator_ADD:
39 case circle::BuiltinOperator_MUL:
40 case circle::BuiltinOperator_SUB:
41 case circle::BuiltinOperator_WHILE:
48 bool isSingleUsageOfTensor(CircleReader *reader, const int32_t tensor_index)
50 uint32_t usage_count = 0;
52 const auto operators = reader->operators();
53 for (uint32_t i = 0; i < operators.size(); ++i)
55 const auto *op = operators.at(i);
56 assert(op != nullptr);
58 const auto *op_inputs = op->inputs();
59 for (int32_t j = 0; j < op_inputs->size(); ++j)
61 const auto input_index = op_inputs->operator[](j);
62 if (input_index == tensor_index)
64 if (++usage_count > 1)
70 // Let's check that it is not graph output
73 const auto &outputs_indexes = reader->outputs();
74 bool is_graph_output = (std::find(outputs_indexes.begin(), outputs_indexes.end(),
75 tensor_index) != outputs_indexes.end());
85 void GraphLoader::checkInplaceOps(CircleReader *reader, RuntimeGraph *runtime_graph)
87 const auto operators = reader->operators();
88 const auto graph_outputs = reader->outputs();
89 for (uint32_t i = 0; i < operators.size(); ++i)
91 const auto *op = operators.at(i);
92 assert(op != nullptr);
94 // Check inplace optimization for operation with single input and single output
95 if (isInplaceOperation(reader->builtin_code(op)))
97 const auto *op_inputs = op->inputs();
98 const auto *op_outputs = op->outputs();
100 bool is_inplace = true;
101 auto non_const_input_it = op_inputs->begin();
105 std::find_if(non_const_input_it, op_inputs->end(), [&reader](const auto input_idx) {
109 return not Tensor::is_constant_tensor(reader, reader->tensors()[input_idx]);
112 if (non_const_input_it == op_inputs->end())
115 auto dist = std::distance(op_inputs->begin(), non_const_input_it);
117 const auto non_const_input_idx = *non_const_input_it;
119 // Check single usage of input tensor
120 if (not isSingleUsageOfTensor(reader, non_const_input_idx))
126 // Let's check single usage of output tensor
127 if (dist >= op_outputs->size() and op_outputs->size() == 1)
129 assert(dist < op_outputs->size());
130 const auto output_index = op_outputs->operator[](dist);
131 if (not isSingleUsageOfTensor(reader, output_index))
137 // Check that num elements are equal
139 const auto *input_non_const_tensor = reader->tensors().at(non_const_input_idx);
140 const auto *output_tensor = reader->tensors().at(output_index);
141 if (Tensor::num_elements(input_non_const_tensor) != Tensor::num_elements(output_tensor))
148 // Let's check that output is not a graph output tensor
149 // TODO: check this statement
151 if (std::find(graph_outputs.begin(), graph_outputs.end(), output_index) !=
159 non_const_input_it++;
163 runtime_graph->addInplaceOpIndex(op);
168 } // namespace luci_interpreter