self.assertEquals(
sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))
+ def test_pytorch_graph_dict_input(self):
+ class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.l = torch.nn.Linear(3, 5)
+
+ def forward(self, x):
+ return self.l(x)
+
+ class ModelDict(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.l = torch.nn.Linear(3, 5)
+
+ def forward(self, x):
+ return {"out": self.l(x)}
+
+
+ dummy_input = torch.zeros(1, 3)
+
+ with self.createSummaryWriter() as w:
+ w.add_graph(Model(), dummy_input)
+
+ with self.createSummaryWriter() as w:
+ w.add_graph(Model(), dummy_input, use_strict_trace=True)
+
+ # expect error: Encountering a dict at the output of the tracer...
+ with self.assertRaises(RuntimeError):
+ with self.createSummaryWriter() as w:
+ w.add_graph(ModelDict(), dummy_input, use_strict_trace=True)
+
+ with self.createSummaryWriter() as w:
+ w.add_graph(ModelDict(), dummy_input, use_strict_trace=False)
+
+
def test_mlp_graph(self):
dummy_input = (torch.zeros(2, 1, 28, 28),)
return nodes_py.to_proto()
-def graph(model, args, verbose=False):
+def graph(model, args, verbose=False, use_strict_trace=True):
"""
This method processes a PyTorch model and produces a `GraphDef` proto
that can be logged to TensorBoard.
args (tuple): input tensor[s] for the model.
verbose (bool): Whether to print out verbose information while
processing.
+ use_strict_trace (bool): Whether to pass keyword argument `strict` to
+ `torch.jit.trace`. Pass False when you want the tracer to
+ record your mutable container types (list, dict)
"""
with torch.onnx.select_model_mode_for_export(model, torch.onnx.TrainingMode.EVAL): # TODO: move outside of torch.onnx?
try:
- trace = torch.jit.trace(model, args)
+ trace = torch.jit.trace(model, args, strict=use_strict_trace)
graph = trace.graph
torch._C._jit_pass_inline(graph)
except RuntimeError as e:
torch._C._log_api_usage_once("tensorboard.logging.add_onnx_graph")
self._get_file_writer().add_onnx_graph(load_onnx_graph(prototxt))
- def add_graph(self, model, input_to_model=None, verbose=False):
- # prohibit second call?
- # no, let tensorboard handle it and show its warning message.
+ def add_graph(self, model, input_to_model=None, verbose=False, use_strict_trace=True):
"""Add graph data to summary.
Args:
input_to_model (torch.Tensor or list of torch.Tensor): A variable or a tuple of
variables to be fed.
verbose (bool): Whether to print graph structure in console.
+ use_strict_trace (bool): Whether to pass keyword argument `strict` to
+ `torch.jit.trace`. Pass False when you want the tracer to
+ record your mutable container types (list, dict)
"""
torch._C._log_api_usage_once("tensorboard.logging.add_graph")
if hasattr(model, 'forward'):
# A valid PyTorch model should have a 'forward' method
- self._get_file_writer().add_graph(graph(model, input_to_model, verbose))
+ self._get_file_writer().add_graph(graph(model, input_to_model, verbose, use_strict_trace))
else:
# Caffe2 models do not have the 'forward' method
from caffe2.proto import caffe2_pb2