From babd4499783abc699faf36f3a72a9fc491e0e572 Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Fri, 27 Aug 2021 10:10:48 -0700 Subject: [PATCH] [JIT] Add aten::slice optimization (#63049) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63049 Given a graph produced from a function like this: ``` def foo(): li = [1, 2, 3, 4, 5, 6] return li[0:2] ``` This pass produces a graph like this: ``` def foo(): li = [1, 2] return li ``` These changes are mostly adapted from https://github.com/pytorch/pytorch/pull/62297/ Test Plan: `buck test //caffe2/test:jit -- TestPeephole` Reviewed By: eellison Differential Revision: D30231044 fbshipit-source-id: d12ee39f68289a574f533041a5adb38b2f000dd5 --- test/jit/test_peephole.py | 74 +++++++++++++++++++- torch/csrc/jit/passes/peephole_list_idioms.cpp | 97 ++++++++++++++++++-------- torch/csrc/jit/passes/peephole_list_idioms.h | 8 +++ 3 files changed, 148 insertions(+), 31 deletions(-) diff --git a/test/jit/test_peephole.py b/test/jit/test_peephole.py index 23de448..ecb4a06 100644 --- a/test/jit/test_peephole.py +++ b/test/jit/test_peephole.py @@ -2,7 +2,7 @@ import torch from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything from torch import nn from torch.testing import FileCheck -from typing import List +from typing import Callable, List import unittest @@ -721,3 +721,75 @@ class TestPeephole(JitTestCase): self.run_pass("peephole", foo.graph) FileCheck().check("DictConstruct").check("len").run(foo.graph) self.assertEqual(foo(), 1) + + def test_peephole_slice_all_three_args(self): + def foo(x: int): + return [1, 2, x, 4, 5, 6, 7][-5:6:2] + + graph = torch.jit.script(foo).graph + self.run_pass("peephole", graph) + FileCheck().check_not("aten::slice").run(graph) + self.checkScript(foo, (3, )) + + def test_peephole_slice_one_empty_arg(self): + def check_helper(fn: Callable[[int], None]) -> None: + graph = torch.jit.script(fn).graph + self.run_pass("peephole", graph) + FileCheck().check_not("aten::slice").run(graph) + self.checkScript(fn, (3, )) + + def foo(x: int): + return [1, 2, x, 4, 5, 6, 7][1::2] + + check_helper(foo) + + def foo(x: int): + return [1, 2, x, 4, 5, 6, 7][:5:3] + + check_helper(foo) + + def foo(x: int): + return [1, 2, x, 4, 5, 6, 7][0:4] + + check_helper(foo) + + def test_peephole_slice_two_empty_args(self): + def check_helper(fn: Callable[[int], None]) -> None: + graph = torch.jit.script(fn).graph + self.run_pass("peephole", graph) + FileCheck().check_not("aten::slice").run(graph) + self.checkScript(fn, (3, )) + + def foo(x: int): + return [1, 2, x, 4, 5, 6, 7][::2] + + check_helper(foo) + + def foo(x: int): + return [1, 2, x, 4, 5, 6, 7][:5] + + check_helper(foo) + + def foo(x: int): + return [1, 2, x, 4, 5, 6, 7][1:] + + check_helper(foo) + + def test_peephole_slice_optimization_not_applied_list_modified(self): + @torch.jit.script + def foo(): + li = [1, 2, 3, 4, 5, 6, 7] + li[0] = 0 + return li[2:5] + + self.run_pass("peephole", foo.graph) + FileCheck().check("aten::slice").run(foo.graph) + + def test_peephole_slice_optimization_not_applied_non_const_args(self): + @torch.jit.script + def foo(x: int, y: int): + li = [1, 2, 3, 4, 5, 6, 7] + return li[x:y] + + self.run_pass("peephole", foo.graph) + FileCheck().check("aten::slice").run(foo.graph) diff --git a/torch/csrc/jit/passes/peephole_list_idioms.cpp b/torch/csrc/jit/passes/peephole_list_idioms.cpp index f33f388..ec3d249 100644 --- a/torch/csrc/jit/passes/peephole_list_idioms.cpp +++ b/torch/csrc/jit/passes/peephole_list_idioms.cpp @@ -7,7 +7,9 @@ #include #include #include +#include #include +#include namespace torch { namespace jit { @@ -57,7 +59,7 @@ struct ListLenRefiner { } auto first_input = n->input(0); - if (first_input->type()->cast() && + if (first_input->type()->castRaw() && !mutated_lists_.count(first_input)) { if (!li_with_len_use.count(first_input)) { li_with_len_use.insert(first_input); @@ -172,7 +174,7 @@ struct PeepholeOptimizeListIdiomsImpl { private: void checkForMutatedList(Value* v) { - if (v->type()->cast() && aliasDb_->hasWriters(v)) { + if (v->type()->castRaw() && aliasDb_->hasWriters(v)) { mutated_lists_.insert(v); } } @@ -191,6 +193,43 @@ struct PeepholeOptimizeListIdiomsImpl { } } + bool optimizeSlice(Node* slice_node, Node* list_construct_node) { + auto start_val = toIValue(slice_node->input(1)); + auto end_val = toIValue(slice_node->input(2)); + auto step_val = toIValue(slice_node->input(3)); + + // All args must be constant to apply this optimization. + if (start_val == c10::nullopt || end_val == c10::nullopt || + step_val == c10::nullopt) { + return false; + } + + int64_t start = start_val->isInt() ? start_val->to() + : std::numeric_limits::max(); + int64_t end = end_val->isInt() ? end_val->to() + : std::numeric_limits::max(); + int64_t step = step_val->isInt() ? step_val->to() : 1; + + size_t list_size = list_construct_node->inputs().size(); + size_t num_values = slice_indices_adjust(list_size, &start, &end, step); + + WithInsertPoint guard(slice_node); + auto slice_list_construct = + graph_->insertNode(graph_->create(prim::ListConstruct)); + slice_list_construct->output()->setType(slice_node->output()->type()); + for (size_t i = start, j = 0; j < num_values; ++j) { + slice_list_construct->addInput(list_construct_node->input(i)); + i += step; + } + + slice_node->output()->replaceAllUsesWith(slice_list_construct->output()); + if (mutated_lists_.count(slice_node->output())) { + mutated_lists_.insert(slice_list_construct->output()); + } + + return true; + } + bool runBlock(Block* block) { bool changed = false; for (Node* node : block->nodes()) { @@ -200,7 +239,7 @@ struct PeepholeOptimizeListIdiomsImpl { // only optimizing list ops if (node->inputs().size() == 0 || - !node->input(0)->type()->cast()) { + !node->input(0)->type()->castRaw()) { continue; } @@ -211,36 +250,33 @@ struct PeepholeOptimizeListIdiomsImpl { continue; } + auto list_creation_node = first_input->node(); + if (list_creation_node->kind() != prim::ListConstruct) { + continue; + } + if (node->kind() == aten::len) { - if (first_input->node()->kind() == prim::ListConstruct) { - WithInsertPoint guard(node); - node->output()->replaceAllUsesWith(graph_->insertConstant( - static_cast(first_input->node()->inputs().size()))); - changed = true; - } + WithInsertPoint guard(node); + node->output()->replaceAllUsesWith(graph_->insertConstant( + static_cast(first_input->node()->inputs().size()))); + changed = true; } else if (node->kind() == aten::__getitem__) { - auto list_creation_node = first_input->node(); - if (list_creation_node->kind() == prim::ListConstruct) { - if (auto index = toIValue(node->input(1))) { - size_t list_size = list_creation_node->inputs().size(); - if (auto norm_index = normalizeIndex(index->toInt(), list_size)) { - node->output()->replaceAllUsesWith( - list_creation_node->input(*norm_index)); - changed = true; - } + if (auto index = toIValue(node->input(1))) { + size_t list_size = list_creation_node->inputs().size(); + if (auto norm_index = normalizeIndex(index->toInt(), list_size)) { + node->output()->replaceAllUsesWith( + list_creation_node->input(*norm_index)); + changed = true; } } } else if (node->kind() == prim::ListUnpack) { - auto list_creation_node = first_input->node(); - if (list_creation_node->kind() == prim::ListConstruct) { - // if sizes are unequal it's a runtime error - if (list_creation_node->inputs().size() != node->outputs().size()) { - continue; - } - for (size_t i = 0; i < node->outputs().size(); ++i) { - node->output(i)->replaceAllUsesWith(list_creation_node->input(i)); - changed = true; - } + // if sizes are unequal it's a runtime error + if (list_creation_node->inputs().size() != node->outputs().size()) { + continue; + } + for (size_t i = 0; i < node->outputs().size(); ++i) { + node->output(i)->replaceAllUsesWith(list_creation_node->input(i)); + changed = true; } } else if (node->kind() == aten::add) { if (node->inputs().size() != 2) { @@ -251,8 +287,7 @@ struct PeepholeOptimizeListIdiomsImpl { if (mutated_lists_.count(second_input)) { continue; } - if (first_input->node()->kind() != prim::ListConstruct || - second_input->node()->kind() != prim::ListConstruct) { + if (second_input->node()->kind() != prim::ListConstruct) { continue; } WithInsertPoint guard(node); @@ -270,6 +305,8 @@ struct PeepholeOptimizeListIdiomsImpl { mutated_lists_.insert(list_construct->output()); } changed = true; + } else if (node->kind() == aten::slice) { + changed |= optimizeSlice(node, first_input->node()); } } return changed; diff --git a/torch/csrc/jit/passes/peephole_list_idioms.h b/torch/csrc/jit/passes/peephole_list_idioms.h index c8add48..d20df95 100644 --- a/torch/csrc/jit/passes/peephole_list_idioms.h +++ b/torch/csrc/jit/passes/peephole_list_idioms.h @@ -51,6 +51,14 @@ namespace jit { // // This is only applied to lists that are not modified. // +// 5. Slice +// Given a function like this: +// def foo(): +// return [1, 2, 3, 4, 5][0:2] +// This pass produces (after deadcode elimination): +// def foo(): +// return [1, 2] +// // Currently this is invoked as part of PeepholeOptimize // return true if graph is modified. // If `refine_list_len` is true will attempt to refine the len of lists through -- 2.7.4