Revert D14901379: [jit] Add options to Operator to enable registration of alias analy...
authorMichael Suo <suo@fb.com>
Wed, 17 Apr 2019 23:48:28 +0000 (16:48 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 17 Apr 2019 23:56:14 +0000 (16:56 -0700)
Differential Revision:
D14901379

Original commit changeset: d92a497e280f

fbshipit-source-id: 51d31491ab90907a6c95af5d8a59dff5e5ed36a4

test/cpp/jit/test.cpp
test/cpp/jit/test_alias_analysis.h
torch/csrc/jit/custom_operator.h
torch/csrc/jit/operator.cpp
torch/csrc/jit/operator.h
torch/csrc/jit/operator_options.h [deleted file]
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/alias_analysis.h

index 090591d..0975f14 100644 (file)
@@ -56,7 +56,6 @@ namespace jit {
   _(TopologicalMove)               \
   _(SubgraphUtils)                 \
   _(AliasAnalysis)                 \
-  _(AliasRegistration)             \
   _(WriteTracking)                 \
   _(Wildcards)                     \
   _(MemoryDAG)                     \
index 46bb934..2b61cfc 100644 (file)
@@ -605,37 +605,5 @@ void testMemoryDAG() {
     ASSERT_FALSE(t.mayAlias(foo, baz));
   }
 }
-
-void testAliasRegistration() {
-  {
-    auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::DEFAULT);
-    RegisterOperators reg({createOperator(
-        "foo::rand",
-        [](at::Tensor) -> at::Tensor {
-          return at::rand({2, 2});
-        },
-        opts)});
-    const auto rand_op = Symbol::fromQualString("foo::rand");
-    auto graph = std::make_shared<Graph>();
-    auto a = graph->addInput();
-    auto b = graph->insert(rand_op, {a});
-    AliasDb aliasDb(graph);
-    // Conservatively we assume there is a reference
-    ASSERT_TRUE(aliasDb.mayAlias(a, b));
-  }
-  {
-    auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::PURE);
-    RegisterOperators reg({createOperator(
-        "foo::pure", [](at::Tensor t) -> at::Tensor { return t * 2; }, opts)});
-    const auto rand_op = Symbol::fromQualString("foo::pure");
-    auto graph = std::make_shared<Graph>();
-    auto a = graph->addInput();
-    auto b = graph->insert(rand_op, {a});
-    AliasDb aliasDb(graph);
-    // PURE means there is no reference
-    ASSERT_FALSE(aliasDb.mayAlias(a, b));
-  }
-}
-
 } // namespace jit
 } // namespace torch
index 87c2f74..73b7cbb 100644 (file)
@@ -170,8 +170,7 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) {
 template <typename Implementation>
 Operator createOperator(
     const std::string& schemaOrName,
-    Implementation&& implementation,
-    OperatorOptions options = OperatorOptions()) {
+    Implementation&& implementation) {
   using Traits = c10::guts::infer_function_traits_t<Implementation>;
   using ArgumentTypes =
       c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>;
@@ -202,20 +201,16 @@ Operator createOperator(
         name.ns().toUnqualString());
   }
 
-  return Operator(
-      schema,
-      [implementation, schema](Stack& stack) {
-        ArgumentTuple tuple;
-        torch::jit::detail::callOperatorWithTuple(
-            schema,
-            std::move(
-                implementation), // NOLINT(bugprone-move-forwarding-reference)
-            stack,
-            tuple,
-            typename MakeIndices<kNumberOfArguments>::indices{});
-        return 0;
-      },
-      std::move(options));
+  return Operator(schema, [implementation, schema](Stack& stack) {
+    ArgumentTuple tuple;
+    torch::jit::detail::callOperatorWithTuple(
+        schema,
+        std::move(implementation), // NOLINT(bugprone-move-forwarding-reference)
+        stack,
+        tuple,
+        typename MakeIndices<kNumberOfArguments>::indices{});
+    return 0;
+  });
 }
 
 /// Registration class for new operators. Effectively calls
@@ -245,10 +240,9 @@ struct TORCH_API RegisterOperators {
   template <typename Implementation>
   RegisterOperators& op(
       const std::string& name,
-      Implementation&& implementation,
-      OperatorOptions options = OperatorOptions()) {
-    registerOperator(createOperator(
-        name, std::forward<Implementation>(implementation), options));
+      Implementation&& implementation) {
+    registerOperator(
+        createOperator(name, std::forward<Implementation>(implementation)));
     return *this;
   }
 };
index 89bac0c..f8fbdc9 100644 (file)
@@ -379,8 +379,7 @@ void registerOperator(Operator&& op) {
           op.schema().name(),
           ". File a bug to add a case for this operator.\n");
     }
