add use_strict_trace to tensorboard add_graph method (#63120)
authorEdward Wang (EcoF) <hack@fb.com>
Thu, 12 Aug 2021 19:10:50 +0000 (12:10 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 12 Aug 2021 19:12:12 +0000 (12:12 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63120

FAIM returns dictionaries as the model output, which throws an error when trying to trace using add_graph. Pass in `strict` to the tracer to make this user configurable.

User post: https://fb.workplace.com/groups/pytorchLightning/permalink/1510194972650369/?comment_id=1510252919311241&reply_comment_id=1510281112641755

Test Plan: unit test

Reviewed By: Reubend

Differential Revision: D30265890

fbshipit-source-id: 58b25d9500b875a29a664aa9ef4c1e7f13631fa1

test/test_tensorboard.py
torch/utils/tensorboard/_pytorch_graph.py
torch/utils/tensorboard/writer.py

index fb3dd00..eb11e3e 100644 (file)
@@ -569,6 +569,41 @@ class TestTensorBoardPytorchGraph(BaseTestCase):
             self.assertEquals(
                 sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))
 
+    def test_pytorch_graph_dict_input(self):
+        class Model(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.l = torch.nn.Linear(3, 5)
+
+            def forward(self, x):
+                return self.l(x)
+
+        class ModelDict(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.l = torch.nn.Linear(3, 5)
+
+            def forward(self, x):
+                return {"out": self.l(x)}
+
+
+        dummy_input = torch.zeros(1, 3)
+
+        with self.createSummaryWriter() as w:
+            w.add_graph(Model(), dummy_input)
+
+        with self.createSummaryWriter() as w:
+            w.add_graph(Model(), dummy_input, use_strict_trace=True)
+
+        # expect error: Encountering a dict at the output of the tracer...
+        with self.assertRaises(RuntimeError):
+            with self.createSummaryWriter() as w:
+                w.add_graph(ModelDict(), dummy_input, use_strict_trace=True)
+
+        with self.createSummaryWriter() as w:
+            w.add_graph(ModelDict(), dummy_input, use_strict_trace=False)
+
+
     def test_mlp_graph(self):
         dummy_input = (torch.zeros(2, 1, 28, 28),)
 
index ecc4b7b..1ab3e81 100644 (file)
@@ -270,7 +270,7 @@ def parse(graph, trace, args=None, omit_useless_nodes=True):
     return nodes_py.to_proto()
 
 
-def graph(model, args, verbose=False):
+def graph(model, args, verbose=False, use_strict_trace=True):
     """
     This method processes a PyTorch model and produces a `GraphDef` proto
     that can be logged to TensorBoard.
@@ -280,10 +280,13 @@ def graph(model, args, verbose=False):
       args (tuple): input tensor[s] for the model.
       verbose (bool): Whether to print out verbose information while
         processing.
+      use_strict_trace (bool): Whether to pass keyword argument `strict` to
+        `torch.jit.trace`. Pass False when you want the tracer to
+        record your mutable container types (list, dict)
     """
     with torch.onnx.select_model_mode_for_export(model, torch.onnx.TrainingMode.EVAL):  # TODO: move outside of torch.onnx?
         try:
-            trace = torch.jit.trace(model, args)
+            trace = torch.jit.trace(model, args, strict=use_strict_trace)
             graph = trace.graph
             torch._C._jit_pass_inline(graph)
         except RuntimeError as e:
index a37af38..cd7f751 100644 (file)
@@ -718,9 +718,7 @@ class SummaryWriter(object):
         torch._C._log_api_usage_once("tensorboard.logging.add_onnx_graph")
         self._get_file_writer().add_onnx_graph(load_onnx_graph(prototxt))
 
-    def add_graph(self, model, input_to_model=None, verbose=False):
-        # prohibit second call?
-        # no, let tensorboard handle it and show its warning message.
+    def add_graph(self, model, input_to_model=None, verbose=False, use_strict_trace=True):
         """Add graph data to summary.
 
         Args:
@@ -728,11 +726,14 @@ class SummaryWriter(object):
             input_to_model (torch.Tensor or list of torch.Tensor): A variable or a tuple of
                 variables to be fed.
             verbose (bool): Whether to print graph structure in console.
+            use_strict_trace (bool): Whether to pass keyword argument `strict` to
+                `torch.jit.trace`. Pass False when you want the tracer to
+                record your mutable container types (list, dict)
         """
         torch._C._log_api_usage_once("tensorboard.logging.add_graph")
         if hasattr(model, 'forward'):
             # A valid PyTorch model should have a 'forward' method
-            self._get_file_writer().add_graph(graph(model, input_to_model, verbose))
+            self._get_file_writer().add_graph(graph(model, input_to_model, verbose, use_strict_trace))
         else:
             # Caffe2 models do not have the 'forward' method
             from caffe2.proto import caffe2_pb2