from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything
from torch import nn
from torch.testing import FileCheck
-from typing import List
+from typing import Callable, List
import unittest
self.run_pass("peephole", foo.graph)
FileCheck().check("DictConstruct").check("len").run(foo.graph)
self.assertEqual(foo(), 1)
+
+ def test_peephole_slice_all_three_args(self):
+ def foo(x: int):
+ return [1, 2, x, 4, 5, 6, 7][-5:6:2]
+
+ graph = torch.jit.script(foo).graph
+ self.run_pass("peephole", graph)
+ FileCheck().check_not("aten::slice").run(graph)
+ self.checkScript(foo, (3, ))
+
+ def test_peephole_slice_one_empty_arg(self):
+ def check_helper(fn: Callable[[int], None]) -> None:
+ graph = torch.jit.script(fn).graph
+ self.run_pass("peephole", graph)
+ FileCheck().check_not("aten::slice").run(graph)
+ self.checkScript(fn, (3, ))
+
+ def foo(x: int):
+ return [1, 2, x, 4, 5, 6, 7][1::2]
+
+ check_helper(foo)
+
+ def foo(x: int):
+ return [1, 2, x, 4, 5, 6, 7][:5:3]
+
+ check_helper(foo)
+
+ def foo(x: int):
+ return [1, 2, x, 4, 5, 6, 7][0:4]
+
+ check_helper(foo)
+
+ def test_peephole_slice_two_empty_args(self):
+ def check_helper(fn: Callable[[int], None]) -> None:
+ graph = torch.jit.script(fn).graph
+ self.run_pass("peephole", graph)
+ FileCheck().check_not("aten::slice").run(graph)
+ self.checkScript(fn, (3, ))
+
+ def foo(x: int):
+ return [1, 2, x, 4, 5, 6, 7][::2]
+
+ check_helper(foo)
+
+ def foo(x: int):
+ return [1, 2, x, 4, 5, 6, 7][:5]
+
+ check_helper(foo)
+
+ def foo(x: int):
+ return [1, 2, x, 4, 5, 6, 7][1:]
+
+ check_helper(foo)
+
+ def test_peephole_slice_optimization_not_applied_list_modified(self):
+ @torch.jit.script
+ def foo():
+ li = [1, 2, 3, 4, 5, 6, 7]
+ li[0] = 0
+ return li[2:5]
+
+ self.run_pass("peephole", foo.graph)
+ FileCheck().check("aten::slice").run(foo.graph)
+
+ def test_peephole_slice_optimization_not_applied_non_const_args(self):
+ @torch.jit.script
+ def foo(x: int, y: int):
+ li = [1, 2, 3, 4, 5, 6, 7]
+ return li[x:y]
+
+ self.run_pass("peephole", foo.graph)
+ FileCheck().check("aten::slice").run(foo.graph)
#include <torch/csrc/jit/passes/peephole_list_idioms.h>
#include <torch/csrc/jit/passes/value_refinement_utils.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
+#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
#include <torch/csrc/utils/memory.h>
+#include <limits>
namespace torch {
namespace jit {
}
auto first_input = n->input(0);
- if (first_input->type()->cast<ListType>() &&
+ if (first_input->type()->castRaw<ListType>() &&
!mutated_lists_.count(first_input)) {
if (!li_with_len_use.count(first_input)) {
li_with_len_use.insert(first_input);
private:
void checkForMutatedList(Value* v) {
- if (v->type()->cast<ListType>() && aliasDb_->hasWriters(v)) {
+ if (v->type()->castRaw<ListType>() && aliasDb_->hasWriters(v)) {
mutated_lists_.insert(v);
}
}
}
}
+ bool optimizeSlice(Node* slice_node, Node* list_construct_node) {
+ auto start_val = toIValue(slice_node->input(1));
+ auto end_val = toIValue(slice_node->input(2));
+ auto step_val = toIValue(slice_node->input(3));
+
+ // All args must be constant to apply this optimization.
+ if (start_val == c10::nullopt || end_val == c10::nullopt ||
+ step_val == c10::nullopt) {
+ return false;
+ }
+
+ int64_t start = start_val->isInt() ? start_val->to<int64_t>()
+ : std::numeric_limits<int64_t>::max();
+ int64_t end = end_val->isInt() ? end_val->to<int64_t>()
+ : std::numeric_limits<int64_t>::max();
+ int64_t step = step_val->isInt() ? step_val->to<int64_t>() : 1;
+
+ size_t list_size = list_construct_node->inputs().size();
+ size_t num_values = slice_indices_adjust(list_size, &start, &end, step);
+
+ WithInsertPoint guard(slice_node);
+ auto slice_list_construct =
+ graph_->insertNode(graph_->create(prim::ListConstruct));
+ slice_list_construct->output()->setType(slice_node->output()->type());
+ for (size_t i = start, j = 0; j < num_values; ++j) {
+ slice_list_construct->addInput(list_construct_node->input(i));
+ i += step;
+ }
+
+ slice_node->output()->replaceAllUsesWith(slice_list_construct->output());
+ if (mutated_lists_.count(slice_node->output())) {
+ mutated_lists_.insert(slice_list_construct->output());
+ }
+
+ return true;
+ }
+
bool runBlock(Block* block) {
bool changed = false;
for (Node* node : block->nodes()) {
// only optimizing list ops
if (node->inputs().size() == 0 ||
- !node->input(0)->type()->cast<ListType>()) {
+ !node->input(0)->type()->castRaw<ListType>()) {
continue;
}
continue;
}
+ auto list_creation_node = first_input->node();
+ if (list_creation_node->kind() != prim::ListConstruct) {
+ continue;
+ }
+
if (node->kind() == aten::len) {
- if (first_input->node()->kind() == prim::ListConstruct) {
- WithInsertPoint guard(node);
- node->output()->replaceAllUsesWith(graph_->insertConstant(
- static_cast<int64_t>(first_input->node()->inputs().size())));
- changed = true;
- }
+ WithInsertPoint guard(node);
+ node->output()->replaceAllUsesWith(graph_->insertConstant(
+ static_cast<int64_t>(first_input->node()->inputs().size())));
+ changed = true;
} else if (node->kind() == aten::__getitem__) {
- auto list_creation_node = first_input->node();
- if (list_creation_node->kind() == prim::ListConstruct) {
- if (auto index = toIValue(node->input(1))) {
- size_t list_size = list_creation_node->inputs().size();
- if (auto norm_index = normalizeIndex(index->toInt(), list_size)) {
- node->output()->replaceAllUsesWith(
- list_creation_node->input(*norm_index));
- changed = true;
- }
+ if (auto index = toIValue(node->input(1))) {
+ size_t list_size = list_creation_node->inputs().size();
+ if (auto norm_index = normalizeIndex(index->toInt(), list_size)) {
+ node->output()->replaceAllUsesWith(
+ list_creation_node->input(*norm_index));
+ changed = true;
}
}
} else if (node->kind() == prim::ListUnpack) {
- auto list_creation_node = first_input->node();
- if (list_creation_node->kind() == prim::ListConstruct) {
- // if sizes are unequal it's a runtime error
- if (list_creation_node->inputs().size() != node->outputs().size()) {
- continue;
- }
- for (size_t i = 0; i < node->outputs().size(); ++i) {
- node->output(i)->replaceAllUsesWith(list_creation_node->input(i));
- changed = true;
- }
+ // if sizes are unequal it's a runtime error
+ if (list_creation_node->inputs().size() != node->outputs().size()) {
+ continue;
+ }
+ for (size_t i = 0; i < node->outputs().size(); ++i) {
+ node->output(i)->replaceAllUsesWith(list_creation_node->input(i));
+ changed = true;
}
} else if (node->kind() == aten::add) {
if (node->inputs().size() != 2) {
if (mutated_lists_.count(second_input)) {
continue;
}
- if (first_input->node()->kind() != prim::ListConstruct ||
- second_input->node()->kind() != prim::ListConstruct) {
+ if (second_input->node()->kind() != prim::ListConstruct) {
continue;
}
WithInsertPoint guard(node);
mutated_lists_.insert(list_construct->output());
}
changed = true;
+ } else if (node->kind() == aten::slice) {
+ changed |= optimizeSlice(node, first_input->node());
}
}
return changed;