batched cleanups (#17288)
authorMichael Suo <suo@fb.com>
Thu, 21 Feb 2019 02:27:31 +0000 (18:27 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 21 Feb 2019 02:31:53 +0000 (18:31 -0800)
Summary:
Bunch of random stuff I came across while doing UDT stuff. Putting in a separate PR to avoid noise
- fix up the alias analysis list ops to include fork/wait
- improve dump() for aliasDb to print writes
- Move BuiltinFunction::call() to sugaredvalue with the rest of the methods
- formatting and includes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17288

Differential Revision: D14147105

Pulled By: suo

fbshipit-source-id: 62e2a922a1726b684347365dc42c72188f154e9c

aten/src/ATen/core/ivalue.h
aten/src/ATen/core/jit_type.h
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/passes/utils/alias_tracker.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/sugared_value.cpp
torch/csrc/jit/script/tree_views.h

index cb633ec..08b2298 100644 (file)
@@ -3,10 +3,11 @@
 #include <condition_variable>
 #include <type_traits>
 
+#include <ATen/core/blob.h>
+#include <ATen/core/interned_strings.h>
 #include <c10/core/Scalar.h>
 #include <c10/core/TensorImpl.h>
 #include <c10/core/UndefinedTensorImpl.h>
-#include <ATen/core/blob.h>
 #include <c10/util/intrusive_ptr.h>
 
 #include <ATen/core/Tensor.h>
index f115b35..4db6f06 100644 (file)
@@ -1,10 +1,10 @@
 #pragma once
 
-#include <ATen/core/ivalue.h>
-#include <ATen/core/interned_strings.h>
-#include <ATen/core/functional.h>
-#include <ATen/core/Type.h>
 #include <ATen/core/TensorMethods.h>
+#include <ATen/core/Type.h>
+#include <ATen/core/functional.h>
+#include <ATen/core/interned_strings.h>
+#include <ATen/core/ivalue.h>
 #include <c10/util/TypeList.h>
 #include <caffe2/core/common.h>
 
index e4b44e3..899bff7 100644 (file)
@@ -239,6 +239,9 @@ void AliasDb::analyze(const std::shared_ptr<Graph>& graph) {
   for (const auto& pr : tupleTypes) {
     makeAllAlias(pr.second, *aliasTracker_);
   }
+  for (const auto& pr : dictTypes) {
+    makeAllAlias(pr.second, *aliasTracker_);
+  }
   makeAllAlias(tensors, *aliasTracker_);
 
   analyze(graph->block());
@@ -321,7 +324,6 @@ void AliasDb::analyzeImpl(Node* node) {
       AT_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind()));
   }
 
-
   const auto& schema = node->schema();
   if (schema.is_vararg() || schema.is_varret()) {
     const auto hasMutableOutputs = std::any_of(
@@ -973,47 +975,48 @@ TORCH_API bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
   // WARNING: by adding a case to this list, you are asserting that you have
   // added a case for the unschematized node in AliasDb::analyze
   const static std::unordered_set<Symbol> handled = {
-    prim::If,
-    prim::Loop,
-    prim::FusionGroup,
-    prim::DifferentiableGraph,
-    prim::Constant,
-    prim::DictConstruct,
-    prim::ListConstruct,
-    prim::TupleConstruct,
-    prim::Undefined,
-    prim::FusedConcat,
-    prim::MMTreeReduce,
-    prim::MMBatchSide,
-    prim::None,
-    prim::BroadcastSizes,
-    prim::ChunkSizes,
-    prim::Function,
-    prim::TupleUnpack,
-    prim::TupleIndex,
-    prim::DictIndex,
-    prim::TupleSlice,
-    prim::ListUnpack,
-    prim::PythonOp,
-    prim::ConstantChunk,
-    prim::BroadcastingChunk,
-    aten::add,
-    aten::sub,
-    aten::mul,
-    aten::div,
+      prim::If,
+      prim::Loop,
+      prim::FusionGroup,
+      prim::DifferentiableGraph,
+      prim::Constant,
+      prim::DictConstruct,
+      prim::ListConstruct,
+      prim::TupleConstruct,
+      prim::Undefined,
+      prim::FusedConcat,
+      prim::MMTreeReduce,
+      prim::MMBatchSide,
+      prim::None,
+      prim::BroadcastSizes,
+      prim::ChunkSizes,
+      prim::Function,
+      prim::TupleUnpack,
+      prim::TupleIndex,
+      prim::DictIndex,
+      prim::TupleSlice,
+      prim::ListUnpack,
+      prim::PythonOp,
+      prim::ConstantChunk,
+      prim::BroadcastingChunk,
+      prim::fork,
+      aten::wait,
+      aten::add,
+      aten::sub,
+      aten::mul,
+      aten::div,
   };
 
   // Operators that should not be used by alias analysis
   const static std::unordered_set<Symbol> purposefully_not_handled = {
-    prim::Print,
-    prim::Load,
-    prim::Store,
-    prim::Drop,
-    at::onnx::Reshape,
-    at::onnx::Shape,
-    prim::AnyDefined,
-    prim::AutogradAdd,
-    prim::fork, // TODO: fork aliasing / futures
+      prim::Print,
+      prim::Load,
+      prim::Store,
+      prim::Drop,
+      at::onnx::Reshape,
+      at::onnx::Shape,
+      prim::AnyDefined,
+      prim::AutogradAdd,
   };
 
   return handled.count(symbol) || purposefully_not_handled.count(symbol);
index 7d249cd..4126bcc 100644 (file)
@@ -1,9 +1,9 @@
-#include <torch/csrc/jit/passes/python_print.h>
 #include <c10/util/Exception.h>
 #include <torch/csrc/jit/attributes.h>
 #include <torch/csrc/jit/export.h>
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/ir_views.h>
+#include <torch/csrc/jit/passes/python_print.h>
 #include <torch/csrc/jit/resource_guard.h>
 #include <torch/csrc/jit/script/error_report.h>
 #include <torch/csrc/jit/script/module.h>
index a231f01..e6cd358 100644 (file)
@@ -180,6 +180,19 @@ void AliasTracker::dump() const {
     std::cout << wildcard->uniqueName() << ", ";
   }
   std::cout << "\n";
+
+  std::cout << "\n===4. Writes===\n";
+  for (const auto& pr : writeIndex_) {
+    const auto node = pr.first;
+    const auto& values = pr.second;
+    std::cout << *node;
+    std::cout << "  ";
+    for (const auto value : values) {
+      std::cout << value->uniqueName() << ", ";
+    }
+    std::cout << "\n";
+  }
+  std::cout << "\n";
 }
 
 std::unordered_set<const AliasTracker::Element*> AliasTracker::Element::
index 25f9607..35f0c23 100644 (file)
@@ -1,4 +1,3 @@
-#include <torch/csrc/jit/script/compiler.h>
 #include <c10/util/Exception.h>
 #include <torch/csrc/jit/hooks_for_testing.h>
 #include <torch/csrc/jit/interpreter.h>
@@ -6,6 +5,7 @@
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/script/compiler.h>
 #include <torch/csrc/jit/script/final_returns.h>
 #include <torch/csrc/jit/script/parser.h>
 #include <torch/csrc/jit/script/schema_matching.h>
@@ -488,16 +488,6 @@ static Value* ensureInt(const SourceRange& range, Value* v) {
   return v;
 }
 
-std::shared_ptr<SugaredValue> BuiltinFunction::call(
-    const SourceRange& loc,
-    Method& m,
-    at::ArrayRef<NamedValue> inputs,
-    at::ArrayRef<NamedValue> attributes,
-    size_t n_binders) {
-  return std::make_shared<SimpleValue>(
-      emitBuiltinCall(loc, *m.graph(), symbol, self, inputs, attributes, true));
-}
-
 inline bool isSupportedListElementType(const TypePtr& type) {
   return type->isSubtypeOf(TensorType::get()) ||
       type->isSubtypeOf(NumberType::get());
@@ -533,6 +523,7 @@ struct to_ir {
     }
 
     method.setSchema(emitDef(def, self, graph->block()));
+
     runCleanupPasses(graph);
   }
 
index ca61850..f158c4c 100644 (file)
@@ -5,9 +5,9 @@
 #include <torch/csrc/Layout.h>
 #include <torch/csrc/jit/import.h>
 #include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/module.h>
 #include <torch/csrc/jit/script/schema_matching.h>
 #include <torch/csrc/jit/script/sugared_value.h>
-#include <torch/csrc/jit/script/module.h>
 #include <torch/csrc/jit/testing/file_check.h>
 
 #include <torch/csrc/jit/constants.h>
@@ -211,7 +211,6 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
   }
 };
 
