2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "locop/FormattedGraph.h"
18 #include "locop/FormattedTensorShape.h"
19 #include "locop/GenericNodeSummaryBuilder.h"
21 #include <loco/Service/TypeInference.h>
22 #include <loco/Service/ShapeInference.h>
24 #include <pp/Format.h>
26 #include <stdex/Memory.h>
33 using locop::SymbolTable;
38 std::string str(const loco::DataType &dtype)
42 case loco::DataType::Unknown:
45 case loco::DataType::U8:
47 case loco::DataType::U16:
49 case loco::DataType::U32:
51 case loco::DataType::U64:
54 case loco::DataType::S8:
56 case loco::DataType::S16:
58 case loco::DataType::S32:
60 case loco::DataType::S64:
63 case loco::DataType::FLOAT16:
65 case loco::DataType::FLOAT32:
67 case loco::DataType::FLOAT64:
70 case loco::DataType::BOOL:
77 throw std::invalid_argument{"dtype"};
80 std::string str(const loco::Domain &domain)
85 case loco::Domain::Unknown:
87 case loco::Domain::Tensor:
89 case loco::Domain::Feature:
91 case loco::Domain::Filter:
93 case loco::Domain::DepthwiseFilter:
95 case loco::Domain::Bias:
101 throw std::invalid_argument{"domain"};
104 std::string str(const loco::NodeShape &node_shape)
106 using namespace locop;
108 switch (node_shape.domain())
110 case loco::Domain::Tensor:
112 auto tensor_shape = node_shape.as<loco::TensorShape>();
113 return pp::fmt(locop::fmt<TensorShapeFormat::Plain>(&tensor_shape));
116 case loco::Domain::Feature:
117 case loco::Domain::Filter:
118 case loco::Domain::DepthwiseFilter:
119 case loco::Domain::Bias:
126 throw std::invalid_argument{"domain"};
129 // TODO Use locop::fmt<TensorShapeFormat ...>
130 locop::FormattedTensorShape<locop::TensorShapeFormat::Bracket>
131 formatted_tensor_shape(const loco::TensorShape *ptr)
133 return locop::FormattedTensorShape<locop::TensorShapeFormat::Bracket>{ptr};
141 struct NodeDesc : public locop::NodeDesc
144 NodeDesc() = default;
145 NodeDesc(const locop::OpName &opname) : locop::NodeDesc{opname}
152 const locop::OpName &name(void) const { return opname(); }
155 uint32_t arg_size(void) const { return args().count(); }
157 const locop::ArgElem &arg(uint32_t n) const { return args().at(n); }
159 void arg(const locop::ArgName &name, const locop::ArgValue &value) { args().append(name, value); }
164 // TODO Remove this workaround
168 std::ostream &operator<<(std::ostream &os, const NodeDesc &d)
170 assert(d.state() != NodeDesc::State::Invalid);
172 std::vector<std::string> values;
174 for (uint32_t n = 0; n < d.args().count(); ++n)
176 values.emplace_back(d.args().at(n).first + ": " + d.args().at(n).second);
179 if (d.state() == NodeDesc::State::PartiallyKnown)
181 values.emplace_back("...");
186 if (values.size() > 0)
189 for (uint32_t n = 1; n < values.size(); ++n)
191 os << ", " << values.at(n);
204 std::ostream &operator<<(std::ostream &os, const FormattedGraph &fmt)
215 void FormattedGraphImpl<Formatter::LinearV1>::dump(std::ostream &os) const
217 struct SymbolTableImpl final : public SymbolTable
219 std::string lookup(const loco::Node *node) const final
226 return _content.at(node);
229 std::map<const loco::Node *, std::string> _content;
232 SymbolTableImpl symbols;
234 auto symbol = [&symbols](const loco::Node *node) { return symbols.lookup(node); };
236 for (uint32_t n = 0; n < _graph->nodes()->size(); ++n)
238 symbols._content[_graph->nodes()->at(n)] = pp::fmt("%", n);
241 // Find the disjoint node clusters
243 // TODO Move this implementation into loco Algorithms.h
244 std::map<loco::Node *, loco::Node *> parents;
246 for (auto node : loco::all_nodes(_graph))
248 parents[node] = nullptr;
251 for (auto node : loco::all_nodes(_graph))
253 for (uint32_t n = 0; n < node->arity(); ++n)
255 if (auto arg = node->arg(n))
262 auto find = [&parents](loco::Node *node) {
263 loco::Node *cur = node;
265 while (parents.at(cur) != nullptr)
267 cur = parents.at(cur);
273 std::set<loco::Node *> roots;
275 for (auto node : loco::all_nodes(_graph))
277 roots.insert(find(node));
280 std::map<loco::Node *, std::set<loco::Node *>> clusters;
283 for (auto root : roots)
285 clusters[root] = std::set<loco::Node *>{};
288 for (auto node : loco::all_nodes(_graph))
290 clusters.at(find(node)).insert(node);
293 std::unique_ptr<locop::NodeSummaryBuilder> node_summary_builder;
297 // Use User-defined NodeSummaryBuilder if NodeSummaryBuilderFactory is present
298 node_summary_builder = _factory->create(&symbols);
302 // Use Built-in NodeSummaryBuilder otherwise
303 node_summary_builder = stdex::make_unique<GenericNodeSummaryBuilder>(&symbols);
306 // Print Graph Input(s)
307 for (uint32_t n = 0; n < _graph->inputs()->size(); ++n)
309 auto input = _graph->inputs()->at(n);
311 std::string name = input->name();
313 std::string shape = "?";
314 if (input->shape() != nullptr)
316 shape = pp::fmt(formatted_tensor_shape(input->shape()));
320 os << pp::fmt("In #", n, " { name: ", name, ", shape: ", shape, " }") << std::endl;
323 // Print Graph Output(s)
324 for (uint32_t n = 0; n < _graph->outputs()->size(); ++n)
326 auto output = _graph->outputs()->at(n);
328 std::string name = output->name();
330 std::string shape = "?";
331 if (output->shape() != nullptr)
333 shape = pp::fmt(formatted_tensor_shape(output->shape()));
337 os << pp::fmt("Out #", n, " { name: ", name, ", shape: ", shape, " }") << std::endl;
340 if (_graph->inputs()->size() + _graph->outputs()->size() != 0)
345 for (auto it = clusters.begin(); it != clusters.end(); ++it)
347 std::vector<loco::Node *> cluster_outputs;
349 for (auto node : it->second)
351 // NOTE This is inefficient but anyway working :)
352 if (loco::succs(node).empty())
354 cluster_outputs.emplace_back(node);
358 for (auto node : loco::postorder_traversal(cluster_outputs))
360 locop::NodeSummary node_summary;
362 // Build a node summary
363 if (!node_summary_builder->build(node, node_summary))
365 throw std::runtime_error{"Fail to build a node summary"};
368 for (uint32_t n = 0; n < node_summary.comments().count(); ++n)
370 os << "; " << node_summary.comments().at(n) << std::endl;
375 if (loco::shape_known(node))
377 auto node_shape = loco::shape_get(node);
378 os << " : " << str(node_shape.domain());
380 os << str(node_shape);
383 os << (loco::dtype_known(node) ? str(loco::dtype_get(node)) : std::string{"?"});
387 os << " = " << node_summary << std::endl;