From 96fb1a56ea894b098df48b45adc053b8bf7f9a37 Mon Sep 17 00:00:00 2001 From: "Edward Wang (EcoF)" Date: Thu, 12 Aug 2021 12:10:50 -0700 Subject: [PATCH] add use_strict_trace to tensorboard add_graph method (#63120) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63120 FAIM returns dictionaries as the model output, which throws an error when trying to trace using add_graph. Pass in `strict` to the tracer to make this user configurable. User post: https://fb.workplace.com/groups/pytorchLightning/permalink/1510194972650369/?comment_id=1510252919311241&reply_comment_id=1510281112641755 Test Plan: unit test Reviewed By: Reubend Differential Revision: D30265890 fbshipit-source-id: 58b25d9500b875a29a664aa9ef4c1e7f13631fa1 --- test/test_tensorboard.py | 35 +++++++++++++++++++++++++++++++ torch/utils/tensorboard/_pytorch_graph.py | 7 +++++-- torch/utils/tensorboard/writer.py | 9 ++++---- 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/test/test_tensorboard.py b/test/test_tensorboard.py index fb3dd00..eb11e3e 100644 --- a/test/test_tensorboard.py +++ b/test/test_tensorboard.py @@ -569,6 +569,41 @@ class TestTensorBoardPytorchGraph(BaseTestCase): self.assertEquals( sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys())) + def test_pytorch_graph_dict_input(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.l = torch.nn.Linear(3, 5) + + def forward(self, x): + return self.l(x) + + class ModelDict(torch.nn.Module): + def __init__(self): + super().__init__() + self.l = torch.nn.Linear(3, 5) + + def forward(self, x): + return {"out": self.l(x)} + + + dummy_input = torch.zeros(1, 3) + + with self.createSummaryWriter() as w: + w.add_graph(Model(), dummy_input) + + with self.createSummaryWriter() as w: + w.add_graph(Model(), dummy_input, use_strict_trace=True) + + # expect error: Encountering a dict at the output of the tracer... + with self.assertRaises(RuntimeError): + with self.createSummaryWriter() as w: + w.add_graph(ModelDict(), dummy_input, use_strict_trace=True) + + with self.createSummaryWriter() as w: + w.add_graph(ModelDict(), dummy_input, use_strict_trace=False) + + def test_mlp_graph(self): dummy_input = (torch.zeros(2, 1, 28, 28),) diff --git a/torch/utils/tensorboard/_pytorch_graph.py b/torch/utils/tensorboard/_pytorch_graph.py index ecc4b7b..1ab3e81 100644 --- a/torch/utils/tensorboard/_pytorch_graph.py +++ b/torch/utils/tensorboard/_pytorch_graph.py @@ -270,7 +270,7 @@ def parse(graph, trace, args=None, omit_useless_nodes=True): return nodes_py.to_proto() -def graph(model, args, verbose=False): +def graph(model, args, verbose=False, use_strict_trace=True): """ This method processes a PyTorch model and produces a `GraphDef` proto that can be logged to TensorBoard. @@ -280,10 +280,13 @@ def graph(model, args, verbose=False): args (tuple): input tensor[s] for the model. verbose (bool): Whether to print out verbose information while processing. + use_strict_trace (bool): Whether to pass keyword argument `strict` to + `torch.jit.trace`. Pass False when you want the tracer to + record your mutable container types (list, dict) """ with torch.onnx.select_model_mode_for_export(model, torch.onnx.TrainingMode.EVAL): # TODO: move outside of torch.onnx? try: - trace = torch.jit.trace(model, args) + trace = torch.jit.trace(model, args, strict=use_strict_trace) graph = trace.graph torch._C._jit_pass_inline(graph) except RuntimeError as e: diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index a37af38..cd7f751 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -718,9 +718,7 @@ class SummaryWriter(object): torch._C._log_api_usage_once("tensorboard.logging.add_onnx_graph") self._get_file_writer().add_onnx_graph(load_onnx_graph(prototxt)) - def add_graph(self, model, input_to_model=None, verbose=False): - # prohibit second call? - # no, let tensorboard handle it and show its warning message. + def add_graph(self, model, input_to_model=None, verbose=False, use_strict_trace=True): """Add graph data to summary. Args: @@ -728,11 +726,14 @@ class SummaryWriter(object): input_to_model (torch.Tensor or list of torch.Tensor): A variable or a tuple of variables to be fed. verbose (bool): Whether to print graph structure in console. + use_strict_trace (bool): Whether to pass keyword argument `strict` to + `torch.jit.trace`. Pass False when you want the tracer to + record your mutable container types (list, dict) """ torch._C._log_api_usage_once("tensorboard.logging.add_graph") if hasattr(model, 'forward'): # A valid PyTorch model should have a 'forward' method - self._get_file_writer().add_graph(graph(model, input_to_model, verbose)) + self._get_file_writer().add_graph(graph(model, input_to_model, verbose, use_strict_trace)) else: # Caffe2 models do not have the 'forward' method from caffe2.proto import caffe2_pb2 -- 2.7.4