From 221edddd18e2c434a67e97689f882e42862e7ada Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Wed, 27 Feb 2019 18:59:19 -0800 Subject: [PATCH] disallow shape analysis with resize ops (#17518) Summary: resize_ and resize_as resize the input tensor. because our shape analysis is flow invariant, we don't do shape analysis on any op that relies on a Tensor that can alias a resized Tensor. E.g. in the following graph the x += 10 x may have been resized. ``` torch.jit.script def test(x, y): for i in range(10): x += 10 x.resize_as_([1 for i in int(range(torch.rand()))) return x ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/17518 Differential Revision: D14249835 Pulled By: eellison fbshipit-source-id: f281b468ccb8c29eeb0f68ca5458cc7246a166d9 --- aten/src/ATen/core/aten_interned_strings.h | 2 + test/test_jit.py | 75 +++++++++++++++++++++++++++++- torch/csrc/jit/passes/shape_analysis.cpp | 67 ++++++++++++++++++++++++-- torch/csrc/jit/python_ir.cpp | 4 +- torch/jit/annotations.py | 6 +-- torch/jit/supported_ops.py | 2 +- torch/onnx/symbolic.py | 12 ++--- 7 files changed, 149 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index dae1933..a739996 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -561,7 +561,9 @@ _(aten, replication_pad3d_forward) \ _(aten, reshape) \ _(aten, reshape_as) \ _(aten, resize) \ +_(aten, resize_) \ _(aten, resize_as) \ +_(aten, resize_as_) \ _(aten, rfft) \ _(aten, rnn_relu) \ _(aten, rnn_relu_cell) \ diff --git a/test/test_jit.py b/test/test_jit.py index 3ee0ff4..29d9522 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -40,6 +40,8 @@ from common_methods_invocations import method_tests as autograd_method_tests from common_methods_invocations import create_input, unpack_variables, \ exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL from torch.testing import FileCheck +from torch._C import TensorType, TupleType, FloatType, IntType, \ + ListType, StringType, DictType from copy import deepcopy import random from typing import List, Dict, Optional @@ -4272,6 +4274,74 @@ a") # NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs) + def test_resize_input_ops(self): + # resize_ and resize_as resize the input tensor. because our shape analysis + # is flow invariant, we set any Tensor that can alias a resized Tensor + # to the base Tensor Type, without size information. + + # testing that value which is an input of a graph gets handled + def out_op_graph_input(): + @torch.jit.script + def test(x, y, z): + torch.mul(x, y, out=z) + return z + + torch._C._jit_pass_shape_analysis( + test.graph, (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False) + self.assertTrue(next(test.graph.outputs()).type() == TensorType.get()) + out_op_graph_input() + + def test_resize(): + @torch.jit.script + def test(x): + after_resize_alias = torch.zeros([2]) + for _i in range(5): + b = x + 1 + f = [1] + before_resize_alias = b.sub_(1) + # for i in range(10): + f.append(1) + b.resize_(f) + after_resize_alias = b.add_(1) + return after_resize_alias + + g = test.graph + self.run_pass('constant_propagation', g) + torch._C._jit_pass_shape_analysis( + g, (torch.zeros(1, 1),), False) + resize_node = g.findNode("aten::resize_") + # first input and output of b.resize_ is b + self.assertTrue(next(resize_node.inputs()).type() == TensorType.get()) + self.assertTrue(next(resize_node.outputs()).type() == TensorType.get()) + + # correctly propagates to b alias set + before_resize = g.findNode("aten::sub_") + self.assertTrue(next(before_resize.outputs()).type() == TensorType.get()) + + after_resize = g.findNode("aten::add_") + self.assertTrue(next(after_resize.outputs()).type() == TensorType.get()) + + test_resize() + + def test_resize_as(): + @torch.jit.script + def test(x): + b = torch.zeros([2, 2]) + b.resize_as_(x) + return b + + g = test.graph + self.run_pass('constant_propagation', g) + torch._C._jit_pass_shape_analysis( + g, (torch.zeros(1, 1),), False) + + # x doesn't alias a resized op so it shouldn't be set to base Tensor type + self.assertTrue(next(g.inputs()).type() != TensorType.get()) + # return is resized + self.assertTrue(next(g.outputs()).type() == TensorType.get()) + + test_resize_as() + def test_view_shape_prop(self): cu = torch.jit.CompilationUnit(''' def test_view_shape_prop(a): @@ -4293,7 +4363,8 @@ a") x = torch.randn(3, 1, 5, requires_grad=True) graph = torch.jit.script(fn).graph torch._C._jit_pass_shape_analysis(graph, (x,), False) - self.assertTrue(next(graph.outputs()).type().kind() != 'DynamicType') + a = next(graph.outputs()).type().kind() + self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType') def test_integral_shape_inference(self): cu = torch.jit.CompilationUnit(''' @@ -11082,7 +11153,7 @@ EXCLUDE_TRACED = { 'test_split_dim_neg0', # The following fail due to #12024. - # A prim::ListConstruct is involved and the indices get traced as DynamicType, + # A prim::ListConstruct is involved and the indices get traced as TensorType, # which always require_grad. This causes a crash in autodiff. 'test___getitem___adv_index', 'test___getitem___adv_index_beg', diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 1742128..a2ce91d 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -50,8 +50,9 @@ bool isValidReturnForRunning(Value* v) { class ShapePropagator { public: - explicit ShapePropagator(std::shared_ptr graph) - : aliasDb_(std::move(graph)) {} + explicit ShapePropagator(std::shared_ptr graph) : aliasDb_(graph) { + collectResizeSet(std::move(graph)->block()); + } void PropagateShapeOnBlock(Block* block, bool insert_expands = true) { for (Node* node : block->nodes()) { @@ -70,11 +71,51 @@ class ShapePropagator { } private: + ValueSet resized_alias_set; const AliasDb aliasDb_; + bool resizesInput(Node* n) { + static std::unordered_set resize_ops{ + aten::resize_, + aten::resize_as_, + }; + + if (resize_ops.count(n->kind())) + return true; + + if (!n->maybeSchema()) + return false; + + // ops which take the result and write to input "out" + if (auto out_arg_index = n->schema().argumentIndexWithName("out")) { + auto arg = n->schema().arguments().at(*out_arg_index); + return arg.kwarg_only() && arg.type()->isSubtypeOf(TensorType::get()); + } + return false; + } + + void collectResizeSet(Block* block) { + for (Node* n : block->nodes()) { + for (Block* b : n->blocks()) { + collectResizeSet(b); + } + if (resizesInput(n)) { + for (const auto input : n->inputs()) { + if (aliasDb_.writesToAlias(n, {input}, /*recurseBlocks*/ false)) { + resized_alias_set.insert(input); + } + } + } + } + } + + void setUnshapedType(Value* o) { + o->setType(unshapedType(o->type())); + } + void setUnshapedType(Node* node) { for (auto o : node->outputs()) { - o->setType(unshapedType(o->type())); + setUnshapedType(o); } } @@ -348,7 +389,25 @@ class ShapePropagator { setUnshapedType(cat_node); } + bool mayAliasResizedSet(at::ArrayRef vs) { + bool in_resize = false; + for (auto v : vs) { + if (aliasDb_.mayAlias({v}, resized_alias_set)) { + setUnshapedType(v); + in_resize = true; + } + } + return in_resize; + } + void PropagateShapeOnNode(Node* node, bool insert_expands = true) { + // Certain ops like resize_ change the input tensors size. Because our + // analysis is flow invariant, we set any Tensor that can alias a resized + // Tensor to the base Tensor Type without size information. + if (mayAliasResizedSet(node->inputs())) { + return setUnshapedType(node); + } + // These don't require the types, and have complicated schema. Return early // after we process them. switch (node->kind()) { @@ -1631,12 +1690,10 @@ void EraseShapeInformation(Block* b) { } } } - } // anonymous namespace void EraseShapeInformation(const std::shared_ptr& graph) { EraseShapeInformation(graph->block()); } - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 87ce018..04e60a3 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -551,7 +551,7 @@ void initPythonIRBindings(PyObject* module_) { std::vector variables; variables.reserve(tensors.size()); for (auto& tensor : tensors) { - variables.push_back(std::move(tensor)); + variables.emplace_back(std::move(tensor)); } return variables; }) @@ -644,7 +644,7 @@ void initPythonIRBindings(PyObject* module_) { .def_static("get", &IntType::get); py::class_>(m, "FloatType") .def_static("get", &FloatType::get); - py::class_>(m, "DynamicType") + py::class_>(m, "TensorType") .def_static("get", &TensorType::get); py::class_>(m, "BoolType") .def_static("get", &BoolType::get); diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index 22e5c3f..b318539 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -5,7 +5,7 @@ import inspect import torch from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \ BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict -from torch._C import DynamicType, TupleType, FloatType, IntType, \ +from torch._C import TensorType, TupleType, FloatType, IntType, \ ListType, StringType, DictType from textwrap import dedent @@ -163,9 +163,9 @@ def try_real_annotations(fn): def ann_to_type(ann): if ann is None: - return DynamicType.get() + return TensorType.get() elif ann is torch.Tensor: - return DynamicType.get() + return TensorType.get() elif is_tuple(ann): return TupleType([ann_to_type(a) for a in ann.__args__]) elif is_list(ann): diff --git a/torch/jit/supported_ops.py b/torch/jit/supported_ops.py index 6b629dd..acfe8d5 100644 --- a/torch/jit/supported_ops.py +++ b/torch/jit/supported_ops.py @@ -73,7 +73,7 @@ def _list_supported_ops(): self = schema.arguments[0] if self.name != 'self': return False - if not self.type.isSubtypeOf(torch._C.DynamicType.get()): + if not self.type.isSubtypeOf(torch._C.TensorType.get()): return False return True diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index c21732f..fe20db9 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -1,7 +1,7 @@ import numbers import torch -from torch._C import DynamicType, ListType, OptionalType +from torch._C import TensorType, ListType, OptionalType from torch.nn.modules.utils import _single, _pair, _triple from torch.nn.utils.rnn import PackedSequence import warnings @@ -38,11 +38,11 @@ import itertools # type information, note that there are several levels of Tensor types, defined # in aten/src/ATen/core/jit_type.h: # -# DynamicType - This is a Tensor, but we don't know anything about its +# TensorType - This is a Tensor, but we don't know anything about its # properties (e.g. scalar type, # dims, shapes). # Appears as `Tensor` in graph print-outs. -# UndefinedTensorType <: DynamicType - Denotes an undefined Tensor -# TensorType <: DynamicType - Denotes a Tensor for which we know the scalar +# UndefinedTensorType <: TensorType - Denotes an undefined Tensor +# DimensionedTensorType <: TensorType - Denotes a Tensor for which we know the scalar # type and number of dimensions, but not the concrete # shapes. For example, appears as 'Float(*, *)' in # graph print-outs. Useful accessor methods include @@ -55,7 +55,7 @@ import itertools # # In general, we should prefer to rely on the least specific information possible. # For example, not relying on tensor properties at all is better than relying -# on the number of dimensions (TensorType) which is better than relying on +# on the number of dimensions (DimensionedTensorType) which is better than relying on # concrete shapes (CompleteTensorType). Doing so will make the export symbolics # more robust to different graphs. @@ -1546,7 +1546,7 @@ def _pack_padded_sequence(g, input, lengths, batch_first): # PackPadded operators cannot be optimized out. if batch_first: input = g.op('Transpose', input, perm_i=[1, 0, 2]) - if not lengths.type().isSubtypeOf(torch._C.DynamicType.get()): + if not lengths.type().isSubtypeOf(torch._C.TensorType.get()): raise RuntimeError("Lengths must be a Tensor for ONNX export") # We know it's a TensorType so this check is now safe. # It's really only necessary beacuse those operators expand to something that -- 2.7.4