2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 #include "OperationValidator.h"
20 #include "verifier/Verifier.h"
29 Graph::Graph() = default;
31 Graph::Graph(const Graph &) = default;
33 Graph::~Graph(void) = default;
35 OperandIndex Graph::addOperand(const Shape &shape, const TypeInfo &type)
37 return _operands.emplace(shape, type);
40 OperandIndex Graph::addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand)
42 return _operands.push(std::move(operand), index);
45 bool Graph::checkOperandsForOperation(const IOperation &operation)
47 auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
48 auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
49 for (auto &&input : inputs)
50 if (!operands().exist(input))
52 for (auto &&input : outputs)
53 if (!operands().exist(input))
58 void Graph::linkOperandToOperation(OperationIndex index, const IOperation &operation)
60 auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
61 auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
63 for (auto &&input : inputs)
64 operands().at(input).insertUse(index);
65 for (auto &&output : outputs)
66 operands().at(output).setDef(index);
69 OperationIndex Graph::addOperation(std::unique_ptr<IOperation> &&operation)
71 const IOperation &op_ref = *operation;
72 if (!checkOperandsForOperation(op_ref))
73 return OperationIndex{};
74 auto ind = _operations.push(std::move(operation));
76 linkOperandToOperation(ind, op_ref);
80 OperationIndex Graph::addOperation(OperationIndex index, std::unique_ptr<IOperation> &&operation)
82 const IOperation &op_ref = *operation;
83 if (!checkOperandsForOperation(op_ref))
84 return OperationIndex{};
85 auto ind_gen = _operations.push(std::move(operation), index);
88 assert(ind_gen == index);
89 linkOperandToOperation(index, op_ref);
94 OperationIndex Graph::replaceOperation(OperationIndex index,
95 std::unique_ptr<IOperation> &&operation)
97 const IOperation &op_ref = *operation;
98 if (!checkOperandsForOperation(op_ref) || !_operations.exist(index))
99 return OperationIndex{};
101 // Check the new operation has the same inputs/outputs as the existing operation
102 const auto &old_op = _operations.at(index);
103 if (!(old_op.getInputs() == op_ref.getInputs() && old_op.getOutputs() == op_ref.getOutputs()))
105 return OperationIndex{};
108 return _operations.set(index, std::move(operation));
111 void Graph::setOperandValue(const OperandIndex &ind, std::shared_ptr<Data> data)
113 assert(_operands.exist(ind));
114 _operands.at(ind).data(std::move(data));
117 void Graph::changeShape(const OperandIndex &ind, const ir::Shape &new_shape)
119 assert(_operands.exist(ind));
120 _operands.at(ind).info().shape(new_shape);
123 void Graph::addInput(const OperandIndex &ind, const std::string &name)
126 _name_to_input.emplace(name, IOIndex{_inputs.size()});
130 void Graph::addOutput(const OperandIndex &ind, const std::string &name)
133 _name_to_output.emplace(name, IOIndex{_outputs.size()});
134 _outputs.append(ind);
137 IOIndex Graph::getInputIndex(const std::string &name) const
139 auto itr = _name_to_input.find(name);
140 return (itr == _name_to_input.end()) ? IOIndex{} : itr->second;
143 IOIndex Graph::getOutputIndex(const std::string &name) const
145 auto itr = _name_to_output.find(name);
146 return (itr == _name_to_output.end()) ? IOIndex{} : itr->second;
149 void Graph::verify(void) const
151 // Call graph verifications for the MODEL phase
153 // Except for edge consistency, the user might have been given a bad model
154 // so here it throws an execption rather than assertion.
155 if (!verifier::InputOutputChecker().verify(*this))
156 throw std::runtime_error{"One of model input and output operands does not exist."};
157 if (!verifier::DAGChecker().verify(*this))
158 throw std::runtime_error{"The graph is cyclic."};
159 assert(verifier::EdgeChecker().verify(*this));
162 // Check shape independent operation feature
164 // - Shape independent parameter
165 OperationValidator{*this}();
168 void Graph::initializeUseDef()
170 operations().iterate([&](const OperationIndex &index, const IOperation &node) -> void {
171 auto outputs = node.getOutputs();
172 for (auto &&output : outputs | ir::Remove::UNDEFINED)
174 operands().at(output).setDef(index);
177 for (auto &&input : node.getInputs() | ir::Remove::UNDEFINED)
179 operands().at(input).insertUse(index);
184 std::vector<ir::OperationIndex> Graph::topolSortOperations() const
186 std::vector<ir::OperationIndex> ret;
187 util::Set<ir::OperationIndex> unvisited;
188 operations().iterate(
189 [&](const ir::OperationIndex &index, const ir::IOperation &) { unvisited.add(index); });
191 std::function<void(const ir::OperationIndex &, const ir::IOperation &)> dfs =
192 [&](const ir::OperationIndex &index, const ir::IOperation &op) -> void {
193 if (!unvisited.contains(index))
195 unvisited.remove(index);
197 for (const auto &output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
199 const auto &operand = operands().at(output);
200 for (const auto &use : operand.getUses())
202 dfs(use, operations().at(use));
205 ret.push_back(index);
207 operations().iterate(dfs);
209 assert(unvisited.empty()); // All of the nodes must have been visited
210 // Reversing Postorder DFS result to make it sorted in topoligical order
211 std::reverse(ret.begin(), ret.end());