From 000e3a08817edec5681c8b10078f61ff185ce090 Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Fri, 13 Aug 2021 10:18:03 -0700 Subject: [PATCH] [Static Runtime] Add pass to eliminate __getitem__/DictConstruct calls (#62429) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62429 Introduce a new pass to eliminate calls to `prim::DictConstruct/aten::__getitem__`. Given a graph like this: ``` %2 : Dict = prim::DictConstruct(%key, %value) %3 : Tensor = aten::__getitem__(%2, %key) %4 : Tensor = op(%3) ``` This pass produces a graph like this (after dead code elimination): ``` %4 : Tensor = op(%value) ``` This optimization is applied in the static runtime. Test Plan: `buck test //caffe2/test:jit -- TestPeephole` **local.forward performance summary** About 3% runtime benefit. All `DictConstruct` calls optimized out, `__getitem__` calls reduced significantly (~50% of them are cut out) P438354810 **local_request_only.forward performance summary** About 14% runtime benefit. Again, all `DictConstruct` calls optimized out, 50% `__getitem__` calls removed. P438359742 There is some variance with runtime measurements, so take these numbers with a grain of salt. Also note that the benefit does not exist in the shrunk model since there are no `DictConstruct` calls Reviewed By: hlu1 Differential Revision: D29995087 fbshipit-source-id: f376376a46ff808115afd2d60446e5db8f6f752f --- test/jit/test_peephole.py | 135 ++++++++++++ tools/build_variables.bzl | 1 + torch/csrc/jit/passes/peephole.cpp | 2 + torch/csrc/jit/passes/peephole_dict_idioms.cpp | 272 +++++++++++++++++++++++++ torch/csrc/jit/passes/peephole_dict_idioms.h | 38 ++++ torch/csrc/jit/passes/peephole_list_idioms.h | 47 ++++- 6 files changed, 494 insertions(+), 1 deletion(-) create mode 100644 torch/csrc/jit/passes/peephole_dict_idioms.cpp create mode 100644 torch/csrc/jit/passes/peephole_dict_idioms.h diff --git a/test/jit/test_peephole.py b/test/jit/test_peephole.py index d2f4d32..8ba1714 100644 --- a/test/jit/test_peephole.py +++ b/test/jit/test_peephole.py @@ -571,3 +571,138 @@ class TestPeephole(JitTestCase): FileCheck().check("graph").check("):").check_next("aten::Int") \ .check_next("ListConstruct").check_next("return").run(foo.graph) self.assertEqual(foo(0, 1, 2, 3), [1, 3]) + + def test_peephole_dict_getitem_simple(self): + @torch.jit.script + def foo(a: int, b: int): + d = {0: a, 1: b} + x = d[1] + y = d[0] + return x, y + + self.run_pass("peephole", foo.graph) + FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph) + self.assertEqual(foo(0, 1), (1, 0)) + + @torch.jit.script + def foo(a: int, b: int): + d = {'0': a, '1': b} + x = d['1'] + y = d['0'] + return x, y + + self.run_pass("peephole", foo.graph) + FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph) + self.assertEqual(foo(0, 1), (1, 0)) + + @torch.jit.script + def foo(a: int, b: int): + d = {0.0: a, 1.0: b} + x = d[1.0] + y = d[0.0] + return x, y + + self.run_pass("peephole", foo.graph) + FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph) + self.assertEqual(foo(0, 1), (1, 0)) + + def test_peephole_dict_getitem_no_optimization_missing_key(self): + @torch.jit.script + def foo(): + d = {0: 1} + return d[2] + + self.run_pass("peephole", foo.graph) + FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) + + def test_peephole_dict_getitem_no_optimization_get_input_arg(self): + # Here we don't know if the input arg is in the dict, so we can't + # make the optimization. + @torch.jit.script + def foo(a: int): + d = {0: 1} + return d[a] + + self.run_pass("peephole", foo.graph) + FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) + self.assertEqual(foo(0), 1) + + def test_peephole_dict_getitem_no_optimization_dict_modified(self): + @torch.jit.script + def foo(): + d = {0: 1} + d[0] = 2 + return d[0] + + self.run_pass("peephole", foo.graph) + FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) + self.assertEqual(foo(), 2) + + def test_peephole_dict_getitem_no_optimization_overlapping_keys(self): + @torch.jit.script + def foo(): + d = {0: 1, 0: 2} # noqa: F601 + return d[0] + + self.run_pass("peephole", foo.graph) + FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) + + def test_peephole_dict_getitem_no_optimization_keys_might_overlap(self): + @torch.jit.script + def foo(x: int): + d = {0: 1, x: 2} + return d[x] + + self.run_pass("peephole", foo.graph) + FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) + + def test_peephole_dict_getitem_no_optimization_unsupported_type(self): + @torch.jit.script + def foo(): + a = torch.rand((2, 2)) + d = {a: 1} + return d[a] + + self.run_pass("peephole", foo.graph) + FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) + self.assertEqual(foo(), 1) + + def test_peephole_dict_len(self): + @torch.jit.script + def foo(): + d = {0: 1, 1: 2} + return len(d) + + self.run_pass("peephole", foo.graph) + FileCheck().check_not("DictConstruct").check_not("len").run(foo.graph) + self.assertEqual(foo(), 2) + + def test_peephole_dict_len_no_optimization_overlapping_keys(self): + @torch.jit.script + def foo(): + d = {0: 1, 0: 2} # noqa: F601 + return len(d) + + self.run_pass("peephole", foo.graph) + FileCheck().check("DictConstruct").check("len").run(foo.graph) + self.assertEqual(foo(), 1) + + def test_peephole_dict_len_no_optimization_keys_might_overlap(self): + @torch.jit.script + def foo(x: int): + d = {0: 1, x: 2} + return len(d) + + self.run_pass("peephole", foo.graph) + FileCheck().check("DictConstruct").check("len").run(foo.graph) + + def test_peephole_dict_len_no_optimization_unsupported_type(self): + @torch.jit.script + def foo(): + a = torch.rand((2, 2)) + d = {a: 1} + return len(d) + + self.run_pass("peephole", foo.graph) + FileCheck().check("DictConstruct").check("len").run(foo.graph) + self.assertEqual(foo(), 1) diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 0538f6d..e7958cc 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -219,6 +219,7 @@ core_sources_full_mobile = [ "torch/csrc/jit/passes/lower_grad_of.cpp", "torch/csrc/jit/passes/lower_tuples.cpp", "torch/csrc/jit/passes/normalize_ops.cpp", + "torch/csrc/jit/passes/peephole_dict_idioms.cpp", "torch/csrc/jit/passes/peephole_list_idioms.cpp", "torch/csrc/jit/passes/value_refinement_utils.cpp", "torch/csrc/jit/passes/peephole_alias_sensitive.cpp", diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index bacc4bc..efb7597 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -32,6 +33,7 @@ struct PeepholeOptimizeImpl { bool run() { bool changed = optimizeBlock(graph_->block()); changed |= PeepholeOptimizeListIdioms(graph_); + changed |= PeepholeOptimizeDictIdioms(graph_); changed |= PeepholeOptimizeAliasSensitive(graph_); changed |= PeepholeOptimizeNonTensor(graph_); return changed; diff --git a/torch/csrc/jit/passes/peephole_dict_idioms.cpp b/torch/csrc/jit/passes/peephole_dict_idioms.cpp new file mode 100644 index 0000000..b3b4ed3 --- /dev/null +++ b/torch/csrc/jit/passes/peephole_dict_idioms.cpp @@ -0,0 +1,272 @@ +#include +#include + +namespace torch { +namespace jit { + +namespace { + +class DictNodeImplBase { + public: + virtual ~DictNodeImplBase() = default; + + virtual bool contains(const IValue&) const = 0; + virtual size_t size() const = 0; + virtual Value* get(const IValue&) const = 0; + + bool canOptimize() { + return !has_overlap_ && !has_non_const_key_; + } + + protected: + bool has_overlap_ = false; + bool has_non_const_key_ = false; +}; + +template +class DictNodeImpl : public DictNodeImplBase { + public: + DictNodeImpl( + std::function ivalue_converter, + Node* dict_creation_node) + : ivalue_converter_(std::move(ivalue_converter)) { + for (size_t i = 0; i < dict_creation_node->inputs().size(); i += 2) { + auto key_opt = toIValue(dict_creation_node->input(i)); + + // Key is not constant if we cannot convert to IValue + if (key_opt == c10::nullopt) { + has_non_const_key_ = true; + continue; + } + + KeyType key = ivalue_converter_(*key_opt); + if (dict_.find(key) == dict_.end()) { + dict_.emplace(key, dict_creation_node->input(i + 1)); + } else { + has_overlap_ = true; + } + } + } + + bool contains(const IValue& ivalue) const override { + auto key = ivalue_converter_(ivalue); + return dict_.find(key) != dict_.end(); + } + + size_t size() const override { + return dict_.size(); + } + + Value* get(const IValue& ivalue) const override { + auto val = ivalue_converter_(ivalue); + auto loc = dict_.find(val); + if (loc != dict_.end()) { + return loc->second; + } + TORCH_CHECK(false, "Cannot get non-existent key"); + } + + private: + std::unordered_map dict_; + std::function ivalue_converter_; +}; + +class DictNode { + public: + explicit DictNode(Node* dict_creation_node) { + auto dict_type = dict_creation_node->output()->type(); + auto key_value_types = dict_type->containedTypes(); + TORCH_CHECK( + key_value_types.size() == 2, "Dict must have 2 contained types"); + const auto& key_type = key_value_types[0]; + + switch (key_type->kind()) { + case TypeKind::IntType: { + auto ivalue_converter = [](const IValue& ival) { return ival.toInt(); }; + impl_ = std::make_unique>( + std::move(ivalue_converter), dict_creation_node); + break; + } + + case TypeKind::FloatType: { + auto ivalue_converter = [](const IValue& ival) { + return ival.toDouble(); + }; + impl_ = std::make_unique>( + std::move(ivalue_converter), dict_creation_node); + break; + } + + case TypeKind::StringType: { + auto ivalue_converter = [](const IValue& ival) { + return *ival.toString(); + }; + impl_ = std::make_unique>( + std::move(ivalue_converter), dict_creation_node); + break; + } + + default: + impl_ = nullptr; + } + } + + bool canOptimize() const { + if (impl_) { + return impl_->canOptimize(); + } + return false; + } + + size_t size() const { + if (impl_) { + return impl_->size(); + } + return 0; + } + + c10::optional getOrNullopt(const IValue& key) const { + if (impl_ && impl_->contains(key)) { + return impl_->get(key); + } + return c10::nullopt; + } + + private: + std::unique_ptr impl_; +}; + +bool isDict(Value* v) { + return v->type()->castRaw() != nullptr; +} + +class PeepholeOptimizeDictIdiomsImpl { + public: + explicit PeepholeOptimizeDictIdiomsImpl(std::shared_ptr graph) + : graph_(std::move(graph)), aliasDb_(std::make_unique(graph_)) {} + + bool run() { + collectMutatedDicts(graph_->block()); + return runBlock(graph_->block()); + } + + private: + void checkForMutatedDicts(Value* v) { + if (isDict(v) && aliasDb_->hasWriters(v)) { + mutated_dicts_.insert(v); + } + } + + void collectMutatedDicts(Block* b) { + for (Value* v : b->inputs()) { + checkForMutatedDicts(v); + } + for (Node* n : b->nodes()) { + for (Value* v : n->outputs()) { + checkForMutatedDicts(v); + } + for (Block* block : n->blocks()) { + collectMutatedDicts(block); + } + } + } + + const DictNode& getDictNode(Node* creation_node) { + auto cached = dict_cache_.find(creation_node); + if (cached == dict_cache_.end()) { + cached = + dict_cache_.emplace(creation_node, DictNode(creation_node)).first; + } + + return cached->second; + } + + c10::optional getValueFromDict(Node* dict_creation_node, Value* key) { + const DictNode& dict_node = getDictNode(dict_creation_node); + auto key_opt = toIValue(key); + // Key is not constant if we cannot convert to IValue + if (key_opt == c10::nullopt) { + return c10::nullopt; + } + IValue key_ival = *key_opt; + if (dict_node.canOptimize()) { + return dict_node.getOrNullopt(key_ival); + } + return c10::nullopt; + } + + c10::optional computeLen(Node* dict_creation_node) { + const DictNode& dict_node = getDictNode(dict_creation_node); + if (dict_node.canOptimize()) { + return static_cast(dict_node.size()); + } + return c10::nullopt; + } + + bool optimizeLen(Node* len_node, Node* creation_node) { + if (creation_node->kind() == prim::DictConstruct) { + auto len = computeLen(creation_node); + if (len != c10::nullopt) { + WithInsertPoint guard(len_node); + len_node->output()->replaceAllUsesWith(graph_->insertConstant(len)); + return true; + } + } + return false; + } + + bool optimizeGetItem(Node* getitem_node, Node* creation_node) { + if (creation_node->kind() == prim::DictConstruct) { + auto key = getitem_node->input(1); + auto value = getValueFromDict(creation_node, key); + if (value != c10::nullopt) { + getitem_node->output()->replaceAllUsesWith(*value); + return true; + } + } + return false; + } + + bool runBlock(Block* block) { + bool changed = false; + for (Node* node : block->nodes()) { + for (Block* b : node->blocks()) { + changed |= runBlock(b); + } + + // only optimizing dict ops + if (node->inputs().size() == 0 || !isDict(node->input(0))) { + continue; + } + + auto first_input = node->input(0); + + // only optimizing ops with unmutated inputs + if (mutated_dicts_.count(first_input)) { + continue; + } + + if (node->kind() == aten::len) { + changed |= optimizeLen(node, first_input->node()); + } else if (node->kind() == aten::__getitem__) { + changed |= optimizeGetItem(node, first_input->node()); + } + } + return changed; + } + + std::shared_ptr graph_; + std::unordered_set mutated_dicts_; + std::unique_ptr aliasDb_; + std::unordered_map dict_cache_; +}; + +} // namespace + +bool PeepholeOptimizeDictIdioms(const std::shared_ptr& graph) { + PeepholeOptimizeDictIdiomsImpl opt(graph); + return opt.run(); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/peephole_dict_idioms.h b/torch/csrc/jit/passes/peephole_dict_idioms.h new file mode 100644 index 0000000..283c313 --- /dev/null +++ b/torch/csrc/jit/passes/peephole_dict_idioms.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +// Peephole Optimizes Dict Ops such as len() and __getitem__ +// 1. getitem optimizations +// Given a function like this: +// def foo(): +// d = {0 : 1} +// x = d[0] +// return x +// This pass produces (after dead code elimination): +// def foo(a, b): +// return 1 +// +// This optimization can only happen if the dict is not modified +// and the dict has constant, non overlapping keys. +// +// 2. len optimizations +// Given a function like this: +// def foo(): +// d = {0 : 1} +// return len(d) +// This pass produces (after dead code elimination): +// def foo(): +// return 1 +// +// This has the same requirements as the getitem optimizations. +// +// Currently this is invoked as part of PeepholeOptimize +// return true if graph is modified. +TORCH_API bool PeepholeOptimizeDictIdioms(const std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/peephole_list_idioms.h b/torch/csrc/jit/passes/peephole_list_idioms.h index 33e44da..c8add48 100644 --- a/torch/csrc/jit/passes/peephole_list_idioms.h +++ b/torch/csrc/jit/passes/peephole_list_idioms.h @@ -5,7 +5,52 @@ namespace torch { namespace jit { -// Peephole Optimizes List Ops such as len(li) and li[1]. +// Peephole Optimizes List ops such as len(li) and li[1]. +// 1. Construct/Unpack optimizations +// Given a function like this: +// def foo(a, b): +// li = [a, b] +// x, y = li +// return x, y +// This pass produces (after dead code elimination): +// def foo(a, b): +// return a, b +// +// This is only applied to lists that are not modified. +// +// 2. getitem optimizations +// Given a function like this: +// def foo(a, b): +// li = [a, b] +// x = li[0] +// return x +// This pass produces (after dead code elimination): +// def foo(a, b): +// return a +// +// This optimization can only happen if the list is not modified. +// +// 3. len optimizations +// Given a function like this: +// def foo(): +// li = [1, 2] +// return len(li) +// This pass produces (after dead code elimination): +// def foo(): +// return 2 +// +// This has the same requirements as the getitem optimizations. +// +// 4. ListConstruct + ListConstruct +// Given a function like this: +// def foo(): +// return [1, 2] + [3, 4] +// This pass produces (after dead code elimination): +// def foo(): +// return [1, 2, 3, 4] +// +// This is only applied to lists that are not modified. +// // Currently this is invoked as part of PeepholeOptimize // return true if graph is modified. // If `refine_list_len` is true will attempt to refine the len of lists through -- 2.7.4