From 501d346da83f970c9c1bc491e7f797bf3a889278 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Wed, 20 Feb 2019 18:27:31 -0800 Subject: [PATCH] batched cleanups (#17288) Summary: Bunch of random stuff I came across while doing UDT stuff. Putting in a separate PR to avoid noise - fix up the alias analysis list ops to include fork/wait - improve dump() for aliasDb to print writes - Move BuiltinFunction::call() to sugaredvalue with the rest of the methods - formatting and includes Pull Request resolved: https://github.com/pytorch/pytorch/pull/17288 Differential Revision: D14147105 Pulled By: suo fbshipit-source-id: 62e2a922a1726b684347365dc42c72188f154e9c --- aten/src/ATen/core/ivalue.h | 3 +- aten/src/ATen/core/jit_type.h | 8 +-- torch/csrc/jit/passes/alias_analysis.cpp | 79 ++++++++++++++------------- torch/csrc/jit/passes/python_print.cpp | 2 +- torch/csrc/jit/passes/utils/alias_tracker.cpp | 13 +++++ torch/csrc/jit/script/compiler.cpp | 13 +---- torch/csrc/jit/script/init.cpp | 29 +++++----- torch/csrc/jit/script/sugared_value.cpp | 15 ++++- torch/csrc/jit/script/tree_views.h | 3 +- 9 files changed, 94 insertions(+), 71 deletions(-) diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index cb633ec..08b2298 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -3,10 +3,11 @@ #include #include +#include +#include #include #include #include -#include #include #include diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index f115b35..4db6f06 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1,10 +1,10 @@ #pragma once -#include -#include -#include -#include #include +#include +#include +#include +#include #include #include diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index e4b44e3..899bff7 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -239,6 +239,9 @@ void AliasDb::analyze(const std::shared_ptr& graph) { for (const auto& pr : tupleTypes) { makeAllAlias(pr.second, *aliasTracker_); } + for (const auto& pr : dictTypes) { + makeAllAlias(pr.second, *aliasTracker_); + } makeAllAlias(tensors, *aliasTracker_); analyze(graph->block()); @@ -321,7 +324,6 @@ void AliasDb::analyzeImpl(Node* node) { AT_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind())); } - const auto& schema = node->schema(); if (schema.is_vararg() || schema.is_varret()) { const auto hasMutableOutputs = std::any_of( @@ -973,47 +975,48 @@ TORCH_API bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { // WARNING: by adding a case to this list, you are asserting that you have // added a case for the unschematized node in AliasDb::analyze const static std::unordered_set handled = { - prim::If, - prim::Loop, - prim::FusionGroup, - prim::DifferentiableGraph, - prim::Constant, - prim::DictConstruct, - prim::ListConstruct, - prim::TupleConstruct, - prim::Undefined, - prim::FusedConcat, - prim::MMTreeReduce, - prim::MMBatchSide, - prim::None, - prim::BroadcastSizes, - prim::ChunkSizes, - prim::Function, - prim::TupleUnpack, - prim::TupleIndex, - prim::DictIndex, - prim::TupleSlice, - prim::ListUnpack, - prim::PythonOp, - prim::ConstantChunk, - prim::BroadcastingChunk, - aten::add, - aten::sub, - aten::mul, - aten::div, + prim::If, + prim::Loop, + prim::FusionGroup, + prim::DifferentiableGraph, + prim::Constant, + prim::DictConstruct, + prim::ListConstruct, + prim::TupleConstruct, + prim::Undefined, + prim::FusedConcat, + prim::MMTreeReduce, + prim::MMBatchSide, + prim::None, + prim::BroadcastSizes, + prim::ChunkSizes, + prim::Function, + prim::TupleUnpack, + prim::TupleIndex, + prim::DictIndex, + prim::TupleSlice, + prim::ListUnpack, + prim::PythonOp, + prim::ConstantChunk, + prim::BroadcastingChunk, + prim::fork, + aten::wait, + aten::add, + aten::sub, + aten::mul, + aten::div, }; // Operators that should not be used by alias analysis const static std::unordered_set purposefully_not_handled = { - prim::Print, - prim::Load, - prim::Store, - prim::Drop, - at::onnx::Reshape, - at::onnx::Shape, - prim::AnyDefined, - prim::AutogradAdd, - prim::fork, // TODO: fork aliasing / futures + prim::Print, + prim::Load, + prim::Store, + prim::Drop, + at::onnx::Reshape, + at::onnx::Shape, + prim::AnyDefined, + prim::AutogradAdd, }; return handled.count(symbol) || purposefully_not_handled.count(symbol); diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 7d249cd..4126bcc 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -1,9 +1,9 @@ -#include #include #include #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/passes/utils/alias_tracker.cpp b/torch/csrc/jit/passes/utils/alias_tracker.cpp index a231f01..e6cd358 100644 --- a/torch/csrc/jit/passes/utils/alias_tracker.cpp +++ b/torch/csrc/jit/passes/utils/alias_tracker.cpp @@ -180,6 +180,19 @@ void AliasTracker::dump() const { std::cout << wildcard->uniqueName() << ", "; } std::cout << "\n"; + + std::cout << "\n===4. Writes===\n"; + for (const auto& pr : writeIndex_) { + const auto node = pr.first; + const auto& values = pr.second; + std::cout << *node; + std::cout << " "; + for (const auto value : values) { + std::cout << value->uniqueName() << ", "; + } + std::cout << "\n"; + } + std::cout << "\n"; } std::unordered_set AliasTracker::Element:: diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 25f9607..35f0c23 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -6,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -488,16 +488,6 @@ static Value* ensureInt(const SourceRange& range, Value* v) { return v; } -std::shared_ptr BuiltinFunction::call( - const SourceRange& loc, - Method& m, - at::ArrayRef inputs, - at::ArrayRef attributes, - size_t n_binders) { - return std::make_shared( - emitBuiltinCall(loc, *m.graph(), symbol, self, inputs, attributes, true)); -} - inline bool isSupportedListElementType(const TypePtr& type) { return type->isSubtypeOf(TensorType::get()) || type->isSubtypeOf(NumberType::get()); @@ -533,6 +523,7 @@ struct to_ir { } method.setSchema(emitDef(def, self, graph->block())); + runCleanupPasses(graph); } diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index ca61850..f158c4c 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -5,9 +5,9 @@ #include #include #include +#include #include #include -#include #include #include @@ -211,7 +211,6 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue { } }; - struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue { explicit ConstantPythonTupleValue(py::object tup) : PythonValue(std::move(tup)) {} @@ -854,18 +853,20 @@ void initJitScriptBindings(PyObject* module) { .def( "_copy_method", [](std::shared_ptr m, - std::string name, - std::vector, std::string>> params, - std::shared_ptr orig) { - std::vector member_inputs; - for (auto& p : params) { - NamedParameter* np = std::get<0>(p)->find_parameter(std::get<1>(p)); - AT_ASSERT(np != nullptr); - member_inputs.push_back(np->slot()); - } - - Method* orig_method = orig->find_method(name); - m->create_method(name, orig_method->graph()->copy(), member_inputs); + std::string name, + std::vector, std::string>> + params, + std::shared_ptr orig) { + std::vector member_inputs; + for (auto& p : params) { + NamedParameter* np = + std::get<0>(p)->find_parameter(std::get<1>(p)); + AT_ASSERT(np != nullptr); + member_inputs.push_back(np->slot()); + } + + Method* orig_method = orig->find_method(name); + m->create_method(name, orig_method->graph()->copy(), member_inputs); }); py::class_(m, "ScriptMethod", py::dynamic_attr()) diff --git a/torch/csrc/jit/script/sugared_value.cpp b/torch/csrc/jit/script/sugared_value.cpp index b103459..282d2fc 100644 --- a/torch/csrc/jit/script/sugared_value.cpp +++ b/torch/csrc/jit/script/sugared_value.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -56,6 +57,16 @@ builtin_cast_methods() { return builtin_cast_methods; } +std::shared_ptr BuiltinFunction::call( + const SourceRange& loc, + Method& m, + at::ArrayRef inputs, + at::ArrayRef attributes, + size_t n_binders) { + return std::make_shared( + emitBuiltinCall(loc, *m.graph(), symbol, self, inputs, attributes, true)); +} + // support syntax sugar for x.foo(y, z) by allowing x.foo to return a // callable value that will resolve to foo(x, y, z) when called. std::shared_ptr SimpleValue::attr( @@ -95,7 +106,9 @@ std::shared_ptr SimpleValue::attr( auto names = tuple_type->names(); for (int i = 0; i < names.size(); i++) { if (names[i] == field) { - auto r = m.graph()->insertNode(m.graph()->createTupleIndex(getValue(), i))->output(); + auto r = m.graph() + ->insertNode(m.graph()->createTupleIndex(getValue(), i)) + ->output(); return std::make_shared(r); } } diff --git a/torch/csrc/jit/script/tree_views.h b/torch/csrc/jit/script/tree_views.h index d41cc62..c34aaad 100644 --- a/torch/csrc/jit/script/tree_views.h +++ b/torch/csrc/jit/script/tree_views.h @@ -875,7 +875,8 @@ struct DictLiteral : public Expr { const SourceRange& range, const List& keys, const List& values) { - return DictLiteral(Compound::create(TK_DICT_LITERAL, range, {keys, values})); + return DictLiteral( + Compound::create(TK_DICT_LITERAL, range, {keys, values})); } }; -- 2.7.4