import inspect
import json
+import os
from typing import Any, Tuple, Callable, Union, Dict
import torch
def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_graph"):
- if not fname.endswith(".svg"):
- fname = fname + ".svg"
- print(f"Writing FX graph to file: {fname}")
+ base, ext = os.path.splitext(fname)
+ if not ext:
+ ext = ".svg"
+ print(f"Writing FX graph to file: {base}{ext}")
g = graph_drawer.FxGraphDrawer(traced, figname)
x = g.get_main_dot_graph()
- x.write_svg(fname)
+ getattr(x, "write_" + ext.lstrip("."))(fname)