From fefa6d305ea3e820afe64cec015d2f6746d9ca88 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Fri, 5 Apr 2019 10:40:19 -0700 Subject: [PATCH] fix side-effects and aliasing for custom ops (#18711) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18711 ghimport-source-id: c9caedc0660b2b7ba3730cd0e1a2e0e9c3cf422b Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18711 [jit] fix side-effects and aliasing for custom ops** Previously we didn't track aliasing, mutation, or side effects for custom ops. This PR adds in guards with the most conservative assumptions possible: the op will 1) have side effects, 2) write to everything 3) produce a wildcard. In order to tell whether a given operator is a custom op, this PR introduces the concept of a "reserved" namespace (basically all our builtin namespaces). Custom ops live in non-reserved namespaces, so a check on the namespace is sufficient to tell whether a schema/node is "custom" or not. This is just to get things correct for now. Follow-ups to this: - Users should be able to specify aliasing/mutability without having to learn the whole alias annotation schema. - Relax assumptions a bit. In particular outputs can only alias input tensors, they don't have to be wildcards. Fixes #18490 Differential Revision: D14730978 fbshipit-source-id: 540b47a24ccf24145051609bdcc99c97e46e0fe0 --- aten/src/ATen/core/function_schema.h | 15 ++++++--- test/cpp/jit/test.cpp | 1 + test/cpp/jit/test_alias_analysis.h | 28 ++++++++-------- test/cpp/jit/test_custom_operators.h | 55 ++++++++++++++++++++++++++++++++ test/cpp/jit/test_misc.h | 8 ++--- torch/csrc/jit/custom_operator.h | 20 ++++++++++++ torch/csrc/jit/ir.cpp | 11 +++++-- torch/csrc/jit/passes/alias_analysis.cpp | 23 +++++++++++++ torch/csrc/jit/passes/alias_analysis.h | 1 + 9 files changed, 138 insertions(+), 24 deletions(-) diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index da1ec34..c5185c1 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -146,12 +146,17 @@ public: return is_varret_; } bool is_mutable() const { - return std::any_of( - arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) { - const auto& aliasInfo = arg.alias_info(); - return aliasInfo && aliasInfo.value().isWrite(); - }); + // see [custom operator aliasing] + const auto kind = Symbol::fromQualString(name_); + const auto is_custom_op = !kind.is_aten() && !kind.is_prim(); + return is_custom_op || + std::any_of( + arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) { + const auto& aliasInfo = arg.alias_info(); + return aliasInfo && aliasInfo.value().isWrite(); + }); } + c10::optional argumentIndexWithName(const std::string& name) const { for(size_t i = 0; i < arguments().size(); ++i) { if(name == arguments()[i].name()) diff --git a/test/cpp/jit/test.cpp b/test/cpp/jit/test.cpp index cc5b26f..7145c1a 100644 --- a/test/cpp/jit/test.cpp +++ b/test/cpp/jit/test.cpp @@ -40,6 +40,7 @@ namespace jit { _(ControlFlow) \ _(CreateAutodiffSubgraphs) \ _(CustomOperators) \ + _(CustomOperatorAliasing) \ _(Differentiate) \ _(DifferentiateWithRequiresGrad) \ _(DynamicDAG) \ diff --git a/test/cpp/jit/test_alias_analysis.h b/test/cpp/jit/test_alias_analysis.h index 69d0e77..2b61cfc 100644 --- a/test/cpp/jit/test_alias_analysis.h +++ b/test/cpp/jit/test_alias_analysis.h @@ -454,11 +454,11 @@ void testAliasAnalysis() { } void testWriteTracking() { - RegisterOperators reg({createOperator( - "foo::creates_alias(Tensor(a) x) -> Tensor(a)", - [](at::Tensor a) { return a; })}); - const auto creates_alias = Symbol::fromQualString("foo::creates_alias"); - const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard"); + RegisterOperators reg( + {Operator("prim::creates_alias(Tensor(a) x) -> Tensor(a)", [](Stack& s) { + return 0; + })}); + const auto creates_alias = Symbol::fromQualString("prim::creates_alias"); { auto graph = std::make_shared(); auto a = graph->addInput(); @@ -491,14 +491,16 @@ void testWriteTracking() { } void testWildcards() { - RegisterOperators reg({createOperator( - "foo::returns_wildcard(Tensor a) -> Tensor(*)", - [](at::Tensor a) { return a; }), - createOperator( - "foo::writes(Tensor(z!) a) -> Tensor(a)", - [](at::Tensor a) { return a; })}); - const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard"); - const auto writes = Symbol::fromQualString("foo::writes"); + RegisterOperators reg( + {Operator( + "prim::returns_wildcard(Tensor a) -> Tensor(*)", + [](Stack& stack) { return 0; }), + Operator("prim::writes(Tensor(z!) a) -> Tensor(a)", [](Stack& stack) { + return 0; + })}); + const auto returns_wildcard = + Symbol::fromQualString("prim::returns_wildcard"); + const auto writes = Symbol::fromQualString("prim::writes"); auto graph = std::make_shared(); const auto a = graph->addInput(); diff --git a/test/cpp/jit/test_custom_operators.h b/test/cpp/jit/test_custom_operators.h index 87e72f4..bc0e071 100644 --- a/test/cpp/jit/test_custom_operators.h +++ b/test/cpp/jit/test_custom_operators.h @@ -4,6 +4,9 @@ #include "test/cpp/jit/test_utils.h" #include "torch/csrc/jit/custom_operator.h" +#include "torch/csrc/jit/irparser.h" +#include "torch/csrc/jit/passes/alias_analysis.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" namespace torch { namespace jit { @@ -206,6 +209,58 @@ void testCustomOperators() { tracer::abandon(); } + { + // Try to create an op using a reserved namespace + ASSERT_THROWS_WITH( + createOperator( + "aten::op(float[] f) -> int", + [](const std::vector& f) -> int64_t { return f.size(); }), + "Tried to register a custom operator to a reserved namespace"); + } +} + +void testCustomOperatorAliasing() { + RegisterOperators reg({createOperator( + "foo::aliasing", [](at::Tensor a, at::Tensor b) -> at::Tensor { + a.add_(b); + return a; + })}); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::aliasing")); + + { + auto graph = std::make_shared(); + script::parseIR( + R"IR( +graph(%x: Tensor, %y: Tensor): + %ret : Tensor = foo::aliasing(%x, %y) + return (%ret) + )IR", + graph.get()); + + auto opNode = *graph->block()->nodes().begin(); + + AliasDb aliasDb(graph); + for (const auto input : opNode->inputs()) { + // The custom op writes to all its inputs + ASSERT_TRUE(aliasDb.writesToAlias(opNode, {input})); + // The output should be a wildcard and thus alias all inputs + ASSERT_TRUE(aliasDb.mayAlias(opNode->output(), input)); + } + } + { + // DCE should not remove a custom op + auto graph = std::make_shared(); + const auto text = R"IR( +graph(%x: Tensor, %y: Tensor): + # CHECK: foo::aliasing + %ret : Tensor = foo::aliasing(%x, %y) + return (%x) + )IR"; + script::parseIR(text, graph.get()); + EliminateDeadCode(graph); + + testing::FileCheck().run(text, *graph); + } } } // namespace test } // namespace jit diff --git a/test/cpp/jit/test_misc.h b/test/cpp/jit/test_misc.h index 7a314ad..8b93c21 100644 --- a/test/cpp/jit/test_misc.h +++ b/test/cpp/jit/test_misc.h @@ -672,7 +672,7 @@ void testAutogradProfiler() { void testNoneSchemaMatch() { RegisterOperators reg({ Operator( - "test::test_none() -> int?", + "prim::test_none() -> int?", [](const Node* node) { return [](Stack& stack) { push(stack, IValue()); @@ -680,7 +680,7 @@ void testNoneSchemaMatch() { }; }), Operator( - "test::is_none(int? a) -> bool", + "prim::is_none(int? a) -> bool", [](const Node* node) { return [](Stack& stack) { IValue a = pop(stack); @@ -700,8 +700,8 @@ void testNoneSchemaMatch() { auto r = std::make_shared(); auto& g = *r; - auto opt_int = g.insert(Symbol::fromQualString("test::test_none"), {}); - auto out_bool = g.insert(Symbol::fromQualString("test::is_none"), {opt_int}); + auto opt_int = g.insert(Symbol::fromQualString("prim::test_none"), {}); + auto out_bool = g.insert(Symbol::fromQualString("prim::is_none"), {opt_int}); g.registerOutput(out_bool); ConstantPropagation(r); diff --git a/torch/csrc/jit/custom_operator.h b/torch/csrc/jit/custom_operator.h index 22b206a..73b7cbb 100644 --- a/torch/csrc/jit/custom_operator.h +++ b/torch/csrc/jit/custom_operator.h @@ -181,6 +181,26 @@ Operator createOperator( auto schema = torch::jit::detail::inferAndCheckSchema(schemaOrName); + // [custom operator aliasing] Currently, we have no way for the user to + // specify the alias annotations for a custom op. Therefore, we have to: + // 1. Assume that custom ops will mutate all inputs, have side effects, and + // produce wildcard outputs. + // 2. Have some way of distinguishing between a custom op and a builtin op + // so that we can apply the above rule. + // We do this by manually whitelisting "aten" and "prim" namespaces as + // builtins. + // + // We don't want to preserve this distinction between custom/builtin ops, as + // it is fragile and hard to maintain. When we provide a way for op + // registration to specify alias annotations, we should fix up builtins to + // use that and remove all references to this note. + Symbol name = Symbol::fromQualString(schema.name()); + if (name.is_aten() || name.is_prim() || name.is_onnx()) { + AT_ERROR( + "Tried to register a custom operator to a reserved namespace: ", + name.ns().toUnqualString()); + } + return Operator(schema, [implementation, schema](Stack& stack) { ArgumentTuple tuple; torch::jit::detail::callOperatorWithTuple( diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index b0ef049..9d9a768 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -846,10 +846,17 @@ bool Node::hasSideEffects() const { case prim::SetAttr: case aten::warn: case prim::AddStatValue: - case prim::TimePoint: + case prim::TimePoint: return true; } - return false; + // All other builtin ops are known to be safe. + // see [custom operator aliasing] + if (kind_.is_aten() || kind_.is_prim() || kind_.is_onnx()) { + return false; + } + + // Custom ops may have arbitrary side effects + return true; } // Assign this node a topological position, to facilitate fast isBefore() and diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index 7220632..10f2414 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -428,6 +428,11 @@ void AliasDb::analyzeImpl(Node* node) { } } + // see [custom operator aliasing] + if (!node->kind().is_aten() && !node->kind().is_prim()) { + return analyzeCustomOp(node); + } + // Bind formal alias annotation to actual alias sets std::unordered_map formalToActual; for (size_t i = 0; i < schema.arguments().size(); i++) { @@ -516,6 +521,11 @@ void AliasDb::analyzeImpl(Node* node) { } // Register the fact that `n` writes to `v`. void AliasDb::registerWrite(const Value* v, Node* n) { + if (!shouldAnnotate(v)) { + // don't need to register a write if the value isn't mutable + return; + } + numWrites_++; if (isWildcard(v)) { @@ -682,6 +692,19 @@ void AliasDb::analyzeSetAttr(Node* node) { registerWrite(self, node); } +// Custom ops may write to any input and produce wildcards +void AliasDb::analyzeCustomOp(Node* node) { + for (const auto input : node->inputs()) { + registerWrite(input, node); + } + + // TODO(suo): we can make the more refined assumption that outputs may only + // alias any input. + for (const auto output : node->outputs()) { + setWildcard(output); + } +} + // BroadcastingChunk: all inputs are broadcasted, and then individually chunked. // This is an intermediate node used only in the graph fuser. void AliasDb::analyzeBroadcastingChunk(Node* node) { diff --git a/torch/csrc/jit/passes/alias_analysis.h b/torch/csrc/jit/passes/alias_analysis.h index bed1d63..2574e61 100644 --- a/torch/csrc/jit/passes/alias_analysis.h +++ b/torch/csrc/jit/passes/alias_analysis.h @@ -182,6 +182,7 @@ class AliasDb { void analyzeFork(Node* node); void analyzeWait(Node* node); void analyzeSetAttr(Node* node); + void analyzeCustomOp(Node* node); /** * Alias manipulation methods -- 2.7.4