import torch
-from torch.testing._internal.jit_utils import JitTestCase
+from torch.testing._internal.jit_utils import JitTestCase, execWrapper
import operator
from torch.testing import FileCheck
+from textwrap import dedent
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
torch._C._jit_pass_peephole(fn.graph)
torch._C._jit_pass_constant_propagation(fn.graph)
self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True)
+
+ def test_arange_shape(self):
+ # no opinfo for tensor constructors
+ inps = [
+ (10,),
+ (10, 10),
+ (0, 10),
+ (0, 1000),
+ (1, -1, -1),
+ (1, 0, -1),
+ (1, 2, 1),
+ (0.6, 0.89, 0.1),
+ (1, 10, 0.3),
+ (1, 10, 4),
+ (0.6, 0.7, 0.8),
+ (1, 10, 0.3),
+ # (True,), TODO: https://github.com/pytorch/pytorch/issues/63405
+ # (False,), TODO: https://github.com/pytorch/pytorch/issues/63405
+ (0, 5),
+ (0, 5, 2),
+ (0, 5 + 1e-6),
+ (0, 5 - 1e-6),
+ (10, -1 + 1e-6, -1),
+ (10, -1, -1),
+ (10, -1 - 1e-6, -1),
+ ]
+
+ for inp in inps:
+ funcs_template = dedent('''
+ def func():
+ return torch.arange({args})
+ ''')
+
+ inp_s = str(inp)[1:-1] # remove tuple parens
+ funcs_str = funcs_template.format(args=inp_s)
+ scope = {}
+ execWrapper(funcs_str, globals(), scope)
+ cu = torch.jit.CompilationUnit(funcs_str)
+ self.checkShapeAnalysis(list(cu.func().size()), cu.func.graph, assert_propagation=True, constant_prop=False)
out.append(elem)
return out
+ def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool=False):
+ return view(self, sizes)
+
def mean_dim(self: List[int], dims: List[int], keep_dim: bool, dt : Any):
out: List[int] = []
for idx in range(len(self)):
dim += dim_post_expr
return dim
+ def zero_dim_tensor(input: Any):
+ out: List[int] = []
+ return out
+
def multiply_integers(li: List[int]):
out = 1
for elem in li:
out = out * elem
return out
+ def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
+ assert end >= 0
+ return [int(torch.ceil(end))]
+
+ def arange_start(start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
+ assert end >= 0
+ assert end >= start
+ return [int(torch.ceil(end - start))]
+
+ def arange_start_step(start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
+ assert step != 0
+ if step < 0:
+ assert start >= end
+ else:
+ assert end >= start
+ return [int(torch.ceil((end - start) / step))]
+
+ def permute(input: List[int], dims: List[int]):
+ assert len(input) == len(dims)
+ ndim = len(dims)
+ seen_dims: List[int] = []
+ newSizes: List[int] = []
+ for i in range(ndim):
+ dim = maybe_wrap_dim(dims[i], ndim)
+ seen_dims.append(dim)
+ newSizes.append(input[dim])
+ for i in range(1, ndim):
+ for j in range(i):
+ assert seen_dims[i] != seen_dims[j]
+ return newSizes
+
def flatten(input: List[int], start_dim: int, end_dim: int):
start_dim = maybe_wrap_dim(start_dim, len(input))
end_dim = maybe_wrap_dim(end_dim, len(input))
{"aten::gelu(Tensor self) -> Tensor", "unary"},
{"aten::tanh(Tensor self) -> Tensor", "unary"},
{"aten::erf(Tensor self) -> (Tensor)", "unary"},
+ {"prim::NumToTensor.Scalar(Scalar a) -> Tensor", "zero_dim_tensor"},
+ {"prim::NumToTensor.bool(bool a) -> Tensor", "zero_dim_tensor"},
{"aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", "unary_four_unused_inputs"},
{"aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))", "unary_four_unused_inputs"},
+ {"aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", "arange_end"},
+ {"aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "arange_start"},
+ {"aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "arange_start_step"},
{"aten::squeeze(Tensor(a) self) -> Tensor(a)", "squeeze_nodim"},
{"aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", "squeeze"},
{"aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", "unsqueeze"},
{"aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", "conv3d"},
{"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "flatten"},
{"aten::relu(Tensor self) -> Tensor", "unary"},
+ {"aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", "permute"},
{"aten::view(Tensor(a) self, int[] size) -> Tensor(a)", "view"},
{"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "view"},
+ {"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", "view_one_unused"},
{"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"},
{"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "addmm"},
#ifdef USE_XNNPACK
self.assertEqual(should_autodiff_node,
found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg)
- def checkShapeAnalysis(self, out_size, traced_graph, assert_propagation):
+ def checkShapeAnalysis(self, out_size, traced_graph, assert_propagation, constant_prop=True):
# repropagte input shapes provided by tracing,
prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
for enable_test_mode in [True, False]:
# disallowing constants helps stress test partial eval and substitution pipeline
torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode)
torch._C._jit_erase_non_input_shape_information(traced_graph)
- torch._C._jit_pass_constant_propagation(traced_graph)
+ if constant_prop:
+ torch._C._jit_pass_constant_propagation(traced_graph)
torch._C._jit_pass_propagate_shapes_on_graph(traced_graph)
# Add sizes to default tensor type to avoid checking something out of scope
# and difficulties with tracer leaving in other parts of tensor type