From 0ef32625a8b44124d8cfadb981ff8ef482c5c1c0 Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Wed, 8 Sep 2021 18:20:46 -0700 Subject: [PATCH] [FX] make visualizer produce different formatted output (#64699) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64699 Previously we just hardcode to svg format. We should give folks a choice in terms of what format they want to see. If we give a weird extension like .abc and this will error out and we expect this to be the right behavior. Reviewed By: houseroad Differential Revision: D30718883 fbshipit-source-id: fe8827262f94ea6887999bb225de763d1909eef8 --- torch/fx/experimental/fx_acc/acc_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torch/fx/experimental/fx_acc/acc_utils.py b/torch/fx/experimental/fx_acc/acc_utils.py index a77f3fa..4ae327c 100644 --- a/torch/fx/experimental/fx_acc/acc_utils.py +++ b/torch/fx/experimental/fx_acc/acc_utils.py @@ -1,5 +1,6 @@ import inspect import json +import os from typing import Any, Tuple, Callable, Union, Dict import torch @@ -83,9 +84,10 @@ def build_raw_tensor_meta( def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_graph"): - if not fname.endswith(".svg"): - fname = fname + ".svg" - print(f"Writing FX graph to file: {fname}") + base, ext = os.path.splitext(fname) + if not ext: + ext = ".svg" + print(f"Writing FX graph to file: {base}{ext}") g = graph_drawer.FxGraphDrawer(traced, figname) x = g.get_main_dot_graph() - x.write_svg(fname) + getattr(x, "write_" + ext.lstrip("."))(fname) -- 2.7.4