Add permute, arange (#63407)
authorElias Ellison <eellison@devfair044.h1.fair>
Wed, 8 Sep 2021 01:19:14 +0000 (18:19 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 01:22:24 +0000 (18:22 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63407

Test Plan: Imported from OSS

Reviewed By: driazati

Differential Revision: D30738149

Pulled By: eellison

fbshipit-source-id: 36d572488408d38b0643aa93cb08aab5c45218ad

test/jit/test_symbolic_shape_analysis.py
torch/csrc/jit/passes/peephole_non_tensor.cpp
torch/csrc/jit/runtime/symbolic_shape_registry.cpp
torch/testing/_internal/common_jit.py
torch/testing/_internal/common_methods_invocations.py

index 6d4e33c..7c067a3 100644 (file)
@@ -1,9 +1,10 @@
 import torch
-from torch.testing._internal.jit_utils import JitTestCase
+from torch.testing._internal.jit_utils import JitTestCase, execWrapper
 import operator
 
 from torch.testing import FileCheck
 
+from textwrap import dedent
 
 if __name__ == '__main__':
     raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
@@ -177,3 +178,42 @@ class TestSymbolicShapeAnalysis(JitTestCase):
             torch._C._jit_pass_peephole(fn.graph)
             torch._C._jit_pass_constant_propagation(fn.graph)
             self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True)
+
+    def test_arange_shape(self):
+        # no opinfo for tensor constructors
+        inps = [
+            (10,),
+            (10, 10),
+            (0, 10),
+            (0, 1000),
+            (1, -1, -1),
+            (1, 0, -1),
+            (1, 2, 1),
+            (0.6, 0.89, 0.1),
+            (1, 10, 0.3),
+            (1, 10, 4),
+            (0.6, 0.7, 0.8),
+            (1, 10, 0.3),
+            # (True,),  TODO: https://github.com/pytorch/pytorch/issues/63405
+            # (False,), TODO: https://github.com/pytorch/pytorch/issues/63405
+            (0, 5),
+            (0, 5, 2),
+            (0, 5 + 1e-6),
+            (0, 5 - 1e-6),
+            (10, -1 + 1e-6, -1),
+            (10, -1, -1),
+            (10, -1 - 1e-6, -1),
+        ]
+
+        for inp in inps:
+            funcs_template = dedent('''
+            def func():
+                return torch.arange({args})
+            ''')
+
+            inp_s = str(inp)[1:-1]  # remove tuple parens
+            funcs_str = funcs_template.format(args=inp_s)
+            scope = {}
+            execWrapper(funcs_str, globals(), scope)
+            cu = torch.jit.CompilationUnit(funcs_str)
+            self.checkShapeAnalysis(list(cu.func().size()), cu.func.graph, assert_propagation=True, constant_prop=False)
index f93eb9e..cffb4bb 100644 (file)
@@ -193,13 +193,14 @@ struct PeepholeOptimizeNonTensorImpl {
           node->output()->replaceAllUsesWith(node->input());
           changed = true;
         }
-      } else if (node->kind() == aten::Int) {
-        if (node->input()->type()->cast<IntType>()) {
-          GRAPH_UPDATE(
-              "Removing ", getHeader(node), " as input is already an integer");
-          node->output()->replaceAllUsesWith(node->input());
-          changed = true;
-        }
+      } else if (
+          (node->kind() == aten::Int || node->kind() == aten::ceil) &&
+          node->inputs().size() == 1 &&
+          node->input()->type()->cast<IntType>()) {
+        GRAPH_UPDATE(
+            "Removing ", getHeader(node), " as input is already an integer");
+        node->output()->replaceAllUsesWith(node->input());
+        changed = true;
       } else if (node->kind() == aten::ne || node->kind() == aten::eq) {
         if (node->inputs().size() != 2 ||
             node->inputs().at(0) != node->inputs().at(1)) {
index dd2a2e8..ae8ddd1 100644 (file)
@@ -80,6 +80,9 @@ const std::string shape_compute_functions =
             out.append(elem)
           return out
 
+        def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool=False):
+          return view(self, sizes)
+
         def mean_dim(self: List[int], dims: List[int], keep_dim: bool, dt : Any):
           out: List[int] = []
           for idx in range(len(self)):
@@ -344,12 +347,47 @@ const std::string shape_compute_functions =
             dim += dim_post_expr
           return dim
 
+        def zero_dim_tensor(input: Any):
+          out: List[int] = []
+          return out
+
         def multiply_integers(li: List[int]):
           out = 1
           for elem in li:
             out = out * elem
           return out
 
+        def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
+          assert end >= 0
+          return [int(torch.ceil(end))]
+
+        def arange_start(start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
+          assert end >= 0
+          assert end >= start
+          return [int(torch.ceil(end - start))]
+
+        def arange_start_step(start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
+          assert step != 0
+          if step < 0:
+            assert start >= end
+          else:
+            assert end >= start
+          return [int(torch.ceil((end - start) / step))]
+
+        def permute(input: List[int], dims: List[int]):
+          assert len(input) == len(dims)
+          ndim = len(dims)
+          seen_dims: List[int] = []
+          newSizes: List[int] = []
+          for i in range(ndim):
+            dim = maybe_wrap_dim(dims[i], ndim)
+            seen_dims.append(dim)
+            newSizes.append(input[dim])
+          for i in range(1, ndim):
+            for j in range(i):
+              assert seen_dims[i] != seen_dims[j]
+          return newSizes
+
         def flatten(input: List[int], start_dim: int, end_dim: int):
           start_dim = maybe_wrap_dim(start_dim, len(input))
           end_dim = maybe_wrap_dim(end_dim, len(input))
@@ -420,8 +458,13 @@ static const OperatorMap<std::string>& get_schema_to_function_graph() {
       {"aten::gelu(Tensor self) -> Tensor", "unary"},
       {"aten::tanh(Tensor self) -> Tensor", "unary"},
       {"aten::erf(Tensor self) -> (Tensor)", "unary"},
+      {"prim::NumToTensor.Scalar(Scalar a) -> Tensor", "zero_dim_tensor"},
+      {"prim::NumToTensor.bool(bool a) -> Tensor", "zero_dim_tensor"},
       {"aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", "unary_four_unused_inputs"},
       {"aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))", "unary_four_unused_inputs"},
+      {"aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", "arange_end"},
+      {"aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "arange_start"},
+      {"aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "arange_start_step"},
       {"aten::squeeze(Tensor(a) self) -> Tensor(a)", "squeeze_nodim"},
       {"aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", "squeeze"},
       {"aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", "unsqueeze"},
@@ -443,8 +486,10 @@ static const OperatorMap<std::string>& get_schema_to_function_graph() {
       {"aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", "conv3d"},
       {"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "flatten"},
       {"aten::relu(Tensor self) -> Tensor", "unary"},
+      {"aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", "permute"},
       {"aten::view(Tensor(a) self, int[] size) -> Tensor(a)", "view"},
       {"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "view"},
+      {"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", "view_one_unused"},
       {"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"},
       {"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "addmm"},
 #ifdef USE_XNNPACK
index 89533a6..c32300e 100644 (file)
@@ -281,7 +281,7 @@ class JitCommonTestCase(TestCase):
         self.assertEqual(should_autodiff_node,
                          found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg)
 
-    def checkShapeAnalysis(self, out_size, traced_graph, assert_propagation):
+    def checkShapeAnalysis(self, out_size, traced_graph, assert_propagation, constant_prop=True):
         # repropagte input shapes provided by tracing,
         prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
         for enable_test_mode in [True, False]:
@@ -289,7 +289,8 @@ class JitCommonTestCase(TestCase):
             # disallowing constants helps stress test partial eval and substitution pipeline
             torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode)
             torch._C._jit_erase_non_input_shape_information(traced_graph)
-            torch._C._jit_pass_constant_propagation(traced_graph)
+            if constant_prop:
+                torch._C._jit_pass_constant_propagation(traced_graph)
             torch._C._jit_pass_propagate_shapes_on_graph(traced_graph)
             # Add sizes to default tensor type to avoid checking something out of scope
             # and difficulties with tracer leaving in other parts of tensor type
index 086f01c..375aa5f 100644 (file)
@@ -6477,10 +6477,8 @@ op_db: List[OpInfo] = [
            op=lambda self, shape: self.expand(shape),
            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
            sample_inputs_func=sample_inputs_expand,
-           skips=(
-               # Because expand does not have a function variant.
-               SkipInfo('TestJit', 'test_variant_consistency_jit'),),
            supports_forward_ad=True,
+           assert_jit_shape_analysis=True,
            supports_out=False),
     OpInfo('expand_as',
            op=lambda self, other: self.expand_as(other),
@@ -7768,6 +7766,7 @@ op_db: List[OpInfo] = [
            dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
            supports_out=False,
            assert_autodiffed=True,
+           assert_jit_shape_analysis=True,
            supports_forward_ad=True,
            sample_inputs_func=sample_inputs_permute),
     OpInfo('pow',