Add options to Operator to enable registration of alias analysis passes (#18589)
authorBram Wasti <bwasti@fb.com>
Wed, 17 Apr 2019 20:07:12 +0000 (13:07 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 17 Apr 2019 20:14:55 +0000 (13:14 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18589
ghimport-source-id: dab203f6be13bf41963848f5315235b6bbe45c08

Differential Revision: D14901379

Pulled By: bwasti

fbshipit-source-id: d92a497e280f1b0a63b11a9fd8ae9b48bf52e6bf

test/cpp/jit/test.cpp
test/cpp/jit/test_alias_analysis.h
torch/csrc/jit/custom_operator.h
torch/csrc/jit/operator.cpp
torch/csrc/jit/operator.h
torch/csrc/jit/operator_options.h [new file with mode: 0644]
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/alias_analysis.h

index 0975f14..090591d 100644 (file)
@@ -56,6 +56,7 @@ namespace jit {
   _(TopologicalMove)               \
   _(SubgraphUtils)                 \
   _(AliasAnalysis)                 \
+  _(AliasRegistration)             \
   _(WriteTracking)                 \
   _(Wildcards)                     \
   _(MemoryDAG)                     \
index 2b61cfc..46bb934 100644 (file)
@@ -605,5 +605,37 @@ void testMemoryDAG() {
     ASSERT_FALSE(t.mayAlias(foo, baz));
   }
 }
+
+void testAliasRegistration() {
+  {
+    auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::DEFAULT);
+    RegisterOperators reg({createOperator(
+        "foo::rand",
+        [](at::Tensor) -> at::Tensor {
+          return at::rand({2, 2});
+        },
+        opts)});
+    const auto rand_op = Symbol::fromQualString("foo::rand");
+    auto graph = std::make_shared<Graph>();
+    auto a = graph->addInput();
+    auto b = graph->insert(rand_op, {a});
+    AliasDb aliasDb(graph);
+    // Conservatively we assume there is a reference
+    ASSERT_TRUE(aliasDb.mayAlias(a, b));
+  }
+  {
+    auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::PURE);
+    RegisterOperators reg({createOperator(
+        "foo::pure", [](at::Tensor t) -> at::Tensor { return t * 2; }, opts)});
+    const auto rand_op = Symbol::fromQualString("foo::pure");
+    auto graph = std::make_shared<Graph>();
+    auto a = graph->addInput();
+    auto b = graph->insert(rand_op, {a});
+    AliasDb aliasDb(graph);
+    // PURE means there is no reference
+    ASSERT_FALSE(aliasDb.mayAlias(a, b));
+  }
+}
+
 } // namespace jit
 } // namespace torch
index 73b7cbb..87c2f74 100644 (file)
@@ -170,7 +170,8 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) {
 template <typename Implementation>
 Operator createOperator(
     const std::string& schemaOrName,
-    Implementation&& implementation) {
+    Implementation&& implementation,
+    OperatorOptions options = OperatorOptions()) {
   using Traits = c10::guts::infer_function_traits_t<Implementation>;
   using ArgumentTypes =
       c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>;
@@ -201,16 +202,20 @@ Operator createOperator(
         name.ns().toUnqualString());
   }
 
-  return Operator(schema, [implementation, schema](Stack& stack) {
-    ArgumentTuple tuple;
-    torch::jit::detail::callOperatorWithTuple(
-        schema,
-        std::move(implementation), // NOLINT(bugprone-move-forwarding-reference)
-        stack,
-        tuple,
-        typename MakeIndices<kNumberOfArguments>::indices{});
-    return 0;
-  });
+  return Operator(
+      schema,
+      [implementation, schema](Stack& stack) {
+        ArgumentTuple tuple;
+        torch::jit::detail::callOperatorWithTuple(
+            schema,
+            std::move(
+                implementation), // NOLINT(bugprone-move-forwarding-reference)
+            stack,
+            tuple,
+            typename MakeIndices<kNumberOfArguments>::indices{});
+        return 0;
+      },
+      std::move(options));
 }
 
 /// Registration class for new operators. Effectively calls
@@ -240,9 +245,10 @@ struct TORCH_API RegisterOperators {
   template <typename Implementation>
   RegisterOperators& op(
       const std::string& name,
-      Implementation&& implementation) {
-    registerOperator(
-        createOperator(name, std::forward<Implementation>(implementation)));
+      Implementation&& implementation,
+      OperatorOptions options = OperatorOptions()) {
+    registerOperator(createOperator(
+        name, std::forward<Implementation>(implementation), options));
     return *this;
   }
 };
index f8fbdc9..89bac0c 100644 (file)
@@ -379,7 +379,8 @@ void registerOperator(Operator&& op) {
           op.schema().name(),
           ". File a bug to add a case for this operator.\n");
     }
