x = torch.randn(3, 4)
self.assertEqual(traced(x), imported(x))
+ def test_onnx_transpose_incomplete_tensor_type(self):
+ # Smoke test to get us into the state where we are attempting to export
+ # a transpose op, where the input is a TensorType rather than a
+ # CompleteTensorType. This would previously not work, since we would
+ # take the size of the input and use the length of its sizes as the
+ # number of dimensions in the permutation.
+ class Foo(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x):
+ return x.contiguous().transpose(0, 1).sum()
+
+ class TraceMe(torch.nn.Module):
+ def __init__(self):
+ super(TraceMe, self).__init__()
+ self.foo = Foo()
+
+ def forward(self, x):
+ return self.foo(x)
+
+ tm = TraceMe()
+ tm = torch.jit.trace(tm, torch.rand(3, 4))
+ example_outputs = (tm(torch.rand(3, 4)),)
+ f = io.BytesIO()
+ torch.onnx._export(tm, (torch.rand(3, 4),), f, example_outputs=example_outputs)
+
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_cuda_export_restore(self):
class Sub(torch.jit.ScriptModule):
# - Looking for inplace ops? They're detected by the trailing underscore, and
# transparently dispatched to their non inplace versions in
# 'run_symbolic_function'. See Note [Export inplace]
+#
+# ---------------------------------------------------------------------
+# A note on Tensor types
+# ---------------------------------------------------------------------
+#
+# In general, we should avoid depending on the type of Tensor Values contained
+# within the trace graph. However, this is sometimes unavoidable (due to ONNX
+# spec requirements, etc). If you are implementing a symbolic and need Tensor
+# type information, note that there are several levels of Tensor types, defined
+# in aten/src/ATen/core/jit_type.h:
+#
+# DynamicType - This is a Tensor, but we don't know anything about its
+# properties (e.g. scalar type, # dims, shapes).
+# Appears as `Tensor` in graph print-outs.
+# UndefinedTensorType <: DynamicType - Denotes an undefined Tensor
+# TensorType <: DynamicType - Denotes a Tensor for which we know the scalar
+# type and number of dimensions, but not the concrete
+# shapes. For example, appears as 'Float(*, *)' in
+# graph print-outs. Useful accessor methods include
+# dim() and scalarType()
+# CompleteTensorType <: TensorType - Denotes a Tensor for which we know the
+# concrete sizes in addition to the information
+# contained in TensorTyper. This adds a sizes()
+# method which can be used to retrieve the
+# concrete sizes.
+#
+# In general, we should prefer to rely on the least specific information possible.
+# For example, not relying on tensor properties at all is better than relying
+# on the number of dimensions (TensorType) which is better than relying on
+# concrete shapes (CompleteTensorType). Doing so will make the export symbolics
+# more robust to different graphs.
# ---------------------------------------------------------------------
# Helper functions
return self
# NB: Transpose in ONNX is actually a Permute
- axes = list(range(len(self.type().sizes())))
+ axes = list(range(self.type().dim()))
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
return g.op("Transpose", self, perm_i=axes)
# So only when dim and axis both equal to ndim - 1 (the last dimension),
# their semantics are equivalent.
if dim < 0:
- dim = len(input.type().sizes()) + dim
- if len(input.type().sizes()) != dim + 1:
+ dim = input.type().dim() + dim
+ if input.type().dim() != dim + 1:
return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
return g.op('Softmax', input, axis_i=dim)
def constant_pad_nd(g, input, padding, value):
from torch.autograd._functions.utils import prepare_onnx_paddings
mode = "constant"
- paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
+ paddings = prepare_onnx_paddings(input.type().dim(), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode, value_f=value)
def reflection_pad(g, input, padding):
from torch.autograd._functions.utils import prepare_onnx_paddings
mode = "reflect"
- paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
+ paddings = prepare_onnx_paddings(input.type().dim(), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode)
def replication_pad(g, input, padding):
from torch.autograd._functions.utils import prepare_onnx_paddings
mode = "edge"
- paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
+ paddings = prepare_onnx_paddings(input.type().dim(), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode)
# PyTorch dim and ONNX axis have different meanings.
# See Softmax comment for details.
if dim < 0:
- dim = len(input.type().sizes()) + dim
- if len(input.type().sizes()) != dim + 1:
+ dim = input.type().dim() + dim
+ if input.type().dim() != dim + 1:
return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
return g.op("LogSoftmax", input, axis_i=dim)
args = [input, weight]
# ONNX only supports 1D bias
- if bias.node().kind() != "prim::Undefined" and len(bias.type().sizes()) == 1:
+ if bias.node().kind() != "prim::Undefined" and bias.type().dim() == 1:
args.append(bias)
kwargs = {"kernel_shape_i": weight_size[2:],
n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs)
- if bias.node().kind() != "prim::Undefined" and len(bias.type().sizes()) != 1:
+ if bias.node().kind() != "prim::Undefined" and bias.type().dim() != 1:
return g.op("Add", n, bias)
else:
return n