From: Yinghai Lu Date: Tue, 12 Feb 2019 22:43:44 +0000 (-0800) Subject: Allow customization of blob node in net_drawer (#16915) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1328 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f435fb8290aeb00fbfa0191d5c42c59c5a772623;p=platform%2Fupstream%2Fpytorch.git Allow customization of blob node in net_drawer (#16915) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16915 TSIA Reviewed By: ipiszy Differential Revision: D14018010 fbshipit-source-id: df5ccc06fa37f08e7a02a8acc466c4ad47afe04e --- diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h b/caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h index e279378..2a28004 100644 --- a/caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h +++ b/caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h @@ -38,8 +38,7 @@ class DotGenerator { const typename GraphT::SubgraphType& sg, const std::vector& subgraphs) const { std::ostringstream output; - output << "digraph G {\n\ - "; + output << "digraph G {\nrankdir=LR\n"; for (const auto& node : sg.getNodes()) { generateNode(node, sg, output); } @@ -60,8 +59,7 @@ class DotGenerator { // Convert a subgraph to dot. std::string convert(const typename GraphT::SubgraphType& sg) const { std::ostringstream output; - output << "digraph G {\n\ - "; + output << "digraph G {\nrankdir=LR\n"; for (const auto& node : sg.getNodes()) { generateNode(node, sg, output); } @@ -82,7 +80,7 @@ class DotGenerator { */ std::string convertStruct(const typename GraphT::SubgraphType& sg) const { std::ostringstream output; - output << "digraph G {\n"; + output << "digraph G {\nrankdir=LR\n"; // Get input nodes (nodes w/o parents) std::unordered_map diff --git a/caffe2/opt/backend_cutting.cc b/caffe2/opt/backend_cutting.cc index 5715e5a..c4dd792 100644 --- a/caffe2/opt/backend_cutting.cc +++ b/caffe2/opt/backend_cutting.cc @@ -346,7 +346,8 @@ void PruneUnrefereredNodes(NNModule* nn) { caffe2::NetDef OptimizeForBackend( caffe2::NetDef& net, std::function supports, - std::function transform_func) { + std::function transform_func, + bool debug) { auto nn = convertToNNModule(net); auto& dfg = nn.dataFlow; @@ -413,6 +414,10 @@ caffe2::NetDef OptimizeForBackend( // absorbed PruneUnrefereredNodes(&nn); + if (debug) { + DumpGraph(&dfg); + } + auto new_net = convertToCaffe2Proto(nn); new_net.set_name(net.name() + "_opt"); return new_net; diff --git a/caffe2/opt/backend_cutting.h b/caffe2/opt/backend_cutting.h index 8ea1413..cf98c11 100644 --- a/caffe2/opt/backend_cutting.h +++ b/caffe2/opt/backend_cutting.h @@ -12,6 +12,7 @@ namespace opt { CAFFE2_API caffe2::NetDef OptimizeForBackend( caffe2::NetDef& net, std::function supports, - std::function transform_func); + std::function transform_func, + bool debug = false); } } // namespace caffe2 diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 64e801c..a1e0896 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -285,10 +285,10 @@ int64_t GetBlob1stDimSize( return shape_info.shape.dims(0); } -// Generates AdjustBatchOps for external inputs / outputs with type BATCH or +// Generates AdjustBatchOps for external inputs/outputs with type BATCH or // SEQ and adds them to input_ops and output_ops. -// Meanwhile, modifies inputs / outputs of corresponding operators in the -// onnxifi_net to use the new inputs / outputs of AdjustBatchOps. +// Meanwhile, modifies inputs/outputs of corresponding operators in the +// onnxifi_net to use the new inputs/outputs of AdjustBatchOps. std::unordered_map AddAdjustBatchOps( const ShapeInfoMap& shape_hints, NetDef* onnxifi_net, @@ -979,7 +979,8 @@ NetDef OnnxifiTransformer::TransformViaOnnx( net, weights, ws, &exporter2, shape_hints, &shape_hints_onnx); }; - return opt::OptimizeForBackend(*pred_net, onnx_supports, onnx_converter); + return opt::OptimizeForBackend( + *pred_net, onnx_supports, onnx_converter, opts_.debug); } // Cutting off the runnable part and replace with ONNXIFI ops. Asssume the nets diff --git a/caffe2/python/net_drawer.py b/caffe2/python/net_drawer.py index ee124bc..17f6c4b 100644 --- a/caffe2/python/net_drawer.py +++ b/caffe2/python/net_drawer.py @@ -79,31 +79,38 @@ def GetOpNodeProducer(append_output, **kwargs): return ReallyGetOpNode +def GetBlobNodeProducer(**kwargs): + def ReallyGetBlobNode(node_name, label): + return pydot.Node(node_name, label=label, **kwargs) + return ReallyGetBlobNode + def GetPydotGraph( operators_or_net, name=None, rankdir='LR', - node_producer=None + op_node_producer=None, + blob_node_producer=None ): - if node_producer is None: - node_producer = GetOpNodeProducer(False, **OP_STYLE) + if op_node_producer is None: + op_node_producer = GetOpNodeProducer(False, **OP_STYLE) + if blob_node_producer is None: + blob_node_producer = GetBlobNodeProducer(**BLOB_STYLE) operators, name = _rectify_operator_and_name(operators_or_net, name) graph = pydot.Dot(name, rankdir=rankdir) pydot_nodes = {} pydot_node_counts = defaultdict(int) for op_id, op in enumerate(operators): - op_node = node_producer(op, op_id) + op_node = op_node_producer(op, op_id) graph.add_node(op_node) # print 'Op: %s' % op.name # print 'inputs: %s' % str(op.input) # print 'outputs: %s' % str(op.output) for input_name in op.input: if input_name not in pydot_nodes: - input_node = pydot.Node( + input_node = blob_node_producer( _escape_label( input_name + str(pydot_node_counts[input_name])), label=_escape_label(input_name), - **BLOB_STYLE ) pydot_nodes[input_name] = input_node else: @@ -114,11 +121,10 @@ def GetPydotGraph( if output_name in pydot_nodes: # we are overwriting an existing blob. need to updat the count. pydot_node_counts[output_name] += 1 - output_node = pydot.Node( + output_node = blob_node_producer( _escape_label( output_name + str(pydot_node_counts[output_name])), label=_escape_label(output_name), - **BLOB_STYLE ) pydot_nodes[output_name] = output_node graph.add_node(output_node) @@ -131,7 +137,7 @@ def GetPydotGraphMinimal( name=None, rankdir='LR', minimal_dependency=False, - node_producer=None, + op_node_producer=None, ): """Different from GetPydotGraph, hide all blob nodes and only show op nodes. @@ -140,8 +146,8 @@ def GetPydotGraphMinimal( op a and b, and op b depends on a, then only the edge b->c will be drawn because a->c will be implied. """ - if node_producer is None: - node_producer = GetOpNodeProducer(False, **OP_STYLE) + if op_node_producer is None: + op_node_producer = GetOpNodeProducer(False, **OP_STYLE) operators, name = _rectify_operator_and_name(operators_or_net, name) graph = pydot.Dot(name, rankdir=rankdir) # blob_parents maps each blob name to its generating op. @@ -149,7 +155,7 @@ def GetPydotGraphMinimal( # op_ancestry records the ancestors of each op. op_ancestry = defaultdict(set) for op_id, op in enumerate(operators): - op_node = node_producer(op, op_id) + op_node = op_node_producer(op, op_id) graph.add_node(op_node) # Get parents, and set up op ancestry. parents = [