Expose dim() on type and use it in ONNX symbolics (#15933)
authorJames Reed <jamesreed@fb.com>
Fri, 11 Jan 2019 22:51:17 +0000 (14:51 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 11 Jan 2019 22:54:19 +0000 (14:54 -0800)
Summary:
While integrating fork/join into production translation, we found that trying to export `transpose()` where the input is of `TensorType` (rather than `CompleteTensorType`) failed. This is not ideal, since `TensorType` still contains the number of dimensions of the tensor, and that's all the `transpose` symbolic needs.

This PR introduces a pybind binding for `dim()` on `TensorType` (and `CompleteTensorType` by inheritance). We now use this in places where it logically makes sense in the symbolics: those symbolics which only require knowledge of the number of dimensions rather than concrete sizes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15933

Differential Revision: D13639657

Pulled By: jamesr66a

fbshipit-source-id: 6e50e407e93060085fd00a686a928764d0ec888d

test/test_jit.py
torch/csrc/jit/python_ir.cpp
torch/onnx/symbolic.py

index 8e7918b..fe56294 100644 (file)
@@ -1777,6 +1777,31 @@ class TestJit(JitTestCase):
         x = torch.randn(3, 4)
         self.assertEqual(traced(x), imported(x))
 
+    def test_onnx_transpose_incomplete_tensor_type(self):
+        # Smoke test to get us into the state where we are attempting to export
+        # a transpose op, where the input is a TensorType rather than a
+        # CompleteTensorType. This would previously not work, since we would
+        # take the size of the input and use the length of its sizes as the
+        # number of dimensions in the permutation.
+        class Foo(torch.jit.ScriptModule):
+            @torch.jit.script_method
+            def forward(self, x):
+                return x.contiguous().transpose(0, 1).sum()
+
+        class TraceMe(torch.nn.Module):
+            def __init__(self):
+                super(TraceMe, self).__init__()
+                self.foo = Foo()
+
+            def forward(self, x):
+                return self.foo(x)
+
+        tm = TraceMe()
+        tm = torch.jit.trace(tm, torch.rand(3, 4))
+        example_outputs = (tm(torch.rand(3, 4)),)
+        f = io.BytesIO()
+        torch.onnx._export(tm, (torch.rand(3, 4),), f, example_outputs=example_outputs)
+
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     def test_cuda_export_restore(self):
         class Sub(torch.jit.ScriptModule):
index 6d3ab9a..fadc596 100644 (file)
@@ -491,6 +491,7 @@ void initPythonIRBindings(PyObject* module_) {
             return s.str();
           })
       .def("kind", [](const Type& t) { return typeKindToString(t.kind()); })
+      .def("dim", [](const Type& t) { return t.expect<TensorType>()->dim(); })
       .def(
           "sizes",
           [](Type& t) { return t.expect<CompleteTensorType>()->sizes(); })
index 82e469e..8655d33 100644 (file)
@@ -27,6 +27,37 @@ import itertools
 # - Looking for inplace ops?  They're detected by the trailing underscore, and
 #   transparently dispatched to their non inplace versions in
 #   'run_symbolic_function'.   See Note [Export inplace]
+#
+# ---------------------------------------------------------------------
+# A note on Tensor types
+# ---------------------------------------------------------------------
+#
+# In general, we should avoid depending on the type of Tensor Values contained
+# within the trace graph. However, this is sometimes unavoidable (due to ONNX
+# spec requirements, etc). If you are implementing a symbolic and need Tensor
+# type information, note that there are several levels of Tensor types, defined
+# in aten/src/ATen/core/jit_type.h:
+#
+# DynamicType - This is a Tensor, but we don't know anything about its
+#               properties (e.g. scalar type, # dims, shapes).
+#               Appears as `Tensor` in graph print-outs.
+# UndefinedTensorType <: DynamicType - Denotes an undefined Tensor
+# TensorType <: DynamicType - Denotes a Tensor for which we know the scalar
+#                             type and number of dimensions, but not the concrete
+#                             shapes. For example, appears as 'Float(*, *)' in
+#                             graph print-outs. Useful accessor methods include
+#                             dim() and scalarType()
+# CompleteTensorType <: TensorType - Denotes a Tensor for which we know the
+#                                    concrete sizes in addition to the information
+#                                    contained in TensorTyper. This adds a sizes()
+#                                    method which can be used to retrieve the
+#                                    concrete sizes.
+#
+# In general, we should prefer to rely on the least specific information possible.
+# For example, not relying on tensor properties at all is better than relying
+# on the number of dimensions (TensorType) which is better than relying on
+# concrete shapes (CompleteTensorType). Doing so will make the export symbolics
+# more robust to different graphs.
 
 # ---------------------------------------------------------------------
 # Helper functions
@@ -394,7 +425,7 @@ def transpose(g, self, dim0, dim1):
         return self
 
     # NB: Transpose in ONNX is actually a Permute
-    axes = list(range(len(self.type().sizes())))
+    axes = list(range(self.type().dim()))
     axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
     return g.op("Transpose", self, perm_i=axes)
 
@@ -528,8 +559,8 @@ def softmax(g, input, dim):
     # So only when dim and axis both equal to ndim - 1 (the last dimension),
     # their semantics are equivalent.
     if dim < 0:
-        dim = len(input.type().sizes()) + dim
-    if len(input.type().sizes()) != dim + 1:
+        dim = input.type().dim() + dim
+    if input.type().dim() != dim + 1:
         return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
     return g.op('Softmax', input, axis_i=dim)
 
@@ -630,7 +661,7 @@ def adaptive_max_pool2d(g, input, output_size):
 def constant_pad_nd(g, input, padding, value):
     from torch.autograd._functions.utils import prepare_onnx_paddings
     mode = "constant"
-    paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
+    paddings = prepare_onnx_paddings(input.type().dim(), padding)
     return g.op("Pad", input, pads_i=paddings, mode_s=mode, value_f=value)
 
 
@@ -638,7 +669,7 @@ def constant_pad_nd(g, input, padding, value):
 def reflection_pad(g, input, padding):
     from torch.autograd._functions.utils import prepare_onnx_paddings
     mode = "reflect"
-    paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
+    paddings = prepare_onnx_paddings(input.type().dim(), padding)
     return g.op("Pad", input, pads_i=paddings, mode_s=mode)
 
 
@@ -646,7 +677,7 @@ def reflection_pad(g, input, padding):
 def replication_pad(g, input, padding):
     from torch.autograd._functions.utils import prepare_onnx_paddings
     mode = "edge"
-    paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
+    paddings = prepare_onnx_paddings(input.type().dim(), padding)
     return g.op("Pad", input, pads_i=paddings, mode_s=mode)
 
 
@@ -747,8 +778,8 @@ def log_softmax(g, input, dim=None):
     # PyTorch dim and ONNX axis have different meanings.
     # See Softmax comment for details.
     if dim < 0:
-        dim = len(input.type().sizes()) + dim
-    if len(input.type().sizes()) != dim + 1:
+        dim = input.type().dim() + dim
+    if input.type().dim() != dim + 1:
         return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
     return g.op("LogSoftmax", input, axis_i=dim)
 
@@ -760,7 +791,7 @@ def _convolution(g, input, weight, bias, stride, padding, dilation,
 
     args = [input, weight]
     # ONNX only supports 1D bias
-    if bias.node().kind() != "prim::Undefined" and len(bias.type().sizes()) == 1:
+    if bias.node().kind() != "prim::Undefined" and bias.type().dim() == 1:
         args.append(bias)
 
     kwargs = {"kernel_shape_i": weight_size[2:],
@@ -781,7 +812,7 @@ def _convolution(g, input, weight, bias, stride, padding, dilation,
 
     n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs)
 
-    if bias.node().kind() != "prim::Undefined" and len(bias.type().sizes()) != 1:
+    if bias.node().kind() != "prim::Undefined" and bias.type().dim() != 1:
         return g.op("Add", n, bias)
     else:
         return n