-    if (!aliasAnalysisHasSpecialCaseFor(s) &&
-        op.options().aliasAnalysis() == AliasAnalysisKind::DEFAULT) {
+    if (!aliasAnalysisHasSpecialCaseFor(s)) {
       AT_ERROR(
           "Missing special case in alias analysis for non-schematized"
           " operator ",
index fed1c1b..491d1bc 100644 (file)
@@ -3,10 +3,9 @@
 // it now to implement correct semantic checking for script
 #pragma once
 
-#include <ATen/core/stack.h>
 #include <c10/util/Exception.h>
 #include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/operator_options.h>
+#include <ATen/core/stack.h>
 
 #include <ATen/ATen.h>
 #include <ATen/core/function_schema.h>
@@ -59,31 +58,19 @@ using OperationCreator = std::function<Operation(const Node*)>;
  */
 
 struct TORCH_API Operator {
-  Operator(
-      FunctionSchema schema,
-      OperationCreator op_creator,
-      OperatorOptions options = OperatorOptions())
+  Operator(FunctionSchema schema, OperationCreator op_creator)
       : schema_(std::make_shared<FunctionSchema>(std::move(schema))),
-        op_creator_(std::move(op_creator)),
-        options_(std::move(options)) {}
+        op_creator_(std::move(op_creator)) {}
 
-  Operator(
-      const std::string& schema,
-      OperationCreator op_creator,
-      OperatorOptions options = OperatorOptions())
-      : schema_string_(schema),
-        op_creator_(std::move(op_creator)),
-        options_(std::move(options)) {}
+  Operator(const std::string& schema, OperationCreator op_creator)
+      : schema_string_(schema), op_creator_(std::move(op_creator)) {}
 
   // Helper constructor to register `op` to run
   // run for _every_ IR Node where n.kind() == name, regardless of arguments.
   // This is accomplished by marking the schema varargs and having no required
   // arguments. This is used for things like prim::While or prim::If that can
   // take a number of different valid input types and lengths.
-  Operator(
-      Symbol name,
-      OperationCreator op_creator,
-      OperatorOptions options = OperatorOptions())
+  Operator(Symbol name, OperationCreator op_creator)
       : Operator(
             FunctionSchema(
                 name,
@@ -92,24 +79,15 @@ struct TORCH_API Operator {
                 {},
                 /*is_vararg*/ true,
                 /*is_varret*/ true),
-            std::move(op_creator),
-            std::move(options)) {}
+            std::move(op_creator)) {}
 
-  Operator(
-      FunctionSchema schema,
-      Operation op,
-      OperatorOptions options = OperatorOptions())
+  Operator(FunctionSchema schema, Operation op)
       : schema_(std::make_shared<FunctionSchema>(std::move(schema))),
-        op_(std::make_shared<Operation>(std::move(op))),
-        options_(std::move(options)) {}
+        op_(std::make_shared<Operation>(std::move(op))) {}
 
-  Operator(
-      const std::string& schema,
-      Operation op,
-      OperatorOptions options = OperatorOptions())
+  Operator(const std::string& schema, Operation op)
       : schema_string_(schema),
-        op_(std::make_shared<Operation>(std::move(op))),
-        options_(std::move(options)) {}
+        op_(std::make_shared<Operation>(std::move(op))) {}
 
   bool matches(const Node* node) const;
 
@@ -132,10 +110,6 @@ struct TORCH_API Operator {
     return *schema_;
   }
 
-  const OperatorOptions& options() const {
-    return options_;
-  }
-
  private:
   mutable c10::optional<std::string> schema_string_;
   // cannot use c10::optional because windows has issues that require an
@@ -147,7 +121,6 @@ struct TORCH_API Operator {
   // NB: std::function has a default state (where it == nullptr).
   std::shared_ptr<Operation> op_;
   OperationCreator op_creator_;
-  OperatorOptions options_;
 };
 
 TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
diff --git a/torch/csrc/jit/operator_options.h b/torch/csrc/jit/operator_options.h
deleted file mode 100644 (file)
index 3e06ade..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/passes/alias_analysis.h>
-
-namespace torch {
-namespace jit {
-
-enum class AliasAnalysisKind {
-  DEFAULT, // The most conservative alias analysis type, assumes side-effects
-  PURE
-};
-
-struct OperatorOptions {
-  OperatorOptions(){};
-
-  OperatorOptions aliasAnalysis(AliasAnalysisKind aak) const noexcept {
-    OperatorOptions r = *this;
-    r.aliasAnalysisKind_ = aak;
-    return r;
-  }
-
-  const AliasAnalysisKind& aliasAnalysis() const {
-    return aliasAnalysisKind_;
-  }
-  AliasAnalysisKind aliasAnalysisKind_ = AliasAnalysisKind::DEFAULT;
-};
-
-} // namespace jit
-} // namespace torch
index 59895e6..10f2414 100644 (file)
@@ -1,6 +1,5 @@
 #include <torch/csrc/jit/passes/alias_analysis.h>
 
-#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/script/error_report.h>
 #include <torch/csrc/utils/memory.h>
 
@@ -347,21 +346,6 @@ void AliasDb::analyze(Node* node) {
   }
 }
 
-// Returns true if analysis was run using
-// the registered analyzer.
-bool AliasDb::tryRegisteredAnalysis(Node* node) {
-  const Operator& op = getOperatorFor(node);
-  auto analysis = op.options().aliasAnalysis();
-  switch (analysis) {
-    case AliasAnalysisKind::PURE:
-      analyzeCreator(node);
-      return true;
-    case AliasAnalysisKind::DEFAULT:
-      return false;
-  }
-  return false;
-}
-
 // The basic strategy is:
 //   1. Retrieve alias information for every input.
 //   2. Use the node's schema's alias annotations to propgagate alias/write
@@ -425,9 +409,6 @@ void AliasDb::analyzeImpl(Node* node) {
       // These ops do nothing
       return;
     default:
-      if (tryRegisteredAnalysis(node)) {
-        return;
-      }
       AT_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind()));
   }
 
index 2ae8521..2574e61 100644 (file)
@@ -183,7 +183,6 @@ class AliasDb {
   void analyzeWait(Node* node);
   void analyzeSetAttr(Node* node);
   void analyzeCustomOp(Node* node);
-  bool tryRegisteredAnalysis(Node* node);
 
   /**
    * Alias manipulation methods