fix side-effects and aliasing for custom ops (#18711)
authorMichael Suo <suo@fb.com>
Fri, 5 Apr 2019 17:40:19 +0000 (10:40 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 5 Apr 2019 17:48:14 +0000 (10:48 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18711
ghimport-source-id: c9caedc0660b2b7ba3730cd0e1a2e0e9c3cf422b

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18711 [jit] fix side-effects and aliasing for custom ops**

Previously we didn't track aliasing, mutation, or side effects for
custom ops. This PR adds in guards with the most conservative
assumptions possible: the op will
1) have side effects,
2) write to everything
3) produce a wildcard.

In order to tell whether a given operator is a custom op, this PR introduces
the concept of a "reserved" namespace (basically all our builtin namespaces).
Custom ops live in non-reserved namespaces, so a check on the namespace
is sufficient to tell whether a schema/node is "custom" or not.

This is just to get things correct for now. Follow-ups to this:
- Users should be able to specify aliasing/mutability without having to learn
the whole alias annotation schema.
- Relax assumptions a bit. In particular outputs can only alias input tensors,
they don't have to be wildcards.

Fixes #18490

Differential Revision: D14730978

fbshipit-source-id: 540b47a24ccf24145051609bdcc99c97e46e0fe0

aten/src/ATen/core/function_schema.h
test/cpp/jit/test.cpp
test/cpp/jit/test_alias_analysis.h
test/cpp/jit/test_custom_operators.h
test/cpp/jit/test_misc.h
torch/csrc/jit/custom_operator.h
torch/csrc/jit/ir.cpp
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/alias_analysis.h

index da1ec34..c5185c1 100644 (file)
@@ -146,12 +146,17 @@ public:
     return is_varret_;
   }
   bool is_mutable() const {
-    return std::any_of(
-        arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) {
-          const auto& aliasInfo = arg.alias_info();
-          return aliasInfo && aliasInfo.value().isWrite();
-        });
+    // see [custom operator aliasing]
+    const auto kind = Symbol::fromQualString(name_);
+    const auto is_custom_op = !kind.is_aten() && !kind.is_prim();
+    return is_custom_op ||
+        std::any_of(
+            arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) {
+              const auto& aliasInfo = arg.alias_info();
+              return aliasInfo && aliasInfo.value().isWrite();
+            });
   }
