Propagate symbolic dimensions through idioms like x.view(y.size()) (#61975)
authorElias Ellison <eellison@fb.com>
Tue, 17 Aug 2021 18:21:50 +0000 (11:21 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 17 Aug 2021 18:25:22 +0000 (11:25 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61975

Propagate symbolic dimensions through size calls. We did this by associating SymbolicSizes with integer inputs by looking through their constructors for `x.size(1)` or `x.size()` nodes.

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D30196948

Pulled By: eellison

fbshipit-source-id: 377fc1d2f6d396c52dc0e87fa814b15720f1414e

test/jit/test_symbolic_shape_analysis.py
torch/csrc/jit/passes/symbolic_shape_analysis.cpp

index 5205845..4e42146 100644 (file)
@@ -116,3 +116,20 @@ class TestSymbolicShapeAnalysis(JitTestCase):
             inputs[1].setType(inputs[1].type().with_sizes(size_2))
             torch._C._jit_pass_propagate_shapes_on_graph(t.graph)
             self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [4, 4, 8])
+
+    def test_size_and_sizes(self):
+        @torch.jit.script
+        def foo(x, y):
+            return x.view(y.size(0), 8, y.size(-1))
+
+        @torch.jit.script
+        def foo2(x, y):
+            return x.view(y.size())
+
+        for graph in [foo.graph, foo2.graph]:
+            inputs = list(graph.inputs())
+            sym1 = torch._C._new_symbolic_shape_symbol()
+
+            inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1]))
+            torch._C._jit_pass_propagate_shapes_on_graph(graph)
+            self.assertEqual(next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1])
index 2ff0e96..fa76058 100644 (file)
@@ -103,15 +103,60 @@ struct SymbolicShapeAnalyzer {
         }
         // we can't optimize a tensor without fixed rank
         if (symbolic_shapes.rank()) {
-          node_input_tensor_indices.push_back(i);
+          node_symbolic_input_indices.emplace_back(i, symbolic_shapes);
         }
       } else if (
           type->cast<ListType>() &&
           type->cast<ListType>()->getElementType()->cast<TensorType>()) {
         TORCH_INTERNAL_ASSERT(false); // not handled yet
-      } else {
-        if (auto ival = toIValue(node_->input(i))) {
-          replaceWithIValue(graph_->inputs().at(i), *ival);
+      } else if (auto ival = toIValue(node_->input(i))) {
+        replaceWithIValue(graph_->inputs().at(i), *ival);
+      } else if (
+          type->cast<ListType>() &&
+          type->cast<ListType>()->getElementType()->cast<IntType>()) {
+        if (node_->input(i)->node()->kind() == prim::ListConstruct &&
+            node_->input(i)->uses().size() == 1) {
+          // it is a very common in graphs to see patterns like:
+          // z = x.view(y.size())
+          // or:
+          // z = x.view(1, 10, y.size(0), y.size(1))
+          // We want to propagate symbolic dimensions and concrete sizes
+          // from y to z. To do this we try to associate symbolic dimensions
+          // or concrete sizes with the integer list inputs that have a
+          // constructor taken from constants or y.size() or y.size(0)
+          auto list_construct = node_->input(i)->node();
+          std::vector<c10::optional<int64_t>> shape;
+          for (Value* v : list_construct->inputs()) {
+            if (auto constant = constant_as<int64_t>(v)) {
+              shape.emplace_back(*constant);
+            } else if (
+                v->node()->kind() == aten::size &&
+                v->node()->input(1)->node()->kind() == prim::Constant) {
+              auto tt = v->node()->input(0)->type()->expect<TensorType>();
+              auto ss = tt->symbolic_sizes();
+              if (!ss.rank()) {
+                shape.emplace_back(c10::nullopt);
+                continue;
+              }
+              auto norm_index = normIndex(
+                  *constant_as<int64_t>(v->node()->input(1)), *ss.rank());
+              if (!norm_index) {
+                shape.emplace_back(c10::nullopt);
+                continue;
+              }
+              shape.emplace_back(ss[*norm_index].value());
+            } else {
+              shape.emplace_back(c10::nullopt);
+            }
+          }
+          node_symbolic_input_indices.emplace_back(
+              i, c10::SymbolicShape(shape));
+        } else if (
+            node_->input(i)->node()->kind() == aten::size &&
+            node_->input(i)->uses().size() == 1) {
+          auto ten_inp = node_->input(i)->node()->input();
+          auto ss = ten_inp->type()->expect<TensorType>()->symbolic_sizes();
+          node_symbolic_input_indices.emplace_back(i, ss);
         }
       }
     }
@@ -175,13 +220,12 @@ struct SymbolicShapeAnalyzer {
 
     std::unordered_map<int64_t, std::vector<Value*>> symbolic_shape_map;
 
-    for (auto tensor_index : node_input_tensor_indices) {
-      auto tensor_value = node_->inputs().at(tensor_index);
-      auto tensor_shape =
-          tensor_value->type()->expect<TensorType>()->symbolic_sizes();
+    for (const auto& index_symbolic_shape : node_symbolic_input_indices) {
+      auto index = index_symbolic_shape.first;
+      auto tensor_shape = index_symbolic_shape.second;
       TORCH_INTERNAL_ASSERT(tensor_shape.rank().has_value());
 
-      for (const auto& use : graph_->inputs().at(tensor_index)->uses()) {
+      for (const auto& use : graph_->inputs().at(index)->uses()) {
         // TODO: either decompose composite ops like slice or add handling here
         switch (use.user->kind()) {
           case aten::len: {
@@ -285,7 +329,8 @@ struct SymbolicShapeAnalyzer {
   // and non-fixed dimensions we cannot reason about at all
   // TODO: might be cleaner to store as a pair of index -> symbolic shape
   // but there were weird lifetime issues
-  std::vector<int64_t> node_input_tensor_indices;
+  std::vector<std::pair<int64_t, c10::SymbolicShape>>
+      node_symbolic_input_indices;
   std::shared_ptr<Graph> graph_;
   Node* node_;
 };