From 3a031c414a5c5ecbd8fc6320e25fc9373940346c Mon Sep 17 00:00:00 2001 From: Bram Wasti Date: Wed, 17 Apr 2019 13:07:12 -0700 Subject: [PATCH] Add options to Operator to enable registration of alias analysis passes (#18589) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18589 ghimport-source-id: dab203f6be13bf41963848f5315235b6bbe45c08 Differential Revision: D14901379 Pulled By: bwasti fbshipit-source-id: d92a497e280f1b0a63b11a9fd8ae9b48bf52e6bf --- test/cpp/jit/test.cpp | 1 + test/cpp/jit/test_alias_analysis.h | 32 +++++++++++++++++++++ torch/csrc/jit/custom_operator.h | 34 +++++++++++++--------- torch/csrc/jit/operator.cpp | 3 +- torch/csrc/jit/operator.h | 49 +++++++++++++++++++++++++------- torch/csrc/jit/operator_options.h | 29 +++++++++++++++++++ torch/csrc/jit/passes/alias_analysis.cpp | 19 +++++++++++++ torch/csrc/jit/passes/alias_analysis.h | 1 + 8 files changed, 142 insertions(+), 26 deletions(-) create mode 100644 torch/csrc/jit/operator_options.h diff --git a/test/cpp/jit/test.cpp b/test/cpp/jit/test.cpp index 0975f14..090591d 100644 --- a/test/cpp/jit/test.cpp +++ b/test/cpp/jit/test.cpp @@ -56,6 +56,7 @@ namespace jit { _(TopologicalMove) \ _(SubgraphUtils) \ _(AliasAnalysis) \ + _(AliasRegistration) \ _(WriteTracking) \ _(Wildcards) \ _(MemoryDAG) \ diff --git a/test/cpp/jit/test_alias_analysis.h b/test/cpp/jit/test_alias_analysis.h index 2b61cfc..46bb934 100644 --- a/test/cpp/jit/test_alias_analysis.h +++ b/test/cpp/jit/test_alias_analysis.h @@ -605,5 +605,37 @@ void testMemoryDAG() { ASSERT_FALSE(t.mayAlias(foo, baz)); } } + +void testAliasRegistration() { + { + auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::DEFAULT); + RegisterOperators reg({createOperator( + "foo::rand", + [](at::Tensor) -> at::Tensor { + return at::rand({2, 2}); + }, + opts)}); + const auto rand_op = Symbol::fromQualString("foo::rand"); + auto graph = std::make_shared(); + auto a = graph->addInput(); + auto b = graph->insert(rand_op, {a}); + AliasDb aliasDb(graph); + // Conservatively we assume there is a reference + ASSERT_TRUE(aliasDb.mayAlias(a, b)); + } + { + auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::PURE); + RegisterOperators reg({createOperator( + "foo::pure", [](at::Tensor t) -> at::Tensor { return t * 2; }, opts)}); + const auto rand_op = Symbol::fromQualString("foo::pure"); + auto graph = std::make_shared(); + auto a = graph->addInput(); + auto b = graph->insert(rand_op, {a}); + AliasDb aliasDb(graph); + // PURE means there is no reference + ASSERT_FALSE(aliasDb.mayAlias(a, b)); + } +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/custom_operator.h b/torch/csrc/jit/custom_operator.h index 73b7cbb..87c2f74 100644 --- a/torch/csrc/jit/custom_operator.h +++ b/torch/csrc/jit/custom_operator.h @@ -170,7 +170,8 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) { template Operator createOperator( const std::string& schemaOrName, - Implementation&& implementation) { + Implementation&& implementation, + OperatorOptions options = OperatorOptions()) { using Traits = c10::guts::infer_function_traits_t; using ArgumentTypes = c10::guts::typelist::map_t; @@ -201,16 +202,20 @@ Operator createOperator( name.ns().toUnqualString()); } - return Operator(schema, [implementation, schema](Stack& stack) { - ArgumentTuple tuple; - torch::jit::detail::callOperatorWithTuple( - schema, - std::move(implementation), // NOLINT(bugprone-move-forwarding-reference) - stack, - tuple, - typename MakeIndices::indices{}); - return 0; - }); + return Operator( + schema, + [implementation, schema](Stack& stack) { + ArgumentTuple tuple; + torch::jit::detail::callOperatorWithTuple( + schema, + std::move( + implementation), // NOLINT(bugprone-move-forwarding-reference) + stack, + tuple, + typename MakeIndices::indices{}); + return 0; + }, + std::move(options)); } /// Registration class for new operators. Effectively calls @@ -240,9 +245,10 @@ struct TORCH_API RegisterOperators { template RegisterOperators& op( const std::string& name, - Implementation&& implementation) { - registerOperator( - createOperator(name, std::forward(implementation))); + Implementation&& implementation, + OperatorOptions options = OperatorOptions()) { + registerOperator(createOperator( + name, std::forward(implementation), options)); return *this; } }; diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index f8fbdc9..89bac0c 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -379,7 +379,8 @@ void registerOperator(Operator&& op) { op.schema().name(), ". File a bug to add a case for this operator.\n"); } - if (!aliasAnalysisHasSpecialCaseFor(s)) { + if (!aliasAnalysisHasSpecialCaseFor(s) && + op.options().aliasAnalysis() == AliasAnalysisKind::DEFAULT) { AT_ERROR( "Missing special case in alias analysis for non-schematized" " operator ", diff --git a/torch/csrc/jit/operator.h b/torch/csrc/jit/operator.h index 491d1bc..fed1c1b 100644 --- a/torch/csrc/jit/operator.h +++ b/torch/csrc/jit/operator.h @@ -3,9 +3,10 @@ // it now to implement correct semantic checking for script #pragma once +#include #include #include -#include +#include #include #include @@ -58,19 +59,31 @@ using OperationCreator = std::function; */ struct TORCH_API Operator { - Operator(FunctionSchema schema, OperationCreator op_creator) + Operator( + FunctionSchema schema, + OperationCreator op_creator, + OperatorOptions options = OperatorOptions()) : schema_(std::make_shared(std::move(schema))), - op_creator_(std::move(op_creator)) {} + op_creator_(std::move(op_creator)), + options_(std::move(options)) {} - Operator(const std::string& schema, OperationCreator op_creator) - : schema_string_(schema), op_creator_(std::move(op_creator)) {} + Operator( + const std::string& schema, + OperationCreator op_creator, + OperatorOptions options = OperatorOptions()) + : schema_string_(schema), + op_creator_(std::move(op_creator)), + options_(std::move(options)) {} // Helper constructor to register `op` to run // run for _every_ IR Node where n.kind() == name, regardless of arguments. // This is accomplished by marking the schema varargs and having no required // arguments. This is used for things like prim::While or prim::If that can // take a number of different valid input types and lengths. - Operator(Symbol name, OperationCreator op_creator) + Operator( + Symbol name, + OperationCreator op_creator, + OperatorOptions options = OperatorOptions()) : Operator( FunctionSchema( name, @@ -79,15 +92,24 @@ struct TORCH_API Operator { {}, /*is_vararg*/ true, /*is_varret*/ true), - std::move(op_creator)) {} + std::move(op_creator), + std::move(options)) {} - Operator(FunctionSchema schema, Operation op) + Operator( + FunctionSchema schema, + Operation op, + OperatorOptions options = OperatorOptions()) : schema_(std::make_shared(std::move(schema))), - op_(std::make_shared(std::move(op))) {} + op_(std::make_shared(std::move(op))), + options_(std::move(options)) {} - Operator(const std::string& schema, Operation op) + Operator( + const std::string& schema, + Operation op, + OperatorOptions options = OperatorOptions()) : schema_string_(schema), - op_(std::make_shared(std::move(op))) {} + op_(std::make_shared(std::move(op))), + options_(std::move(options)) {} bool matches(const Node* node) const; @@ -110,6 +132,10 @@ struct TORCH_API Operator { return *schema_; } + const OperatorOptions& options() const { + return options_; + } + private: mutable c10::optional schema_string_; // cannot use c10::optional because windows has issues that require an @@ -121,6 +147,7 @@ struct TORCH_API Operator { // NB: std::function has a default state (where it == nullptr). std::shared_ptr op_; OperationCreator op_creator_; + OperatorOptions options_; }; TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema); diff --git a/torch/csrc/jit/operator_options.h b/torch/csrc/jit/operator_options.h new file mode 100644 index 0000000..3e06ade --- /dev/null +++ b/torch/csrc/jit/operator_options.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +enum class AliasAnalysisKind { + DEFAULT, // The most conservative alias analysis type, assumes side-effects + PURE +}; + +struct OperatorOptions { + OperatorOptions(){}; + + OperatorOptions aliasAnalysis(AliasAnalysisKind aak) const noexcept { + OperatorOptions r = *this; + r.aliasAnalysisKind_ = aak; + return r; + } + + const AliasAnalysisKind& aliasAnalysis() const { + return aliasAnalysisKind_; + } + AliasAnalysisKind aliasAnalysisKind_ = AliasAnalysisKind::DEFAULT; +}; + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index 10f2414..59895e6 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -1,5 +1,6 @@ #include +#include #include #include @@ -346,6 +347,21 @@ void AliasDb::analyze(Node* node) { } } +// Returns true if analysis was run using +// the registered analyzer. +bool AliasDb::tryRegisteredAnalysis(Node* node) { + const Operator& op = getOperatorFor(node); + auto analysis = op.options().aliasAnalysis(); + switch (analysis) { + case AliasAnalysisKind::PURE: + analyzeCreator(node); + return true; + case AliasAnalysisKind::DEFAULT: + return false; + } + return false; +} + // The basic strategy is: // 1. Retrieve alias information for every input. // 2. Use the node's schema's alias annotations to propgagate alias/write @@ -409,6 +425,9 @@ void AliasDb::analyzeImpl(Node* node) { // These ops do nothing return; default: + if (tryRegisteredAnalysis(node)) { + return; + } AT_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind())); } diff --git a/torch/csrc/jit/passes/alias_analysis.h b/torch/csrc/jit/passes/alias_analysis.h index 2574e61..2ae8521 100644 --- a/torch/csrc/jit/passes/alias_analysis.h +++ b/torch/csrc/jit/passes/alias_analysis.h @@ -183,6 +183,7 @@ class AliasDb { void analyzeWait(Node* node); void analyzeSetAttr(Node* node); void analyzeCustomOp(Node* node); + bool tryRegisteredAnalysis(Node* node); /** * Alias manipulation methods -- 2.7.4