+
   c10::optional<int> argumentIndexWithName(const std::string& name) const {
     for(size_t i = 0; i < arguments().size(); ++i) {
       if(name == arguments()[i].name())
index cc5b26f..7145c1a 100644 (file)
@@ -40,6 +40,7 @@ namespace jit {
   _(ControlFlow)                   \
   _(CreateAutodiffSubgraphs)       \
   _(CustomOperators)               \
+  _(CustomOperatorAliasing)        \
   _(Differentiate)                 \
   _(DifferentiateWithRequiresGrad) \
   _(DynamicDAG)                    \
index 69d0e77..2b61cfc 100644 (file)
@@ -454,11 +454,11 @@ void testAliasAnalysis() {
 }
 
 void testWriteTracking() {
-  RegisterOperators reg({createOperator(
-      "foo::creates_alias(Tensor(a) x) -> Tensor(a)",
-      [](at::Tensor a) { return a; })});
-  const auto creates_alias = Symbol::fromQualString("foo::creates_alias");
-  const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard");
+  RegisterOperators reg(
+      {Operator("prim::creates_alias(Tensor(a) x) -> Tensor(a)", [](Stack& s) {
+        return 0;
+      })});
+  const auto creates_alias = Symbol::fromQualString("prim::creates_alias");
   {
     auto graph = std::make_shared<Graph>();
     auto a = graph->addInput();
@@ -491,14 +491,16 @@ void testWriteTracking() {
 }
 
 void testWildcards() {
-  RegisterOperators reg({createOperator(
-                             "foo::returns_wildcard(Tensor a) -> Tensor(*)",
-                             [](at::Tensor a) { return a; }),
-                         createOperator(
-                             "foo::writes(Tensor(z!) a) -> Tensor(a)",
-                             [](at::Tensor a) { return a; })});
-  const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard");
-  const auto writes = Symbol::fromQualString("foo::writes");
+  RegisterOperators reg(
+      {Operator(
+           "prim::returns_wildcard(Tensor a) -> Tensor(*)",
+           [](Stack& stack) { return 0; }),
+       Operator("prim::writes(Tensor(z!) a) -> Tensor(a)", [](Stack& stack) {
+         return 0;
+       })});
+  const auto returns_wildcard =
+      Symbol::fromQualString("prim::returns_wildcard");
+  const auto writes = Symbol::fromQualString("prim::writes");
 
   auto graph = std::make_shared<Graph>();
   const auto a = graph->addInput();
index 87e72f4..bc0e071 100644 (file)
@@ -4,6 +4,9 @@
 #include "test/cpp/jit/test_utils.h"
 
 #include "torch/csrc/jit/custom_operator.h"
+#include "torch/csrc/jit/irparser.h"
+#include "torch/csrc/jit/passes/alias_analysis.h"
+#include "torch/csrc/jit/passes/dead_code_elimination.h"
 
 namespace torch {
 namespace jit {
@@ -206,6 +209,58 @@ void testCustomOperators() {
 
     tracer::abandon();
   }
+  {
+    // Try to create an op using a reserved namespace
+    ASSERT_THROWS_WITH(
+        createOperator(
+            "aten::op(float[] f) -> int",
+            [](const std::vector<double>& f) -> int64_t { return f.size(); }),
+        "Tried to register a custom operator to a reserved namespace");
+  }
+}
+
+void testCustomOperatorAliasing() {
+  RegisterOperators reg({createOperator(
+      "foo::aliasing", [](at::Tensor a, at::Tensor b) -> at::Tensor {
+        a.add_(b);
+        return a;
+      })});
+  auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::aliasing"));
+
+  {
+    auto graph = std::make_shared<Graph>();
+    script::parseIR(
+        R"IR(
+graph(%x: Tensor, %y: Tensor):
+  %ret : Tensor = foo::aliasing(%x, %y)
+  return (%ret)
+  )IR",
+        graph.get());
+
+    auto opNode = *graph->block()->nodes().begin();
+
+    AliasDb aliasDb(graph);
+    for (const auto input : opNode->inputs()) {
+      // The custom op writes to all its inputs
+      ASSERT_TRUE(aliasDb.writesToAlias(opNode, {input}));
+      // The output should be a wildcard and thus alias all inputs
+      ASSERT_TRUE(aliasDb.mayAlias(opNode->output(), input));
+    }
+  }
+  {
+    // DCE should not remove a custom op
+    auto graph = std::make_shared<Graph>();
+    const auto text = R"IR(
+graph(%x: Tensor, %y: Tensor):
+  # CHECK: foo::aliasing
+  %ret : Tensor = foo::aliasing(%x, %y)
+  return (%x)
+  )IR";
+    script::parseIR(text, graph.get());
+    EliminateDeadCode(graph);
+
+    testing::FileCheck().run(text, *graph);
+  }
 }
 } // namespace test
 } // namespace jit
index 7a314ad..8b93c21 100644 (file)
@@ -672,7 +672,7 @@ void testAutogradProfiler() {
 void testNoneSchemaMatch() {
   RegisterOperators reg({
       Operator(
-          "test::test_none() -> int?",
+          "prim::test_none() -> int?",
           [](const Node* node) {
             return [](Stack& stack) {
               push(stack, IValue());
@@ -680,7 +680,7 @@ void testNoneSchemaMatch() {
             };
           }),
       Operator(
-          "test::is_none(int? a) -> bool",
+          "prim::is_none(int? a) -> bool",
           [](const Node* node) {
             return [](Stack& stack) {
               IValue a = pop(stack);
@@ -700,8 +700,8 @@ void testNoneSchemaMatch() {
 
   auto r = std::make_shared<Graph>();
   auto& g = *r;
-  auto opt_int = g.insert(Symbol::fromQualString("test::test_none"), {});
-  auto out_bool = g.insert(Symbol::fromQualString("test::is_none"), {opt_int});
+  auto opt_int = g.insert(Symbol::fromQualString("prim::test_none"), {});
+  auto out_bool = g.insert(Symbol::fromQualString("prim::is_none"), {opt_int});
   g.registerOutput(out_bool);
   ConstantPropagation(r);
 
index 22b206a..73b7cbb 100644 (file)
@@ -181,6 +181,26 @@ Operator createOperator(
 
   auto schema = torch::jit::detail::inferAndCheckSchema<Traits>(schemaOrName);
 
+  // [custom operator aliasing] Currently, we have no way for the user to
+  // specify the alias annotations for a custom op. Therefore, we have to:
+  //   1. Assume that custom ops will mutate all inputs, have side effects, and
+  //      produce wildcard outputs.
+  //   2. Have some way of distinguishing between a custom op and a builtin op
+  //      so that we can apply the above rule.
+  // We do this by manually whitelisting "aten" and "prim" namespaces as
+  // builtins.
+  //
+  // We don't want to preserve this distinction between custom/builtin ops, as
+  // it is fragile and hard to maintain. When we provide a way for op
+  // registration to specify alias annotations, we should fix up builtins to
+  // use that and remove all references to this note.
+  Symbol name = Symbol::fromQualString(schema.name());
+  if (name.is_aten() || name.is_prim() || name.is_onnx()) {
+    AT_ERROR(
+        "Tried to register a custom operator to a reserved namespace: ",
+        name.ns().toUnqualString());
+  }
+
   return Operator(schema, [implementation, schema](Stack& stack) {
     ArgumentTuple tuple;
     torch::jit::detail::callOperatorWithTuple(
index b0ef049..9d9a768 100644 (file)
@@ -846,10 +846,17 @@ bool Node::hasSideEffects() const {
     case prim::SetAttr:
     case aten::warn:
     case prim::AddStatValue:
-     case prim::TimePoint:
+    case prim::TimePoint:
       return true;
   }
-  return false;
+  // All other builtin ops are known to be safe.
+  // see [custom operator aliasing]
+  if (kind_.is_aten() || kind_.is_prim() || kind_.is_onnx()) {
+    return false;
+  }
+
+  // Custom ops may have arbitrary side effects
+  return true;
 }
 
 // Assign this node a topological position, to facilitate fast isBefore() and
index 7220632..10f2414 100644 (file)
@@ -428,6 +428,11 @@ void AliasDb::analyzeImpl(Node* node) {
     }
   }
 
+  // see [custom operator aliasing]
+  if (!node->kind().is_aten() && !node->kind().is_prim()) {
+    return analyzeCustomOp(node);
+  }
+
   // Bind formal alias annotation to actual alias sets
   std::unordered_map<Symbol, Value*> formalToActual;
   for (size_t i = 0; i < schema.arguments().size(); i++) {
@@ -516,6 +521,11 @@ void AliasDb::analyzeImpl(Node* node) {
 }
 // Register the fact that `n` writes to `v`.
 void AliasDb::registerWrite(const Value* v, Node* n) {
+  if (!shouldAnnotate(v)) {
+    // don't need to register a write if the value isn't mutable
+    return;
+  }
+
   numWrites_++;
 
   if (isWildcard(v)) {
@@ -682,6 +692,19 @@ void AliasDb::analyzeSetAttr(Node* node) {
   registerWrite(self, node);
 }
 
+// Custom ops may write to any input and produce wildcards
+void AliasDb::analyzeCustomOp(Node* node) {
+  for (const auto input : node->inputs()) {
+    registerWrite(input, node);
+  }
+
+  // TODO(suo): we can make the more refined assumption that outputs may only
+  // alias any input.
+  for (const auto output : node->outputs()) {
+    setWildcard(output);
+  }
+}
+
 // BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
 // This is an intermediate node used only in the graph fuser.
 void AliasDb::analyzeBroadcastingChunk(Node* node) {
index bed1d63..2574e61 100644 (file)
@@ -182,6 +182,7 @@ class AliasDb {
   void analyzeFork(Node* node);
   void analyzeWait(Node* node);
   void analyzeSetAttr(Node* node);
+  void analyzeCustomOp(Node* node);
 
   /**
    * Alias manipulation methods