[Static Runtime] Add pass to eliminate __getitem__/DictConstruct calls (#62429)
authorMike Iovine <mikeiovine@fb.com>
Fri, 13 Aug 2021 17:18:03 +0000 (10:18 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 13 Aug 2021 17:21:16 +0000 (10:21 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62429

Introduce a new pass to eliminate calls to `prim::DictConstruct/aten::__getitem__`. Given a graph like this:
```
%2 : Dict = prim::DictConstruct(%key, %value)
%3 : Tensor = aten::__getitem__(%2, %key)
%4 : Tensor = op(%3)
```
This pass produces a graph like this (after dead code elimination):
```
%4 : Tensor = op(%value)
```

This optimization is applied in the static runtime.

Test Plan:
`buck test //caffe2/test:jit -- TestPeephole`

**local.forward performance summary**
About 3% runtime benefit. All `DictConstruct` calls optimized out, `__getitem__` calls reduced significantly (~50% of them are cut out)
P438354810

**local_request_only.forward performance summary**
About 14% runtime benefit. Again, all `DictConstruct` calls optimized out, 50% `__getitem__` calls removed.
P438359742

There is some variance with runtime measurements, so take these numbers with a grain of salt. Also note that the benefit does not exist in the shrunk model since there are no `DictConstruct` calls

Reviewed By: hlu1

Differential Revision: D29995087

fbshipit-source-id: f376376a46ff808115afd2d60446e5db8f6f752f

test/jit/test_peephole.py
tools/build_variables.bzl
torch/csrc/jit/passes/peephole.cpp
torch/csrc/jit/passes/peephole_dict_idioms.cpp [new file with mode: 0644]
torch/csrc/jit/passes/peephole_dict_idioms.h [new file with mode: 0644]
torch/csrc/jit/passes/peephole_list_idioms.h

index d2f4d32..8ba1714 100644 (file)
@@ -571,3 +571,138 @@ class TestPeephole(JitTestCase):
         FileCheck().check("graph").check("):").check_next("aten::Int") \
                    .check_next("ListConstruct").check_next("return").run(foo.graph)
         self.assertEqual(foo(0, 1, 2, 3), [1, 3])
+
+    def test_peephole_dict_getitem_simple(self):
+        @torch.jit.script
+        def foo(a: int, b: int):
+            d = {0: a, 1: b}
+            x = d[1]
+            y = d[0]
+            return x, y
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph)
+        self.assertEqual(foo(0, 1), (1, 0))
+
+        @torch.jit.script
+        def foo(a: int, b: int):
+            d = {'0': a, '1': b}
+            x = d['1']
+            y = d['0']
+            return x, y
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph)
+        self.assertEqual(foo(0, 1), (1, 0))
+
+        @torch.jit.script
+        def foo(a: int, b: int):
+            d = {0.0: a, 1.0: b}
+            x = d[1.0]
+            y = d[0.0]
+            return x, y
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph)
+        self.assertEqual(foo(0, 1), (1, 0))
+
+    def test_peephole_dict_getitem_no_optimization_missing_key(self):
+        @torch.jit.script
+        def foo():
+            d = {0: 1}
+            return d[2]
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
+
+    def test_peephole_dict_getitem_no_optimization_get_input_arg(self):
+        # Here we don't know if the input arg is in the dict, so we can't
+        # make the optimization.
+        @torch.jit.script
+        def foo(a: int):
+            d = {0: 1}
+            return d[a]
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
+        self.assertEqual(foo(0), 1)
+
+    def test_peephole_dict_getitem_no_optimization_dict_modified(self):
+        @torch.jit.script
+        def foo():
+            d = {0: 1}
+            d[0] = 2
+            return d[0]
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
+        self.assertEqual(foo(), 2)
+
+    def test_peephole_dict_getitem_no_optimization_overlapping_keys(self):
+        @torch.jit.script
+        def foo():
+            d = {0: 1, 0: 2}  # noqa: F601
+            return d[0]
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
+
+    def test_peephole_dict_getitem_no_optimization_keys_might_overlap(self):
+        @torch.jit.script
+        def foo(x: int):
+            d = {0: 1, x: 2}
+            return d[x]
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
+
+    def test_peephole_dict_getitem_no_optimization_unsupported_type(self):
+        @torch.jit.script
+        def foo():
+            a = torch.rand((2, 2))
+            d = {a: 1}
+            return d[a]
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
+        self.assertEqual(foo(), 1)
+
+    def test_peephole_dict_len(self):
+        @torch.jit.script
+        def foo():
+            d = {0: 1, 1: 2}
+            return len(d)
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check_not("DictConstruct").check_not("len").run(foo.graph)
+        self.assertEqual(foo(), 2)
+
+    def test_peephole_dict_len_no_optimization_overlapping_keys(self):
+        @torch.jit.script
+        def foo():
+            d = {0: 1, 0: 2}  # noqa: F601
+            return len(d)
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check("DictConstruct").check("len").run(foo.graph)
+        self.assertEqual(foo(), 1)
+
+    def test_peephole_dict_len_no_optimization_keys_might_overlap(self):
+        @torch.jit.script
+        def foo(x: int):
+            d = {0: 1, x: 2}
+            return len(d)
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check("DictConstruct").check("len").run(foo.graph)
+
+    def test_peephole_dict_len_no_optimization_unsupported_type(self):
+        @torch.jit.script
+        def foo():
+            a = torch.rand((2, 2))
+            d = {a: 1}
+            return len(d)
+
+        self.run_pass("peephole", foo.graph)
+        FileCheck().check("DictConstruct").check("len").run(foo.graph)
+        self.assertEqual(foo(), 1)
index 0538f6d..e7958cc 100644 (file)
@@ -219,6 +219,7 @@ core_sources_full_mobile = [
     "torch/csrc/jit/passes/lower_grad_of.cpp",
     "torch/csrc/jit/passes/lower_tuples.cpp",
     "torch/csrc/jit/passes/normalize_ops.cpp",
+    "torch/csrc/jit/passes/peephole_dict_idioms.cpp",
     "torch/csrc/jit/passes/peephole_list_idioms.cpp",
     "torch/csrc/jit/passes/value_refinement_utils.cpp",
     "torch/csrc/jit/passes/peephole_alias_sensitive.cpp",
index bacc4bc..efb7597 100644 (file)
@@ -7,6 +7,7 @@
 #include <torch/csrc/jit/jit_log.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/peephole_alias_sensitive.h>
+#include <torch/csrc/jit/passes/peephole_dict_idioms.h>
 #include <torch/csrc/jit/passes/peephole_list_idioms.h>
 #include <torch/csrc/jit/passes/peephole_non_tensor.h>
 #include <torch/csrc/jit/runtime/graph_executor.h>
@@ -32,6 +33,7 @@ struct PeepholeOptimizeImpl {
   bool run() {
     bool changed = optimizeBlock(graph_->block());
     changed |= PeepholeOptimizeListIdioms(graph_);
+    changed |= PeepholeOptimizeDictIdioms(graph_);
     changed |= PeepholeOptimizeAliasSensitive(graph_);
     changed |= PeepholeOptimizeNonTensor(graph_);
     return changed;
diff --git a/torch/csrc/jit/passes/peephole_dict_idioms.cpp b/torch/csrc/jit/passes/peephole_dict_idioms.cpp
new file mode 100644 (file)
index 0000000..b3b4ed3
--- /dev/null
@@ -0,0 +1,272 @@
+#include <torch/csrc/jit/ir/alias_analysis.h>
+#include <torch/csrc/jit/passes/peephole_dict_idioms.h>
+
+namespace torch {
+namespace jit {
+
+namespace {
+
+class DictNodeImplBase {
+ public:
+  virtual ~DictNodeImplBase() = default;
+
+  virtual bool contains(const IValue&) const = 0;
+  virtual size_t size() const = 0;
+  virtual Value* get(const IValue&) const = 0;
+
+  bool canOptimize() {
+    return !has_overlap_ && !has_non_const_key_;
+  }
+
+ protected:
+  bool has_overlap_ = false;
+  bool has_non_const_key_ = false;
+};
+
+template <class KeyType>
+class DictNodeImpl : public DictNodeImplBase {
+ public:
+  DictNodeImpl(
+      std::function<KeyType(const IValue&)> ivalue_converter,
+      Node* dict_creation_node)
+      : ivalue_converter_(std::move(ivalue_converter)) {
+    for (size_t i = 0; i < dict_creation_node->inputs().size(); i += 2) {
+      auto key_opt = toIValue(dict_creation_node->input(i));
+
+      // Key is not constant if we cannot convert to IValue
+      if (key_opt == c10::nullopt) {
+        has_non_const_key_ = true;
+        continue;
+      }
+
+      KeyType key = ivalue_converter_(*key_opt);
+      if (dict_.find(key) == dict_.end()) {
+        dict_.emplace(key, dict_creation_node->input(i + 1));
+      } else {
+        has_overlap_ = true;
+      }
+    }
+  }
+
+  bool contains(const IValue& ivalue) const override {
+    auto key = ivalue_converter_(ivalue);
+    return dict_.find(key) != dict_.end();
+  }
+
+  size_t size() const override {
+    return dict_.size();
+  }
+
+  Value* get(const IValue& ivalue) const override {
+    auto val = ivalue_converter_(ivalue);
+    auto loc = dict_.find(val);
+    if (loc != dict_.end()) {
+      return loc->second;
+    }
+    TORCH_CHECK(false, "Cannot get non-existent key");
+  }
+
+ private:
+  std::unordered_map<KeyType, Value*> dict_;
+  std::function<KeyType(const IValue&)> ivalue_converter_;
+};
+
+class DictNode {
+ public:
+  explicit DictNode(Node* dict_creation_node) {
+    auto dict_type = dict_creation_node->output()->type();
+    auto key_value_types = dict_type->containedTypes();
+    TORCH_CHECK(
+        key_value_types.size() == 2, "Dict must have 2 contained types");
+    const auto& key_type = key_value_types[0];
+
+    switch (key_type->kind()) {
+      case TypeKind::IntType: {
+        auto ivalue_converter = [](const IValue& ival) { return ival.toInt(); };
+        impl_ = std::make_unique<DictNodeImpl<int64_t>>(
+            std::move(ivalue_converter), dict_creation_node);
+        break;
+      }
+
+      case TypeKind::FloatType: {
+        auto ivalue_converter = [](const IValue& ival) {
+          return ival.toDouble();
+        };
+        impl_ = std::make_unique<DictNodeImpl<double>>(
+            std::move(ivalue_converter), dict_creation_node);
+        break;
+      }
+
+      case TypeKind::StringType: {
+        auto ivalue_converter = [](const IValue& ival) {
+          return *ival.toString();
+        };
+        impl_ = std::make_unique<DictNodeImpl<std::string>>(
+            std::move(ivalue_converter), dict_creation_node);
+        break;
+      }
+
+      default:
+        impl_ = nullptr;
+    }
+  }
+
+  bool canOptimize() const {
+    if (impl_) {
+      return impl_->canOptimize();
+    }
+    return false;
+  }
+
+  size_t size() const {
+    if (impl_) {
+      return impl_->size();
+    }
+    return 0;
+  }
+
+  c10::optional<Value*> getOrNullopt(const IValue& key) const {
+    if (impl_ && impl_->contains(key)) {
+      return impl_->get(key);
+    }
+    return c10::nullopt;
+  }
+
+ private:
+  std::unique_ptr<DictNodeImplBase> impl_;
+};
+
+bool isDict(Value* v) {
+  return v->type()->castRaw<DictType>() != nullptr;
+}
+
+class PeepholeOptimizeDictIdiomsImpl {
+ public:
+  explicit PeepholeOptimizeDictIdiomsImpl(std::shared_ptr<Graph> graph)
+      : graph_(std::move(graph)), aliasDb_(std::make_unique<AliasDb>(graph_)) {}
+
+  bool run() {
+    collectMutatedDicts(graph_->block());
+    return runBlock(graph_->block());
+  }
+
+ private:
+  void checkForMutatedDicts(Value* v) {
+    if (isDict(v) && aliasDb_->hasWriters(v)) {
+      mutated_dicts_.insert(v);
+    }
+  }
+
+  void collectMutatedDicts(Block* b) {
+    for (Value* v : b->inputs()) {
+      checkForMutatedDicts(v);
+    }
+    for (Node* n : b->nodes()) {
+      for (Value* v : n->outputs()) {
+        checkForMutatedDicts(v);
+      }
+      for (Block* block : n->blocks()) {
+        collectMutatedDicts(block);
+      }
+    }
+  }
+
+  const DictNode& getDictNode(Node* creation_node) {
+    auto cached = dict_cache_.find(creation_node);
+    if (cached == dict_cache_.end()) {
+      cached =
+          dict_cache_.emplace(creation_node, DictNode(creation_node)).first;
+    }
+
+    return cached->second;
+  }
+
+  c10::optional<Value*> getValueFromDict(Node* dict_creation_node, Value* key) {
+    const DictNode& dict_node = getDictNode(dict_creation_node);
+    auto key_opt = toIValue(key);
+    // Key is not constant if we cannot convert to IValue
+    if (key_opt == c10::nullopt) {
+      return c10::nullopt;
+    }
+    IValue key_ival = *key_opt;
+    if (dict_node.canOptimize()) {
+      return dict_node.getOrNullopt(key_ival);
+    }
+    return c10::nullopt;
+  }
+
+  c10::optional<int64_t> computeLen(Node* dict_creation_node) {
+    const DictNode& dict_node = getDictNode(dict_creation_node);
+    if (dict_node.canOptimize()) {
+      return static_cast<int64_t>(dict_node.size());
+    }
+    return c10::nullopt;
+  }
+
+  bool optimizeLen(Node* len_node, Node* creation_node) {
+    if (creation_node->kind() == prim::DictConstruct) {
+      auto len = computeLen(creation_node);
+      if (len != c10::nullopt) {
+        WithInsertPoint guard(len_node);
+        len_node->output()->replaceAllUsesWith(graph_->insertConstant(len));
+        return true;
+      }
+    }
+    return false;
+  }
+
+  bool optimizeGetItem(Node* getitem_node, Node* creation_node) {
+    if (creation_node->kind() == prim::DictConstruct) {
+      auto key = getitem_node->input(1);
+      auto value = getValueFromDict(creation_node, key);
+      if (value != c10::nullopt) {
+        getitem_node->output()->replaceAllUsesWith(*value);
+        return true;
+      }
+    }
+    return false;
+  }
+
+  bool runBlock(Block* block) {
+    bool changed = false;
+    for (Node* node : block->nodes()) {
+      for (Block* b : node->blocks()) {
+        changed |= runBlock(b);
+      }
+
+      // only optimizing dict ops
+      if (node->inputs().size() == 0 || !isDict(node->input(0))) {
+        continue;
+      }
+
+      auto first_input = node->input(0);
+
+      // only optimizing ops with unmutated inputs
+      if (mutated_dicts_.count(first_input)) {
+        continue;
+      }
+
+      if (node->kind() == aten::len) {
+        changed |= optimizeLen(node, first_input->node());
+      } else if (node->kind() == aten::__getitem__) {
+        changed |= optimizeGetItem(node, first_input->node());
+      }
+    }
+    return changed;
+  }
+
+  std::shared_ptr<Graph> graph_;
+  std::unordered_set<Value*> mutated_dicts_;
+  std::unique_ptr<AliasDb> aliasDb_;
+  std::unordered_map<Node*, DictNode> dict_cache_;
+};
+
+} // namespace
+
+bool PeepholeOptimizeDictIdioms(const std::shared_ptr<Graph>& graph) {
+  PeepholeOptimizeDictIdiomsImpl opt(graph);
+  return opt.run();
+}
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/peephole_dict_idioms.h b/torch/csrc/jit/passes/peephole_dict_idioms.h
new file mode 100644 (file)
index 0000000..283c313
--- /dev/null
@@ -0,0 +1,38 @@
+#pragma once
+
+#include <torch/csrc/jit/ir/ir.h>
+
+namespace torch {
+namespace jit {
+
+// Peephole Optimizes Dict Ops such as len() and __getitem__
+// 1. getitem optimizations
+// Given a function like this:
+//     def foo():
+//         d = {0 : 1}
+//         x = d[0]
+//         return x
+// This pass produces (after dead code elimination):
+//     def foo(a, b):
+//         return 1
+//
+// This optimization can only happen if the dict is not modified
+// and the dict has constant, non overlapping keys.
+//
+// 2. len optimizations
+// Given a function like this:
+//     def foo():
+//         d = {0 : 1}
+//         return len(d)
+// This pass produces (after dead code elimination):
+//     def foo():
+//         return 1
+//
+// This has the same requirements as the getitem optimizations.
+//
+// Currently this is invoked as part of PeepholeOptimize
+// return true if graph is modified.
+TORCH_API bool PeepholeOptimizeDictIdioms(const std::shared_ptr<Graph>& graph);
+
+} // namespace jit
+} // namespace torch
index 33e44da..c8add48 100644 (file)
@@ -5,7 +5,52 @@
 namespace torch {
 namespace jit {
 
-// Peephole Optimizes List Ops such as len(li) and li[1].
+// Peephole Optimizes List ops such as len(li) and li[1].
+// 1. Construct/Unpack optimizations
+// Given a function like this:
+//    def foo(a, b):
+//        li = [a, b]
+//        x, y = li
+//        return x, y
+// This pass produces (after dead code elimination):
+//    def foo(a, b):
+//        return a, b
+//
+// This is only applied to lists that are not modified.
+//
+// 2. getitem optimizations
+// Given a function like this:
+//     def foo(a, b):
+//         li = [a, b]
+//         x = li[0]
+//         return x
+// This pass produces (after dead code elimination):
+//     def foo(a, b):
+//         return a
+//
+// This optimization can only happen if the list is not modified.
+//
+// 3. len optimizations
+// Given a function like this:
+//     def foo():
+//         li = [1, 2]
+//         return len(li)
+// This pass produces (after dead code elimination):
+//     def foo():
+//         return 2
+//
+// This has the same requirements as the getitem optimizations.
+//
+// 4. ListConstruct + ListConstruct
+// Given a function like this:
+//     def foo():
+//         return [1, 2] + [3, 4]
+// This pass produces (after dead code elimination):
+//     def foo():
+//         return [1, 2, 3, 4]
+//
+// This is only applied to lists that are not modified.
+//
 // Currently this is invoked as part of PeepholeOptimize
 // return true if graph is modified.
 // If `refine_list_len` is true will attempt to refine the len of lists through