-    if (!aliasAnalysisHasSpecialCaseFor(s)) {
+    if (!aliasAnalysisHasSpecialCaseFor(s) &&
+        op.options().aliasAnalysis() == AliasAnalysisKind::DEFAULT) {
       AT_ERROR(
           "Missing special case in alias analysis for non-schematized"
           " operator ",
index 491d1bc..fed1c1b 100644 (file)
@@ -3,9 +3,10 @@
 // it now to implement correct semantic checking for script
 #pragma once
 
+#include <ATen/core/stack.h>
 #include <c10/util/Exception.h>
 #include <torch/csrc/jit/ir.h>
-#include <ATen/core/stack.h>
+#include <torch/csrc/jit/operator_options.h>
 
 #include <ATen/ATen.h>
 #include <ATen/core/function_schema.h>
@@ -58,19 +59,31 @@ using OperationCreator = std::function<Operation(const Node*)>;
  */
 
 struct TORCH_API Operator {
-  Operator(FunctionSchema schema, OperationCreator op_creator)
+  Operator(
+      FunctionSchema schema,
+      OperationCreator op_creator,
+      OperatorOptions options = OperatorOptions())
       : schema_(std::make_shared<FunctionSchema>(std::move(schema))),
-        op_creator_(std::move(op_creator)) {}
+        op_creator_(std::move(op_creator)),
+        options_(std::move(options)) {}
 
-  Operator(const std::string& schema, OperationCreator op_creator)
-      : schema_string_(schema), op_creator_(std::move(op_creator)) {}
+  Operator(
+      const std::string& schema,
+      OperationCreator op_creator,
+      OperatorOptions options = OperatorOptions())
+      : schema_string_(schema),
+        op_creator_(std::move(op_creator)),
+        options_(std::move(options)) {}
 
   // Helper constructor to register `op` to run
   // run for _every_ IR Node where n.kind() == name, regardless of arguments.
   // This is accomplished by marking the schema varargs and having no required
   // arguments. This is used for things like prim::While or prim::If that can
   // take a number of different valid input types and lengths.
-  Operator(Symbol name, OperationCreator op_creator)
+  Operator(
+      Symbol name,
+      OperationCreator op_creator,
+      OperatorOptions options = OperatorOptions())
       : Operator(
             FunctionSchema(
                 name,
@@ -79,15 +92,24 @@ struct TORCH_API Operator {
                 {},
                 /*is_vararg*/ true,
                 /*is_varret*/ true),
-            std::move(op_creator)) {}
+            std::move(op_creator),
+            std::move(options)) {}
 
-  Operator(FunctionSchema schema, Operation op)
+  Operator(
+      FunctionSchema schema,
+      Operation op,
+      OperatorOptions options = OperatorOptions())
       : schema_(std::make_shared<FunctionSchema>(std::move(schema))),
-        op_(std::make_shared<Operation>(std::move(op))) {}
+        op_(std::make_shared<Operation>(std::move(op))),
+        options_(std::move(options)) {}
 
-  Operator(const std::string& schema, Operation op)
+  Operator(
+      const std::string& schema,
+      Operation op,
+      OperatorOptions options = OperatorOptions())
       : schema_string_(schema),
-        op_(std::make_shared<Operation>(std::move(op))) {}
+        op_(std::make_shared<Operation>(std::move(op))),
+        options_(std::move(options)) {}
 
   bool matches(const Node* node) const;
 
@@ -110,6 +132,10 @@ struct TORCH_API Operator {
     return *schema_;
   }
 
+  const OperatorOptions& options() const {
+    return options_;
+  }
+
  private:
   mutable c10::optional<std::string> schema_string_;
   // cannot use c10::optional because windows has issues that require an
@@ -121,6 +147,7 @@ struct TORCH_API Operator {
   // NB: std::function has a default state (where it == nullptr).
   std::shared_ptr<Operation> op_;
   OperationCreator op_creator_;
+  OperatorOptions options_;
 };
 
 TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
diff --git a/torch/csrc/jit/operator_options.h b/torch/csrc/jit/operator_options.h
new file mode 100644 (file)
index 0000000..3e06ade
--- /dev/null
@@ -0,0 +1,29 @@
+#pragma once
+
+#include <torch/csrc/jit/passes/alias_analysis.h>
+
+namespace torch {
+namespace jit {
+
+enum class AliasAnalysisKind {
+  DEFAULT, // The most conservative alias analysis type, assumes side-effects
+  PURE
+};
+
+struct OperatorOptions {
+  OperatorOptions(){};
+
+  OperatorOptions aliasAnalysis(AliasAnalysisKind aak) const noexcept {
+    OperatorOptions r = *this;
+    r.aliasAnalysisKind_ = aak;
+    return r;
+  }
+
+  const AliasAnalysisKind& aliasAnalysis() const {
+    return aliasAnalysisKind_;
+  }
+  AliasAnalysisKind aliasAnalysisKind_ = AliasAnalysisKind::DEFAULT;
+};
+
+} // namespace jit
+} // namespace torch
index 10f2414..59895e6 100644 (file)
@@ -1,5 +1,6 @@
 #include <torch/csrc/jit/passes/alias_analysis.h>
 
+#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/script/error_report.h>
 #include <torch/csrc/utils/memory.h>
 
@@ -346,6 +347,21 @@ void AliasDb::analyze(Node* node) {
   }
 }
 
+// Returns true if analysis was run using
+// the registered analyzer.
+bool AliasDb::tryRegisteredAnalysis(Node* node) {
+  const Operator& op = getOperatorFor(node);
+  auto analysis = op.options().aliasAnalysis();
+  switch (analysis) {
+    case AliasAnalysisKind::PURE:
+      analyzeCreator(node);
+      return true;
+    case AliasAnalysisKind::DEFAULT:
+      return false;
+  }
+  return false;
+}
+
 // The basic strategy is:
 //   1. Retrieve alias information for every input.
 //   2. Use the node's schema's alias annotations to propgagate alias/write
@@ -409,6 +425,9 @@ void AliasDb::analyzeImpl(Node* node) {
       // These ops do nothing
       return;
     default:
+      if (tryRegisteredAnalysis(node)) {
+        return;
+      }
       AT_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind()));
   }
 
index 2574e61..2ae8521 100644 (file)
@@ -183,6 +183,7 @@ class AliasDb {
   void analyzeWait(Node* node);
   void analyzeSetAttr(Node* node);
   void analyzeCustomOp(Node* node);
+  bool tryRegisteredAnalysis(Node* node);
 
   /**
    * Alias manipulation methods