From 29514bfcdbe460b15900b762576ed9bb1eea45d5 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Wed, 15 Sep 2021 13:43:12 -0700 Subject: [PATCH] Max Pool with indices (#64121) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64121 Add support for aten operators which return multiple outputs Test Plan: Imported from OSS Reviewed By: driazati Differential Revision: D30738142 Pulled By: eellison fbshipit-source-id: 0d7e51187bd5e3e9b43f0fdb5178366a97aec943 --- test/test_ops.py | 13 ++++- torch/csrc/jit/passes/symbolic_shape_analysis.cpp | 67 ++++++++++++++-------- torch/csrc/jit/runtime/symbolic_shape_registry.cpp | 22 +++++++ torch/testing/_internal/common_jit.py | 31 ++++++---- .../_internal/common_methods_invocations.py | 9 ++- 5 files changed, 103 insertions(+), 39 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 4ee0cbc..8c93dba 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -784,9 +784,16 @@ class TestJit(JitCommonTestCase): if supports_tracing: out = variant(get_sample(), *sample.args, **sample.kwargs) - # TODO: handle multiple outputs - if isinstance(out, torch.Tensor): - self.checkShapeAnalysis(out.size(), traced_fn.graph, op.assert_jit_shape_analysis) + # right now, tuple of outputs and tensor output supported + # TODO: list of tensor outputs + tuple_of_tensors = isinstance(out, tuple) and all([isinstance(elem, torch.Tensor) for elem in out]) + + if isinstance(out, torch.Tensor) or tuple_of_tensors: + if tuple_of_tensors: + sizes = [elem.size() for elem in out] + else: + sizes = out.size() + self.checkShapeAnalysis(sizes, traced_fn.graph, op.assert_jit_shape_analysis) checked_shape_analysis = True if op.assert_jit_shape_analysis: self.assertTrue(checked_shape_analysis) diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp index e2162e1..d00a717 100644 --- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp @@ -135,6 +135,13 @@ bool symbolicShapeAnalysisTestModeEnabled() { return symbolic_shape_analysis_test_mode; } +namespace { + +bool isListOfInts(const TypePtr& type) { + return type->cast() && + type->cast()->getElementType()->cast(); +} + c10::optional normIndex(int64_t index, size_t len) { if (index < 0) { index = index + len; @@ -151,6 +158,8 @@ void replaceWithIValue(Value* v, IValue val) { v->replaceAllUsesWith(v->owningGraph()->insertConstant(val)); } +} // namespace + // Symbolic Shape Analysis works through iteratively partially evaluating // a TorchScript shape compute graph by inputing properties from input // Tensors. We can substitute in properties like `len(x)` and `x[1]` @@ -269,7 +278,7 @@ struct SymbolicShapeAnalyzer { } } - c10::SymbolicShape run() { + void run() { bool made_change = true; constexpr size_t MAX_ATTEMPTS = 8; size_t curr_attempt = 0; @@ -294,7 +303,7 @@ struct SymbolicShapeAnalyzer { substituteInputTensorProperties(&symbolic_shape_values); GRAPH_DUMP("Done with partial evaluation", graph_); - return extractOutputShape(symbolic_shape_values); + extractOutputShape(symbolic_shape_values); } private: @@ -449,28 +458,21 @@ struct SymbolicShapeAnalyzer { } } - c10::SymbolicShape extractOutputShape( - std::unordered_map& symbolic_shape_values) { - TORCH_INTERNAL_ASSERT(graph_->outputs().size() == 1); - auto output = graph_->outputs().at(0); - TORCH_INTERNAL_ASSERT( - output->type()->cast() && - output->type()->cast()->getElementType()->cast()); - if (output->node()->kind() == prim::Constant) { - auto int_list = toIValue(output)->toIntVector(); + c10::SymbolicShape extractListShape( + Value* list, + std::unordered_map& symbolic_shape_values, + const AliasDb& db) { + if (list->node()->kind() == prim::Constant) { + auto int_list = toIValue(list)->toIntVector(); return c10::SymbolicShape(int_list); } - // TODO: would be nice if there were easy facility to look at uses and see - // if they are all pure instead of instanting db. - AliasDb db(graph_); - // If it is not a single list construct or constant, bail, - // otherwise we cannot analyze its output and it might be modified - if (output->node()->kind() != prim::ListConstruct || - db.hasWriters(output)) { + // We need a list construct or a constant output + // that is not written to in order to analyze the output shape + if (list->node()->kind() != prim::ListConstruct || db.hasWriters(list)) { GRAPH_DEBUG("Could not extract shape ", getHeader(node_)); return c10::SymbolicShape(); } - Node* list_construct = output->node(); + Node* list_construct = list->node(); std::vector> output_shape; for (Value* input : list_construct->inputs()) { if (symbolic_shape_values.count(input)) { @@ -482,6 +484,23 @@ struct SymbolicShapeAnalyzer { return c10::SymbolicShape(output_shape); } + void extractOutputShape( + std::unordered_map& symbolic_shape_values) { + TORCH_INTERNAL_ASSERT(graph_->outputs().size() == node_->outputs().size()); + // TODO: would be nice if there were easy facility to look at uses and see + // if they are all pure instead of instanting db. + AliasDb db(graph_); + for (size_t i = 0; i < graph_->outputs().size(); ++i) { + auto output = graph_->outputs().at(i); + auto type = output->type(); + TORCH_INTERNAL_ASSERT(isListOfInts(type)); + auto ss = extractListShape(output, symbolic_shape_values, db); + node_->output(i)->setType( + node_->output(i)->type()->expect()->withSymbolicShapes( + ss)); + } + } + // node input indices that are TensorType and we need to iteratively // substitute properties of. We only substitute properties // of TensorTypes with a fixed dimension but not a complete shape, @@ -498,10 +517,7 @@ void PropagateShapesWithShapeFunction( Node* n, std::shared_ptr& shape_compute_graph, const AliasDb& db) { - c10::SymbolicShape out = - SymbolicShapeAnalyzer(n, shape_compute_graph, db).run(); - n->output()->setType( - n->output()->type()->expect()->withSymbolicShapes(out)); + SymbolicShapeAnalyzer(n, shape_compute_graph, db).run(); } void PropagateShapesOnBlock(Block* b, const AliasDb& db) { @@ -516,6 +532,11 @@ void PropagateShapesOnBlock(Block* b, const AliasDb& db) { if (auto maybe_graph = shapeComputeGraphForSchema(n->schema())) { PropagateShapesWithShapeFunction(n, *maybe_graph, db); } + } else if (n->kind() == prim::TupleConstruct) { + auto orig_type = n->output()->type()->expect(); + auto new_types = fmap(n->inputs(), [](Value* v) { return v->type(); }); + n->output()->setType( + orig_type->createWithContained(std::move(new_types))); } } } diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index 7841631..3751eb6 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -204,8 +204,13 @@ const std::string shape_compute_functions = return [nInputPlane, outputHeight, outputWidth] else: return [nbatch, nInputPlane, outputHeight, outputWidth] + + def max_pool2d_with_indices(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool): + out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) + return (out, out) )" R"( + def mm(self: List[int] , mat2: List[int]): assert len(self) == 2, "self must be a matrix" assert len(mat2) == 2, "mat2 must be a matrix" @@ -587,6 +592,7 @@ static const OperatorMap& get_schema_to_function_graph() { {"aten::matmul(Tensor self, Tensor other) -> Tensor", "matmul"}, {"aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", "linear"}, {"aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "max_pool2d"}, + {"aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", "max_pool2d_with_indices"}, {"aten::t(Tensor(a) self) -> Tensor(a)", "t"}, {"aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", "transpose"}, {"aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor", "conv1d"}, @@ -634,6 +640,22 @@ void loadModule(const CompilationUnit& module) { std::shared_ptr graph = shape_compute_function.graph(); Inline(*graph); + // ATEN operators can return multiple unboxed values, this in contrast to + // functions defined in TorchScript or User-Registered Operators + // Which must use a Tuple + // Here, modify the shape graph of aten operators with multiple outputs + // so that they correspond to each other + if (pair.first->schema().returns().size() > 1) { + TORCH_INTERNAL_ASSERT( + graph->outputs().size() == 1 && + graph->outputs().at(0)->node()->kind() == prim::TupleConstruct); + auto tuple_node = graph->outputs().at(0)->node(); + graph->eraseOutput(0); + for (Value* v : tuple_node->inputs()) { + graph->registerOutput(v); + } + } + cached_schema_to_graph[schema_string] = graph; reused_functions[shape_compute_function_name] = graph; } diff --git a/torch/testing/_internal/common_jit.py b/torch/testing/_internal/common_jit.py index c32300e..8154963 100644 --- a/torch/testing/_internal/common_jit.py +++ b/torch/testing/_internal/common_jit.py @@ -281,7 +281,8 @@ class JitCommonTestCase(TestCase): self.assertEqual(should_autodiff_node, found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg) - def checkShapeAnalysis(self, out_size, traced_graph, assert_propagation, constant_prop=True): + def checkShapeAnalysis(self, out_sizes: Union[List[int], List[List[int]]], + traced_graph, assert_propagation, constant_prop=True): # repropagte input shapes provided by tracing, prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled() for enable_test_mode in [True, False]: @@ -294,16 +295,26 @@ class JitCommonTestCase(TestCase): torch._C._jit_pass_propagate_shapes_on_graph(traced_graph) # Add sizes to default tensor type to avoid checking something out of scope # and difficulties with tracer leaving in other parts of tensor type - sizes = next(traced_graph.outputs()).type().symbolic_sizes() - out_type = TensorType.get().with_sizes(sizes) - actual_type = TensorType.get().with_sizes(out_size) + output = next(traced_graph.outputs()).type() - # always check actual shape is a subtype of the output - self.assertTrue(actual_type.isSubtypeOf(out_type)) + def test_type(type, actual_size): + sizes = type.symbolic_sizes() + out_type = TensorType.get().with_sizes(sizes) + actual_type = TensorType.get().with_sizes(actual_size) - # and then if assertion flag is provided, check shape analysis - # is successful - if assert_propagation: - self.assertEqual(out_type.sizes(), out_size) + # always check actual shape is a subtype of the output + self.assertTrue(actual_type.isSubtypeOf(out_type)) + + # and then if assertion flag is provided, check shape analysis + # is successful + if assert_propagation: + self.assertEqual(out_type.sizes(), actual_size) + + if output.isSubtypeOf(torch._C.TensorType.get()): + test_type(output, out_sizes) + else: + tuple_elements = output.elements() + for i in range(len(tuple_elements)): + test_type(tuple_elements[i], out_sizes[i]) torch._C._jit_set_symbolic_shapes_test_mode(prev_symbolic_shapes_test_enabled) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 18f7431..8189dce 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2598,11 +2598,14 @@ def sample_inputs_max_pool2d(op_info, device, dtype, requires_grad, **kwargs): ceil_modei = [True, False] paddingi = [0, 1] dilationi = [1, (1, 2)] - products = product(kerneli, stridei, Ni, Ci, Hi, Wi, ceil_modei, paddingi, dilationi) + return_indicesi = [True, False] + + products = product(kerneli, stridei, Ni, Ci, Hi, Wi, ceil_modei, paddingi, dilationi, return_indicesi) def generator(): - for kernel, stride, N, C, H, W, ceil_mode, padding, dilation in products: - max_pool = torch.nn.MaxPool2d(kernel, stride, ceil_mode=ceil_mode, padding=padding, dilation=dilation) + for kernel, stride, N, C, H, W, ceil_mode, padding, dilation, return_indices in products: + max_pool = torch.nn.MaxPool2d(kernel, stride, ceil_mode=ceil_mode, padding=padding, + dilation=dilation, return_indices=return_indices) kwargs = { "kernel_size": max_pool.kernel_size, "stride": max_pool.stride, -- 2.7.4