[FX] make visualizer produce different formatted output (#64699)
authorYinghai Lu <yinghai@fb.com>
Thu, 9 Sep 2021 01:20:46 +0000 (18:20 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 01:22:12 +0000 (18:22 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64699

Previously we just hardcode to svg format. We should give folks a choice in terms of what format they want to see. If we give a weird extension like .abc and this will error out and we expect this to be the right behavior.

Reviewed By: houseroad

Differential Revision: D30718883

fbshipit-source-id: fe8827262f94ea6887999bb225de763d1909eef8

torch/fx/experimental/fx_acc/acc_utils.py

index a77f3fa..4ae327c 100644 (file)
@@ -1,5 +1,6 @@
 import inspect
 import json
+import os
 from typing import Any, Tuple, Callable, Union, Dict
 
 import torch
@@ -83,9 +84,10 @@ def build_raw_tensor_meta(
 
 
 def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_graph"):
-    if not fname.endswith(".svg"):
-        fname = fname + ".svg"
-    print(f"Writing FX graph to file: {fname}")
+    base, ext = os.path.splitext(fname)
+    if not ext:
+        ext = ".svg"
+    print(f"Writing FX graph to file: {base}{ext}")
     g = graph_drawer.FxGraphDrawer(traced, figname)
     x = g.get_main_dot_graph()
-    x.write_svg(fname)
+    getattr(x, "write_" + ext.lstrip("."))(fname)