Allow customization of blob node in net_drawer (#16915)
authorYinghai Lu <yinghai@fb.com>
Tue, 12 Feb 2019 22:43:44 +0000 (14:43 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 12 Feb 2019 23:02:50 +0000 (15:02 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16915

TSIA

Reviewed By: ipiszy

Differential Revision: D14018010

fbshipit-source-id: df5ccc06fa37f08e7a02a8acc466c4ad47afe04e

caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h
caffe2/opt/backend_cutting.cc
caffe2/opt/backend_cutting.h
caffe2/opt/onnxifi_transformer.cc
caffe2/python/net_drawer.py

index e279378..2a28004 100644 (file)
@@ -38,8 +38,7 @@ class DotGenerator {
       const typename GraphT::SubgraphType& sg,
       const std::vector<typename GraphT::SubgraphType*>& subgraphs) const {
     std::ostringstream output;
-    output << "digraph G {\n\
-      ";
+    output << "digraph G {\nrankdir=LR\n";
     for (const auto& node : sg.getNodes()) {
       generateNode(node, sg, output);
     }
@@ -60,8 +59,7 @@ class DotGenerator {
   // Convert a subgraph to dot.
   std::string convert(const typename GraphT::SubgraphType& sg) const {
     std::ostringstream output;
-    output << "digraph G {\n\
-      ";
+    output << "digraph G {\nrankdir=LR\n";
     for (const auto& node : sg.getNodes()) {
       generateNode(node, sg, output);
     }
@@ -82,7 +80,7 @@ class DotGenerator {
    */
   std::string convertStruct(const typename GraphT::SubgraphType& sg) const {
     std::ostringstream output;
-    output << "digraph G {\n";
+    output << "digraph G {\nrankdir=LR\n";
 
     // Get input nodes (nodes w/o parents)
     std::unordered_map<typename GraphT::NodeRef, int>
index 5715e5a..c4dd792 100644 (file)
@@ -346,7 +346,8 @@ void PruneUnrefereredNodes(NNModule* nn) {
 caffe2::NetDef OptimizeForBackend(
     caffe2::NetDef& net,
     std::function<bool(const caffe2::OperatorDef&)> supports,
-    std::function<caffe2::NetDef(const caffe2::NetDef&)> transform_func) {
+    std::function<caffe2::NetDef(const caffe2::NetDef&)> transform_func,
+    bool debug) {
   auto nn = convertToNNModule(net);
   auto& dfg = nn.dataFlow;
 
@@ -413,6 +414,10 @@ caffe2::NetDef OptimizeForBackend(
   // absorbed
   PruneUnrefereredNodes(&nn);
 
+  if (debug) {
+    DumpGraph(&dfg);
+  }
+
   auto new_net = convertToCaffe2Proto(nn);
   new_net.set_name(net.name() + "_opt");
   return new_net;
index 8ea1413..cf98c11 100644 (file)
@@ -12,6 +12,7 @@ namespace opt {
 CAFFE2_API caffe2::NetDef OptimizeForBackend(
     caffe2::NetDef& net,
     std::function<bool(const caffe2::OperatorDef&)> supports,
-    std::function<caffe2::NetDef(const caffe2::NetDef&)> transform_func);
+    std::function<caffe2::NetDef(const caffe2::NetDef&)> transform_func,
+    bool debug = false);
 }
 } // namespace caffe2
index 64e801c..a1e0896 100644 (file)
@@ -285,10 +285,10 @@ int64_t GetBlob1stDimSize(
   return shape_info.shape.dims(0);
 }
 
-// Generates AdjustBatchOps for external inputs / outputs with type BATCH or
+// Generates AdjustBatchOps for external inputs/outputs with type BATCH or
 // SEQ and adds them to input_ops and output_ops.
-// Meanwhile, modifies inputs / outputs of corresponding operators in the
-// onnxifi_net to use the new inputs / outputs of AdjustBatchOps.
+// Meanwhile, modifies inputs/outputs of corresponding operators in the
+// onnxifi_net to use the new inputs/outputs of AdjustBatchOps.
 std::unordered_map<std::string, std::string> AddAdjustBatchOps(
     const ShapeInfoMap& shape_hints,
     NetDef* onnxifi_net,
@@ -979,7 +979,8 @@ NetDef OnnxifiTransformer::TransformViaOnnx(
             net, weights, ws, &exporter2, shape_hints, &shape_hints_onnx);
       };
 
-  return opt::OptimizeForBackend(*pred_net, onnx_supports, onnx_converter);
+  return opt::OptimizeForBackend(
+      *pred_net, onnx_supports, onnx_converter, opts_.debug);
 }
 
 // Cutting off the runnable part and replace with ONNXIFI ops. Asssume the nets
index ee124bc..17f6c4b 100644 (file)
@@ -79,31 +79,38 @@ def GetOpNodeProducer(append_output, **kwargs):
     return ReallyGetOpNode
 
 
+def GetBlobNodeProducer(**kwargs):
+    def ReallyGetBlobNode(node_name, label):
+        return pydot.Node(node_name, label=label, **kwargs)
+    return ReallyGetBlobNode
+
 def GetPydotGraph(
     operators_or_net,
     name=None,
     rankdir='LR',
-    node_producer=None
+    op_node_producer=None,
+    blob_node_producer=None
 ):
-    if node_producer is None:
-        node_producer = GetOpNodeProducer(False, **OP_STYLE)
+    if op_node_producer is None:
+        op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
+    if blob_node_producer is None:
+        blob_node_producer = GetBlobNodeProducer(**BLOB_STYLE)
     operators, name = _rectify_operator_and_name(operators_or_net, name)
     graph = pydot.Dot(name, rankdir=rankdir)
     pydot_nodes = {}
     pydot_node_counts = defaultdict(int)
     for op_id, op in enumerate(operators):
-        op_node = node_producer(op, op_id)
+        op_node = op_node_producer(op, op_id)
         graph.add_node(op_node)
         # print 'Op: %s' % op.name
         # print 'inputs: %s' % str(op.input)
         # print 'outputs: %s' % str(op.output)
         for input_name in op.input:
             if input_name not in pydot_nodes:
-                input_node = pydot.Node(
+                input_node = blob_node_producer(
                     _escape_label(
                         input_name + str(pydot_node_counts[input_name])),
                     label=_escape_label(input_name),
-                    **BLOB_STYLE
                 )
                 pydot_nodes[input_name] = input_node
             else:
@@ -114,11 +121,10 @@ def GetPydotGraph(
             if output_name in pydot_nodes:
                 # we are overwriting an existing blob. need to updat the count.
                 pydot_node_counts[output_name] += 1
-            output_node = pydot.Node(
+            output_node = blob_node_producer(
                 _escape_label(
                     output_name + str(pydot_node_counts[output_name])),
                 label=_escape_label(output_name),
-                **BLOB_STYLE
             )
             pydot_nodes[output_name] = output_node
             graph.add_node(output_node)
@@ -131,7 +137,7 @@ def GetPydotGraphMinimal(
     name=None,
     rankdir='LR',
     minimal_dependency=False,
-    node_producer=None,
+    op_node_producer=None,
 ):
     """Different from GetPydotGraph, hide all blob nodes and only show op nodes.
 
@@ -140,8 +146,8 @@ def GetPydotGraphMinimal(
     op a and b, and op b depends on a, then only the edge b->c will be drawn
     because a->c will be implied.
     """
-    if node_producer is None:
-        node_producer = GetOpNodeProducer(False, **OP_STYLE)
+    if op_node_producer is None:
+        op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
     operators, name = _rectify_operator_and_name(operators_or_net, name)
     graph = pydot.Dot(name, rankdir=rankdir)
     # blob_parents maps each blob name to its generating op.
@@ -149,7 +155,7 @@ def GetPydotGraphMinimal(
     # op_ancestry records the ancestors of each op.
     op_ancestry = defaultdict(set)
     for op_id, op in enumerate(operators):
-        op_node = node_producer(op, op_id)
+        op_node = op_node_producer(op, op_id)
         graph.add_node(op_node)
         # Get parents, and set up op ancestry.
         parents = [