State _state = State::Invalid;
};
+using NodeSummary = NodeDesc;
+
+/**
+ * @brief Build a summary from loco Node
+ */
+struct NodeSummaryBuilder
+{
+ virtual ~NodeSummaryBuilder() = default;
+
+ virtual bool build(const loco::Node *, NodeSummary &) const = 0;
+};
+
+struct NodeSummaryBuilderFactory
+{
+ virtual ~NodeSummaryBuilderFactory() = default;
+
+ virtual std::unique_ptr<NodeSummaryBuilder> create(const SymbolTable *) const = 0;
+};
+
struct FormattedGraph
{
virtual ~FormattedGraph() = default;
public:
void dump(std::ostream &os) const final;
+public:
+ FormattedGraphImpl<Formatter::LinearV1> &with(std::unique_ptr<NodeSummaryBuilderFactory> &&f)
+ {
+ _factory = std::move(f);
+ return (*this);
+ }
+
private:
loco::Graph *_graph;
+
+ /**
+ * @brief User-provided NodeSummaryBuilderFactory
+ */
+ std::unique_ptr<NodeSummaryBuilderFactory> _factory = nullptr;
};
template <Formatter F> FormattedGraphImpl<F> fmt(loco::Graph *g)
} // namespace locop
-namespace
+// TODO Remove this workaround
+namespace locop
{
std::ostream &operator<<(std::ostream &os, const NodeDesc &d)
std::vector<std::string> values;
- for (uint32_t n = 0; n < d.arg_size(); ++n)
+ for (uint32_t n = 0; n < d.args().count(); ++n)
{
- values.emplace_back(d.arg(n).first + ": " + d.arg(n).second);
+ values.emplace_back(d.args().at(n).first + ": " + d.args().at(n).second);
}
if (d.state() == NodeDesc::State::PartiallyKnown)
values.emplace_back("...");
}
- os << d.name();
+ os << d.opname();
os << "(";
if (values.size() > 0)
{
return os;
}
+} // namespace locop
+
+namespace
+{
+
NodeDesc default_node_desc(const SymbolTable &tbl, const loco::Node *node)
{
NodeDesc res{opname(node)};
return default_node_desc(tbl, node);
}
+struct BuiltinNodeSummaryBuilder final : public locop::NodeSummaryBuilder
+{
+public:
+ BuiltinNodeSummaryBuilder(const locop::SymbolTable *symtbl) : _symtbl{symtbl}
+ {
+ // DO NOTHING
+ }
+
+public:
+ bool build(const loco::Node *node, locop::NodeSummary &summary) const final
+ {
+ summary = node_desc(*_symtbl, node);
+ return true;
+ }
+
+private:
+ const locop::SymbolTable *_symtbl;
+};
+
} // namespace
namespace locop
clusters.at(find(node)).insert(node);
}
+ std::unique_ptr<locop::NodeSummaryBuilder> node_summary_builder;
+
+ if (_factory)
+ {
+ // Use User-defined NodeSummaryBuilder if NodeSummaryBuilderFactory is present
+ node_summary_builder = _factory->create(&symbols);
+ }
+ else
+ {
+ // Use Built-in NodeSummaryBuilder otherwise
+ node_summary_builder = stdex::make_unique<BuiltinNodeSummaryBuilder>(&symbols);
+ }
+
for (auto it = clusters.begin(); it != clusters.end(); ++it)
{
std::vector<loco::Node *> cluster_outputs;
for (auto node : loco::postorder_traversal(cluster_outputs))
{
- os << symbol(node) << " = " << node_desc(symbols, node) << std::endl;
+ locop::NodeSummary node_summary;
+
+ // Build a node summary
+ if (!node_summary_builder->build(node, node_summary))
+ {
+ throw std::runtime_error{"Fail to build a node summary"};
+ }
+
+ os << symbol(node) << " = " << node_summary << std::endl;
}
os << std::endl;
}
#include "locop/FormattedGraph.h"
+#include <stdex/Memory.h>
+
#include <gtest/gtest.h>
TEST(LinearV1FormatterTest, simple)
// TODO Validate the output (when the implementation becomes stable)
std::cout << locop::fmt<locop::LinearV1>(g) << std::endl;
}
+
+TEST(LinearV1FormatterTest, user_defined_node_summary_builder)
+{
+ auto g = loco::make_graph();
+ {
+ auto pull = g->nodes()->create<loco::Pull>();
+
+ pull->rank(2);
+ pull->dim(0) = loco::make_dimension(); // Mark dim 0 as unknown
+ pull->dim(1) = 4;
+
+ auto push = g->nodes()->create<loco::Push>();
+
+ push->from(pull);
+ }
+
+ struct MyBuilder final : public locop::NodeSummaryBuilder
+ {
+ bool build(const loco::Node *, locop::NodeSummary &s) const final
+ {
+ s.opname("my.op");
+ s.state(locop::NodeSummary::State::PartiallyKnown);
+ return true;
+ }
+ };
+
+ struct MyFactory final : public locop::NodeSummaryBuilderFactory
+ {
+ std::unique_ptr<locop::NodeSummaryBuilder> create(const locop::SymbolTable *) const final
+ {
+ return stdex::make_unique<MyBuilder>();
+ }
+ };
+
+ std::cout << locop::fmt<locop::LinearV1>(g).with(stdex::make_unique<MyFactory>()) << std::endl;
+
+ // TODO Check whether MyBuilder actually sees all the nodes in a graph
+ SUCCEED();
+}