-
 struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
   explicit ConstantPythonTupleValue(py::object tup)
       : PythonValue(std::move(tup)) {}
@@ -854,18 +853,20 @@ void initJitScriptBindings(PyObject* module) {
       .def(
           "_copy_method",
           [](std::shared_ptr<Module> m,
-            std::string name,
-            std::vector<std::tuple<std::shared_ptr<Module>, std::string>> params,
-            std::shared_ptr<Module> orig) {
-              std::vector<at::Tensor*> member_inputs;
-              for (auto& p : params) {
-                NamedParameter* np = std::get<0>(p)->find_parameter(std::get<1>(p));
-                AT_ASSERT(np != nullptr);
-                member_inputs.push_back(np->slot());
-              }
-
-              Method* orig_method = orig->find_method(name);
-              m->create_method(name, orig_method->graph()->copy(), member_inputs);
+             std::string name,
+             std::vector<std::tuple<std::shared_ptr<Module>, std::string>>
+                 params,
+             std::shared_ptr<Module> orig) {
+            std::vector<at::Tensor*> member_inputs;
+            for (auto& p : params) {
+              NamedParameter* np =
+                  std::get<0>(p)->find_parameter(std::get<1>(p));
+              AT_ASSERT(np != nullptr);
+              member_inputs.push_back(np->slot());
+            }
+
+            Method* orig_method = orig->find_method(name);
+            m->create_method(name, orig_method->graph()->copy(), member_inputs);
           });
 
   py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
index b103459..282d2fc 100644 (file)
@@ -1,4 +1,5 @@
 #include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/script/schema_matching.h>
 #include <torch/csrc/jit/script/sugared_value.h>
 #include <torch/csrc/jit/script/tree_views.h>
 #include <torch/csrc/jit/script/type_parser.h>
@@ -56,6 +57,16 @@ builtin_cast_methods() {
   return builtin_cast_methods;
 }
 
+std::shared_ptr<SugaredValue> BuiltinFunction::call(
+    const SourceRange& loc,
+    Method& m,
+    at::ArrayRef<NamedValue> inputs,
+    at::ArrayRef<NamedValue> attributes,
+    size_t n_binders) {
+  return std::make_shared<SimpleValue>(
+      emitBuiltinCall(loc, *m.graph(), symbol, self, inputs, attributes, true));
+}
+
 // support syntax sugar for x.foo(y, z) by allowing x.foo to return a
 // callable value that will resolve to foo(x, y, z) when called.
 std::shared_ptr<SugaredValue> SimpleValue::attr(
@@ -95,7 +106,9 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
     auto names = tuple_type->names();
     for (int i = 0; i < names.size(); i++) {
       if (names[i] == field) {
-        auto r = m.graph()->insertNode(m.graph()->createTupleIndex(getValue(), i))->output();
+        auto r = m.graph()
+                     ->insertNode(m.graph()->createTupleIndex(getValue(), i))
+                     ->output();
         return std::make_shared<SimpleValue>(r);
       }
     }
index d41cc62..c34aaad 100644 (file)
@@ -875,7 +875,8 @@ struct DictLiteral : public Expr {
       const SourceRange& range,
       const List<Expr>& keys,
       const List<Expr>& values) {
-    return DictLiteral(Compound::create(TK_DICT_LITERAL, range, {keys, values}));
+    return DictLiteral(
+        Compound::create(TK_DICT_LITERAL, range, {keys, values}));
   }
 };