From 9a607594534f27396c9ea703f7e9c818df662195 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 17 Aug 2021 11:21:50 -0700 Subject: [PATCH] Propagate symbolic dimensions through idioms like x.view(y.size()) (#61975) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61975 Propagate symbolic dimensions through size calls. We did this by associating SymbolicSizes with integer inputs by looking through their constructors for `x.size(1)` or `x.size()` nodes. Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D30196948 Pulled By: eellison fbshipit-source-id: 377fc1d2f6d396c52dc0e87fa814b15720f1414e --- test/jit/test_symbolic_shape_analysis.py | 17 ++++++ torch/csrc/jit/passes/symbolic_shape_analysis.cpp | 65 +++++++++++++++++++---- 2 files changed, 72 insertions(+), 10 deletions(-) diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py index 5205845..4e42146 100644 --- a/test/jit/test_symbolic_shape_analysis.py +++ b/test/jit/test_symbolic_shape_analysis.py @@ -116,3 +116,20 @@ class TestSymbolicShapeAnalysis(JitTestCase): inputs[1].setType(inputs[1].type().with_sizes(size_2)) torch._C._jit_pass_propagate_shapes_on_graph(t.graph) self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [4, 4, 8]) + + def test_size_and_sizes(self): + @torch.jit.script + def foo(x, y): + return x.view(y.size(0), 8, y.size(-1)) + + @torch.jit.script + def foo2(x, y): + return x.view(y.size()) + + for graph in [foo.graph, foo2.graph]: + inputs = list(graph.inputs()) + sym1 = torch._C._new_symbolic_shape_symbol() + + inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1])) + torch._C._jit_pass_propagate_shapes_on_graph(graph) + self.assertEqual(next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1]) diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp index 2ff0e96..fa76058 100644 --- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp @@ -103,15 +103,60 @@ struct SymbolicShapeAnalyzer { } // we can't optimize a tensor without fixed rank if (symbolic_shapes.rank()) { - node_input_tensor_indices.push_back(i); + node_symbolic_input_indices.emplace_back(i, symbolic_shapes); } } else if ( type->cast() && type->cast()->getElementType()->cast()) { TORCH_INTERNAL_ASSERT(false); // not handled yet - } else { - if (auto ival = toIValue(node_->input(i))) { - replaceWithIValue(graph_->inputs().at(i), *ival); + } else if (auto ival = toIValue(node_->input(i))) { + replaceWithIValue(graph_->inputs().at(i), *ival); + } else if ( + type->cast() && + type->cast()->getElementType()->cast()) { + if (node_->input(i)->node()->kind() == prim::ListConstruct && + node_->input(i)->uses().size() == 1) { + // it is a very common in graphs to see patterns like: + // z = x.view(y.size()) + // or: + // z = x.view(1, 10, y.size(0), y.size(1)) + // We want to propagate symbolic dimensions and concrete sizes + // from y to z. To do this we try to associate symbolic dimensions + // or concrete sizes with the integer list inputs that have a + // constructor taken from constants or y.size() or y.size(0) + auto list_construct = node_->input(i)->node(); + std::vector> shape; + for (Value* v : list_construct->inputs()) { + if (auto constant = constant_as(v)) { + shape.emplace_back(*constant); + } else if ( + v->node()->kind() == aten::size && + v->node()->input(1)->node()->kind() == prim::Constant) { + auto tt = v->node()->input(0)->type()->expect(); + auto ss = tt->symbolic_sizes(); + if (!ss.rank()) { + shape.emplace_back(c10::nullopt); + continue; + } + auto norm_index = normIndex( + *constant_as(v->node()->input(1)), *ss.rank()); + if (!norm_index) { + shape.emplace_back(c10::nullopt); + continue; + } + shape.emplace_back(ss[*norm_index].value()); + } else { + shape.emplace_back(c10::nullopt); + } + } + node_symbolic_input_indices.emplace_back( + i, c10::SymbolicShape(shape)); + } else if ( + node_->input(i)->node()->kind() == aten::size && + node_->input(i)->uses().size() == 1) { + auto ten_inp = node_->input(i)->node()->input(); + auto ss = ten_inp->type()->expect()->symbolic_sizes(); + node_symbolic_input_indices.emplace_back(i, ss); } } } @@ -175,13 +220,12 @@ struct SymbolicShapeAnalyzer { std::unordered_map> symbolic_shape_map; - for (auto tensor_index : node_input_tensor_indices) { - auto tensor_value = node_->inputs().at(tensor_index); - auto tensor_shape = - tensor_value->type()->expect()->symbolic_sizes(); + for (const auto& index_symbolic_shape : node_symbolic_input_indices) { + auto index = index_symbolic_shape.first; + auto tensor_shape = index_symbolic_shape.second; TORCH_INTERNAL_ASSERT(tensor_shape.rank().has_value()); - for (const auto& use : graph_->inputs().at(tensor_index)->uses()) { + for (const auto& use : graph_->inputs().at(index)->uses()) { // TODO: either decompose composite ops like slice or add handling here switch (use.user->kind()) { case aten::len: { @@ -285,7 +329,8 @@ struct SymbolicShapeAnalyzer { // and non-fixed dimensions we cannot reason about at all // TODO: might be cleaner to store as a pair of index -> symbolic shape // but there were weird lifetime issues - std::vector node_input_tensor_indices; + std::vector> + node_symbolic_input_indices; std::shared_ptr graph_; Node* node_; }; -- 2.7.4