[JIT] Add aten::slice optimization (#63049)
authorMike Iovine <mikeiovine@fb.com>
Fri, 27 Aug 2021 17:10:48 +0000 (10:10 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 17:12:45 +0000 (10:12 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63049

Given a graph produced from a function like this:
```
def foo():
    li = [1, 2, 3, 4, 5, 6]
    return li[0:2]
```
This pass produces a graph like this:
```
def foo():
    li = [1, 2]
    return li
```

These changes are mostly adapted from https://github.com/pytorch/pytorch/pull/62297/

Test Plan: `buck test //caffe2/test:jit -- TestPeephole`

Reviewed By: eellison

Differential Revision: D30231044

fbshipit-source-id: d12ee39f68289a574f533041a5adb38b2f000dd5

test/jit/test_peephole.py
torch/csrc/jit/passes/peephole_list_idioms.cpp
torch/csrc/jit/passes/peephole_list_idioms.h

index 23de448..ecb4a06 100644 (file)
@@ -2,7 +2,7 @@ import torch
 from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything
 from torch import nn
 from torch.testing import FileCheck
-from typing import List
+from typing import Callable, List
 
 import unittest
 
@@ -721,3 +721,75 @@ class TestPeephole(JitTestCase):
         self.run_pass("peephole", foo.graph)
         FileCheck().check("DictConstruct").check("len").run(foo.graph)
         self.assertEqual(foo(), 1)
+
+    def test_peephole_slice_all_three_args(self):
+        def foo(x: int):
+            return [1, 2, x, 4, 5, 6, 7][-5:6:2]
+
+        graph = torch.jit.script(foo).graph
+        self.run_pass("peephole", graph)
+        FileCheck().check_not("aten::slice").run(graph)
+        self.checkScript(foo, (3, ))
+
+    def test_peephole_slice_one_empty_arg(self):
+        def check_helper(fn: Callable[[int], None]) -> None:
+            graph = torch.jit.script(fn).graph
+            self.run_pass("peephole", graph)
+            FileCheck().check_not("aten::slice").run(graph)
+            self.checkScript(fn, (3, ))
+
+        def foo(x: int):
+            return [1, 2, x, 4, 5, 6, 7][1::2]
+
+        check_helper(foo)
+
+        def foo(x: int):
+            return [1, 2, x, 4, 5, 6, 7][:5:3]
+
+        check_helper(foo)
+
+        def foo(x: int):
+            return [1, 2, x, 4, 5, 6, 7][0:4]
+
+        check_helper(foo)
+
+    def test_peephole_slice_two_empty_args(self):
+        def check_helper(fn: Callable[[int], None]) -> None:
+            graph = torch.jit.script(fn).graph
+            self.run_pass("peephole", graph)
+            FileCheck().check_not("aten::slice").run(graph)
+            self.checkScript(fn, (3, ))
+
+        def foo(x: int):
+            return [1, 2, x, 4, 5, 6, 7][::2]
+
+        check_helper(foo)
+
+        def foo(x: int):
+            return [1, 2, x, 4, 5, 6, 7][:5]
+
+        check_helper(foo)
+
+        def foo(x: int):
+            return [1, 2, x, 4, 5, 6, 7][1:]
+
+        check_helper(foo)
+
+    def test_peephole_slice_optimization_not_applied_list_modified(self):
+        @torch.jit.script
+        def foo():
+            li = [1, 2, 3, 4, 5, 6, 7]
+            li[0] = 0
+            return li[2:5]
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check("aten::slice").run(foo.graph)
+
+    def test_peephole_slice_optimization_not_applied_non_const_args(self):
+        @torch.jit.script
+        def foo(x: int, y: int):
+            li = [1, 2, 3, 4, 5, 6, 7]
+            return li[x:y]
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check("aten::slice").run(foo.graph)
index f33f388..ec3d249 100644 (file)
@@ -7,7 +7,9 @@
 #include <torch/csrc/jit/passes/peephole_list_idioms.h>
 #include <torch/csrc/jit/passes/value_refinement_utils.h>
 #include <torch/csrc/jit/runtime/graph_executor.h>
+#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
 #include <torch/csrc/utils/memory.h>
+#include <limits>
 
 namespace torch {
 namespace jit {
@@ -57,7 +59,7 @@ struct ListLenRefiner {
       }
 
       auto first_input = n->input(0);
-      if (first_input->type()->cast<ListType>() &&
+      if (first_input->type()->castRaw<ListType>() &&
           !mutated_lists_.count(first_input)) {
         if (!li_with_len_use.count(first_input)) {
           li_with_len_use.insert(first_input);
@@ -172,7 +174,7 @@ struct PeepholeOptimizeListIdiomsImpl {
 
  private:
   void checkForMutatedList(Value* v) {
-    if (v->type()->cast<ListType>() && aliasDb_->hasWriters(v)) {
+    if (v->type()->castRaw<ListType>() && aliasDb_->hasWriters(v)) {
       mutated_lists_.insert(v);
     }
   }
@@ -191,6 +193,43 @@ struct PeepholeOptimizeListIdiomsImpl {
     }
   }
 
+  bool optimizeSlice(Node* slice_node, Node* list_construct_node) {
+    auto start_val = toIValue(slice_node->input(1));
+    auto end_val = toIValue(slice_node->input(2));
+    auto step_val = toIValue(slice_node->input(3));
+
+    // All args must be constant to apply this optimization.
+    if (start_val == c10::nullopt || end_val == c10::nullopt ||
+        step_val == c10::nullopt) {
+      return false;
+    }
+
+    int64_t start = start_val->isInt() ? start_val->to<int64_t>()
+                                       : std::numeric_limits<int64_t>::max();
+    int64_t end = end_val->isInt() ? end_val->to<int64_t>()
+                                   : std::numeric_limits<int64_t>::max();
+    int64_t step = step_val->isInt() ? step_val->to<int64_t>() : 1;
+
+    size_t list_size = list_construct_node->inputs().size();
+    size_t num_values = slice_indices_adjust(list_size, &start, &end, step);
+
+    WithInsertPoint guard(slice_node);
+    auto slice_list_construct =
+        graph_->insertNode(graph_->create(prim::ListConstruct));
+    slice_list_construct->output()->setType(slice_node->output()->type());
+    for (size_t i = start, j = 0; j < num_values; ++j) {
+      slice_list_construct->addInput(list_construct_node->input(i));
+      i += step;
+    }
+
+    slice_node->output()->replaceAllUsesWith(slice_list_construct->output());
+    if (mutated_lists_.count(slice_node->output())) {
+      mutated_lists_.insert(slice_list_construct->output());
+    }
+
+    return true;
+  }
+
   bool runBlock(Block* block) {
     bool changed = false;
     for (Node* node : block->nodes()) {
@@ -200,7 +239,7 @@ struct PeepholeOptimizeListIdiomsImpl {
 
       // only optimizing list ops
       if (node->inputs().size() == 0 ||
-          !node->input(0)->type()->cast<ListType>()) {
+          !node->input(0)->type()->castRaw<ListType>()) {
         continue;
       }
 
@@ -211,36 +250,33 @@ struct PeepholeOptimizeListIdiomsImpl {
         continue;
       }
 
+      auto list_creation_node = first_input->node();
+      if (list_creation_node->kind() != prim::ListConstruct) {
+        continue;
+      }
+
       if (node->kind() == aten::len) {
-        if (first_input->node()->kind() == prim::ListConstruct) {
-          WithInsertPoint guard(node);
-          node->output()->replaceAllUsesWith(graph_->insertConstant(
-              static_cast<int64_t>(first_input->node()->inputs().size())));
-          changed = true;
-        }
+        WithInsertPoint guard(node);
+        node->output()->replaceAllUsesWith(graph_->insertConstant(
+            static_cast<int64_t>(first_input->node()->inputs().size())));
+        changed = true;
       } else if (node->kind() == aten::__getitem__) {
-        auto list_creation_node = first_input->node();
-        if (list_creation_node->kind() == prim::ListConstruct) {
-          if (auto index = toIValue(node->input(1))) {
-            size_t list_size = list_creation_node->inputs().size();
-            if (auto norm_index = normalizeIndex(index->toInt(), list_size)) {
-              node->output()->replaceAllUsesWith(
-                  list_creation_node->input(*norm_index));
-              changed = true;
-            }
+        if (auto index = toIValue(node->input(1))) {
+          size_t list_size = list_creation_node->inputs().size();
+          if (auto norm_index = normalizeIndex(index->toInt(), list_size)) {
+            node->output()->replaceAllUsesWith(
+                list_creation_node->input(*norm_index));
+            changed = true;
           }
         }
       } else if (node->kind() == prim::ListUnpack) {
-        auto list_creation_node = first_input->node();
-        if (list_creation_node->kind() == prim::ListConstruct) {
-          // if sizes are unequal it's a runtime error
-          if (list_creation_node->inputs().size() != node->outputs().size()) {
-            continue;
-          }
-          for (size_t i = 0; i < node->outputs().size(); ++i) {
-            node->output(i)->replaceAllUsesWith(list_creation_node->input(i));
-            changed = true;
-          }
+        // if sizes are unequal it's a runtime error
+        if (list_creation_node->inputs().size() != node->outputs().size()) {
+          continue;
+        }
+        for (size_t i = 0; i < node->outputs().size(); ++i) {
+          node->output(i)->replaceAllUsesWith(list_creation_node->input(i));
+          changed = true;
         }
       } else if (node->kind() == aten::add) {
         if (node->inputs().size() != 2) {
@@ -251,8 +287,7 @@ struct PeepholeOptimizeListIdiomsImpl {
         if (mutated_lists_.count(second_input)) {
           continue;
         }
-        if (first_input->node()->kind() != prim::ListConstruct ||
-            second_input->node()->kind() != prim::ListConstruct) {
+        if (second_input->node()->kind() != prim::ListConstruct) {
           continue;
         }
         WithInsertPoint guard(node);
@@ -270,6 +305,8 @@ struct PeepholeOptimizeListIdiomsImpl {
           mutated_lists_.insert(list_construct->output());
         }
         changed = true;
+      } else if (node->kind() == aten::slice) {
+        changed |= optimizeSlice(node, first_input->node());
       }
     }
     return changed;
index c8add48..d20df95 100644 (file)
@@ -51,6 +51,14 @@ namespace jit {
 //
 // This is only applied to lists that are not modified.
 //
+// 5. Slice
+// Given a function like this:
+//     def foo():
+//         return [1, 2, 3, 4, 5][0:2]
+// This pass produces (after deadcode elimination):
+//     def foo():
+//         return [1, 2]
+//
 // Currently this is invoked as part of PeepholeOptimize
 // return true if graph is modified.
 // If `refine_list_len` is true will attempt to refine the len of lists through