Move graph util to fx2trt (#64064)
authorShirong Wu <shirong@fb.com>
Thu, 2 Sep 2021 05:09:42 +0000 (22:09 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 05:34:11 +0000 (22:34 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64064

Move original util in torch2trt to fx2trt dir since torch2trt is gonne be deprecated. This is a follow up diff for D30379124

Test Plan: manual

Reviewed By: yinghai, mikekgfb

Differential Revision: D30591687

fbshipit-source-id: ae0e59dfbc2d2e2aa4f3ccea7cff2291c7deb388

torch/fx/experimental/fx2trt/tools/graph_util.py [new file with mode: 0644]

diff --git a/torch/fx/experimental/fx2trt/tools/graph_util.py b/torch/fx/experimental/fx2trt/tools/graph_util.py
new file mode 100644 (file)
index 0000000..96c8b12
--- /dev/null
@@ -0,0 +1,64 @@
+import graphviz  # type: ignore[import]
+
+def get_layer_name_type(layer):
+    return "\n".join(f"{i}" for i in [layer.name, layer.type])
+
+def trt_network_to_dot_graph(network):
+    dot = graphviz.Digraph(comment="Network")
+
+    # add nodes (layers)
+    for i in range(network.num_layers):
+        layer = network.get_layer(i)
+        dot.node(get_layer_name_type(layer))
+
+    # add nodes (inputs)
+    for i in range(network.num_inputs):
+        dot.node(network.get_input(i).name)
+
+    # add nodes (outputs)
+    for i in range(network.num_outputs):
+        dot.node(network.get_output(i).name)
+
+    # add layer->layer edges
+    for a in range(network.num_layers):
+        layer_a = network.get_layer(a)
+
+        for b in range(network.num_layers):
+            layer_b = network.get_layer(b)
+
+            for i in range(layer_a.num_outputs):
+                output_i = layer_a.get_output(i)
+
+                for j in range(layer_b.num_inputs):
+                    input_j = layer_b.get_input(j)
+
+                    if output_i == input_j:
+                        dot.edge(get_layer_name_type(layer_a), get_layer_name_type(layer_b), label=str(input_j.shape))
+
+    # add input->layer edges
+    for i in range(network.num_inputs):
+        input_i = network.get_input(i)
+
+        for b in range(network.num_layers):
+            layer_b = network.get_layer(b)
+
+            for j in range(layer_b.num_inputs):
+                input_j = layer_b.get_input(j)
+
+                if input_i == input_j:
+                    dot.edge(input_i.name, get_layer_name_type(layer_b), label=str(input_j.shape))
+
+    # add layer->output edges
+    for i in range(network.num_outputs):
+        input_i = network.get_output(i)
+
+        for b in range(network.num_layers):
+            layer_b = network.get_layer(b)
+
+            for j in range(layer_b.num_outputs):
+                input_j = layer_b.get_output(j)
+
+                if input_i == input_j:
+                    dot.edge(get_layer_name_type(layer_b), input_i.name, label=str(input_j.shape))
+
+    return dot