Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18711
ghimport-source-id:
c9caedc0660b2b7ba3730cd0e1a2e0e9c3cf422b
Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18711 [jit] fix side-effects and aliasing for custom ops**
Previously we didn't track aliasing, mutation, or side effects for
custom ops. This PR adds in guards with the most conservative
assumptions possible: the op will
1) have side effects,
2) write to everything
3) produce a wildcard.
In order to tell whether a given operator is a custom op, this PR introduces
the concept of a "reserved" namespace (basically all our builtin namespaces).
Custom ops live in non-reserved namespaces, so a check on the namespace
is sufficient to tell whether a schema/node is "custom" or not.
This is just to get things correct for now. Follow-ups to this:
- Users should be able to specify aliasing/mutability without having to learn
the whole alias annotation schema.
- Relax assumptions a bit. In particular outputs can only alias input tensors,
they don't have to be wildcards.
Fixes #18490
Differential Revision:
D14730978
fbshipit-source-id:
540b47a24ccf24145051609bdcc99c97e46e0fe0
return is_varret_;
}
bool is_mutable() const {
- return std::any_of(
- arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) {
- const auto& aliasInfo = arg.alias_info();
- return aliasInfo && aliasInfo.value().isWrite();
- });
+ // see [custom operator aliasing]
+ const auto kind = Symbol::fromQualString(name_);
+ const auto is_custom_op = !kind.is_aten() && !kind.is_prim();
+ return is_custom_op ||
+ std::any_of(
+ arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) {
+ const auto& aliasInfo = arg.alias_info();
+ return aliasInfo && aliasInfo.value().isWrite();
+ });
}
+
c10::optional<int> argumentIndexWithName(const std::string& name) const {
for(size_t i = 0; i < arguments().size(); ++i) {
if(name == arguments()[i].name())
_(ControlFlow) \
_(CreateAutodiffSubgraphs) \
_(CustomOperators) \
+ _(CustomOperatorAliasing) \
_(Differentiate) \
_(DifferentiateWithRequiresGrad) \
_(DynamicDAG) \
}
void testWriteTracking() {
- RegisterOperators reg({createOperator(
- "foo::creates_alias(Tensor(a) x) -> Tensor(a)",
- [](at::Tensor a) { return a; })});
- const auto creates_alias = Symbol::fromQualString("foo::creates_alias");
- const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard");
+ RegisterOperators reg(
+ {Operator("prim::creates_alias(Tensor(a) x) -> Tensor(a)", [](Stack& s) {
+ return 0;
+ })});
+ const auto creates_alias = Symbol::fromQualString("prim::creates_alias");
{
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
}
void testWildcards() {
- RegisterOperators reg({createOperator(
- "foo::returns_wildcard(Tensor a) -> Tensor(*)",
- [](at::Tensor a) { return a; }),
- createOperator(
- "foo::writes(Tensor(z!) a) -> Tensor(a)",
- [](at::Tensor a) { return a; })});
- const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard");
- const auto writes = Symbol::fromQualString("foo::writes");
+ RegisterOperators reg(
+ {Operator(
+ "prim::returns_wildcard(Tensor a) -> Tensor(*)",
+ [](Stack& stack) { return 0; }),
+ Operator("prim::writes(Tensor(z!) a) -> Tensor(a)", [](Stack& stack) {
+ return 0;
+ })});
+ const auto returns_wildcard =
+ Symbol::fromQualString("prim::returns_wildcard");
+ const auto writes = Symbol::fromQualString("prim::writes");
auto graph = std::make_shared<Graph>();
const auto a = graph->addInput();
#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/custom_operator.h"
+#include "torch/csrc/jit/irparser.h"
+#include "torch/csrc/jit/passes/alias_analysis.h"
+#include "torch/csrc/jit/passes/dead_code_elimination.h"
namespace torch {
namespace jit {
tracer::abandon();
}
+ {
+ // Try to create an op using a reserved namespace
+ ASSERT_THROWS_WITH(
+ createOperator(
+ "aten::op(float[] f) -> int",
+ [](const std::vector<double>& f) -> int64_t { return f.size(); }),
+ "Tried to register a custom operator to a reserved namespace");
+ }
+}
+
+void testCustomOperatorAliasing() {
+ RegisterOperators reg({createOperator(
+ "foo::aliasing", [](at::Tensor a, at::Tensor b) -> at::Tensor {
+ a.add_(b);
+ return a;
+ })});
+ auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::aliasing"));
+
+ {
+ auto graph = std::make_shared<Graph>();
+ script::parseIR(
+ R"IR(
+graph(%x: Tensor, %y: Tensor):
+ %ret : Tensor = foo::aliasing(%x, %y)
+ return (%ret)
+ )IR",
+ graph.get());
+
+ auto opNode = *graph->block()->nodes().begin();
+
+ AliasDb aliasDb(graph);
+ for (const auto input : opNode->inputs()) {
+ // The custom op writes to all its inputs
+ ASSERT_TRUE(aliasDb.writesToAlias(opNode, {input}));
+ // The output should be a wildcard and thus alias all inputs
+ ASSERT_TRUE(aliasDb.mayAlias(opNode->output(), input));
+ }
+ }
+ {
+ // DCE should not remove a custom op
+ auto graph = std::make_shared<Graph>();
+ const auto text = R"IR(
+graph(%x: Tensor, %y: Tensor):
+ # CHECK: foo::aliasing
+ %ret : Tensor = foo::aliasing(%x, %y)
+ return (%x)
+ )IR";
+ script::parseIR(text, graph.get());
+ EliminateDeadCode(graph);
+
+ testing::FileCheck().run(text, *graph);
+ }
}
} // namespace test
} // namespace jit
void testNoneSchemaMatch() {
RegisterOperators reg({
Operator(
- "test::test_none() -> int?",
+ "prim::test_none() -> int?",
[](const Node* node) {
return [](Stack& stack) {
push(stack, IValue());
};
}),
Operator(
- "test::is_none(int? a) -> bool",
+ "prim::is_none(int? a) -> bool",
[](const Node* node) {
return [](Stack& stack) {
IValue a = pop(stack);
auto r = std::make_shared<Graph>();
auto& g = *r;
- auto opt_int = g.insert(Symbol::fromQualString("test::test_none"), {});
- auto out_bool = g.insert(Symbol::fromQualString("test::is_none"), {opt_int});
+ auto opt_int = g.insert(Symbol::fromQualString("prim::test_none"), {});
+ auto out_bool = g.insert(Symbol::fromQualString("prim::is_none"), {opt_int});
g.registerOutput(out_bool);
ConstantPropagation(r);
auto schema = torch::jit::detail::inferAndCheckSchema<Traits>(schemaOrName);
+ // [custom operator aliasing] Currently, we have no way for the user to
+ // specify the alias annotations for a custom op. Therefore, we have to:
+ // 1. Assume that custom ops will mutate all inputs, have side effects, and
+ // produce wildcard outputs.
+ // 2. Have some way of distinguishing between a custom op and a builtin op
+ // so that we can apply the above rule.
+ // We do this by manually whitelisting "aten" and "prim" namespaces as
+ // builtins.
+ //
+ // We don't want to preserve this distinction between custom/builtin ops, as
+ // it is fragile and hard to maintain. When we provide a way for op
+ // registration to specify alias annotations, we should fix up builtins to
+ // use that and remove all references to this note.
+ Symbol name = Symbol::fromQualString(schema.name());
+ if (name.is_aten() || name.is_prim() || name.is_onnx()) {
+ AT_ERROR(
+ "Tried to register a custom operator to a reserved namespace: ",
+ name.ns().toUnqualString());
+ }
+
return Operator(schema, [implementation, schema](Stack& stack) {
ArgumentTuple tuple;
torch::jit::detail::callOperatorWithTuple(
case prim::SetAttr:
case aten::warn:
case prim::AddStatValue:
- case prim::TimePoint:
+ case prim::TimePoint:
return true;
}
- return false;
+ // All other builtin ops are known to be safe.
+ // see [custom operator aliasing]
+ if (kind_.is_aten() || kind_.is_prim() || kind_.is_onnx()) {
+ return false;
+ }
+
+ // Custom ops may have arbitrary side effects
+ return true;
}
// Assign this node a topological position, to facilitate fast isBefore() and
}
}
+ // see [custom operator aliasing]
+ if (!node->kind().is_aten() && !node->kind().is_prim()) {
+ return analyzeCustomOp(node);
+ }
+
// Bind formal alias annotation to actual alias sets
std::unordered_map<Symbol, Value*> formalToActual;
for (size_t i = 0; i < schema.arguments().size(); i++) {
}
// Register the fact that `n` writes to `v`.
void AliasDb::registerWrite(const Value* v, Node* n) {
+ if (!shouldAnnotate(v)) {
+ // don't need to register a write if the value isn't mutable
+ return;
+ }
+
numWrites_++;
if (isWildcard(v)) {
registerWrite(self, node);
}
+// Custom ops may write to any input and produce wildcards
+void AliasDb::analyzeCustomOp(Node* node) {
+ for (const auto input : node->inputs()) {
+ registerWrite(input, node);
+ }
+
+ // TODO(suo): we can make the more refined assumption that outputs may only
+ // alias any input.
+ for (const auto output : node->outputs()) {
+ setWildcard(output);
+ }
+}
+
// BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
// This is an intermediate node used only in the graph fuser.
void AliasDb::analyzeBroadcastingChunk(Node* node) {
void analyzeFork(Node* node);
void analyzeWait(Node* node);
void analyzeSetAttr(Node* node);
+ void analyzeCustomOp(Node* node);
/**
* Alias manipulation methods