_(TopologicalMove) \
_(SubgraphUtils) \
_(AliasAnalysis) \
+ _(AliasRegistration) \
_(WriteTracking) \
_(Wildcards) \
_(MemoryDAG) \
ASSERT_FALSE(t.mayAlias(foo, baz));
}
}
+
+void testAliasRegistration() {
+ {
+ auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::DEFAULT);
+ RegisterOperators reg({createOperator(
+ "foo::rand",
+ [](at::Tensor) -> at::Tensor {
+ return at::rand({2, 2});
+ },
+ opts)});
+ const auto rand_op = Symbol::fromQualString("foo::rand");
+ auto graph = std::make_shared<Graph>();
+ auto a = graph->addInput();
+ auto b = graph->insert(rand_op, {a});
+ AliasDb aliasDb(graph);
+ // Conservatively we assume there is a reference
+ ASSERT_TRUE(aliasDb.mayAlias(a, b));
+ }
+ {
+ auto opts = OperatorOptions().aliasAnalysis(AliasAnalysisKind::PURE);
+ RegisterOperators reg({createOperator(
+ "foo::pure", [](at::Tensor t) -> at::Tensor { return t * 2; }, opts)});
+ const auto rand_op = Symbol::fromQualString("foo::pure");
+ auto graph = std::make_shared<Graph>();
+ auto a = graph->addInput();
+ auto b = graph->insert(rand_op, {a});
+ AliasDb aliasDb(graph);
+ // PURE means there is no reference
+ ASSERT_FALSE(aliasDb.mayAlias(a, b));
+ }
+}
+
} // namespace jit
} // namespace torch
template <typename Implementation>
Operator createOperator(
const std::string& schemaOrName,
- Implementation&& implementation) {
+ Implementation&& implementation,
+ OperatorOptions options = OperatorOptions()) {
using Traits = c10::guts::infer_function_traits_t<Implementation>;
using ArgumentTypes =
c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>;
name.ns().toUnqualString());
}
- return Operator(schema, [implementation, schema](Stack& stack) {
- ArgumentTuple tuple;
- torch::jit::detail::callOperatorWithTuple(
- schema,
- std::move(implementation), // NOLINT(bugprone-move-forwarding-reference)
- stack,
- tuple,
- typename MakeIndices<kNumberOfArguments>::indices{});
- return 0;
- });
+ return Operator(
+ schema,
+ [implementation, schema](Stack& stack) {
+ ArgumentTuple tuple;
+ torch::jit::detail::callOperatorWithTuple(
+ schema,
+ std::move(
+ implementation), // NOLINT(bugprone-move-forwarding-reference)
+ stack,
+ tuple,
+ typename MakeIndices<kNumberOfArguments>::indices{});
+ return 0;
+ },
+ std::move(options));
}
/// Registration class for new operators. Effectively calls
template <typename Implementation>
RegisterOperators& op(
const std::string& name,
- Implementation&& implementation) {
- registerOperator(
- createOperator(name, std::forward<Implementation>(implementation)));
+ Implementation&& implementation,
+ OperatorOptions options = OperatorOptions()) {
+ registerOperator(createOperator(
+ name, std::forward<Implementation>(implementation), options));
return *this;
}
};
op.schema().name(),
". File a bug to add a case for this operator.\n");
}
- if (!aliasAnalysisHasSpecialCaseFor(s)) {
+ if (!aliasAnalysisHasSpecialCaseFor(s) &&
+ op.options().aliasAnalysis() == AliasAnalysisKind::DEFAULT) {
AT_ERROR(
"Missing special case in alias analysis for non-schematized"
" operator ",
// it now to implement correct semantic checking for script
#pragma once
+#include <ATen/core/stack.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/ir.h>
-#include <ATen/core/stack.h>
+#include <torch/csrc/jit/operator_options.h>
#include <ATen/ATen.h>
#include <ATen/core/function_schema.h>
*/
struct TORCH_API Operator {
- Operator(FunctionSchema schema, OperationCreator op_creator)
+ Operator(
+ FunctionSchema schema,
+ OperationCreator op_creator,
+ OperatorOptions options = OperatorOptions())
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
- op_creator_(std::move(op_creator)) {}
+ op_creator_(std::move(op_creator)),
+ options_(std::move(options)) {}
- Operator(const std::string& schema, OperationCreator op_creator)
- : schema_string_(schema), op_creator_(std::move(op_creator)) {}
+ Operator(
+ const std::string& schema,
+ OperationCreator op_creator,
+ OperatorOptions options = OperatorOptions())
+ : schema_string_(schema),
+ op_creator_(std::move(op_creator)),
+ options_(std::move(options)) {}
// Helper constructor to register `op` to run
// run for _every_ IR Node where n.kind() == name, regardless of arguments.
// This is accomplished by marking the schema varargs and having no required
// arguments. This is used for things like prim::While or prim::If that can
// take a number of different valid input types and lengths.
- Operator(Symbol name, OperationCreator op_creator)
+ Operator(
+ Symbol name,
+ OperationCreator op_creator,
+ OperatorOptions options = OperatorOptions())
: Operator(
FunctionSchema(
name,
{},
/*is_vararg*/ true,
/*is_varret*/ true),
- std::move(op_creator)) {}
+ std::move(op_creator),
+ std::move(options)) {}
- Operator(FunctionSchema schema, Operation op)
+ Operator(
+ FunctionSchema schema,
+ Operation op,
+ OperatorOptions options = OperatorOptions())
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
- op_(std::make_shared<Operation>(std::move(op))) {}
+ op_(std::make_shared<Operation>(std::move(op))),
+ options_(std::move(options)) {}
- Operator(const std::string& schema, Operation op)
+ Operator(
+ const std::string& schema,
+ Operation op,
+ OperatorOptions options = OperatorOptions())
: schema_string_(schema),
- op_(std::make_shared<Operation>(std::move(op))) {}
+ op_(std::make_shared<Operation>(std::move(op))),
+ options_(std::move(options)) {}
bool matches(const Node* node) const;
return *schema_;
}
+ const OperatorOptions& options() const {
+ return options_;
+ }
+
private:
mutable c10::optional<std::string> schema_string_;
// cannot use c10::optional because windows has issues that require an
// NB: std::function has a default state (where it == nullptr).
std::shared_ptr<Operation> op_;
OperationCreator op_creator_;
+ OperatorOptions options_;
};
TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
--- /dev/null
+#pragma once
+
+#include <torch/csrc/jit/passes/alias_analysis.h>
+
+namespace torch {
+namespace jit {
+
+enum class AliasAnalysisKind {
+ DEFAULT, // The most conservative alias analysis type, assumes side-effects
+ PURE
+};
+
+struct OperatorOptions {
+ OperatorOptions(){};
+
+ OperatorOptions aliasAnalysis(AliasAnalysisKind aak) const noexcept {
+ OperatorOptions r = *this;
+ r.aliasAnalysisKind_ = aak;
+ return r;
+ }
+
+ const AliasAnalysisKind& aliasAnalysis() const {
+ return aliasAnalysisKind_;
+ }
+ AliasAnalysisKind aliasAnalysisKind_ = AliasAnalysisKind::DEFAULT;
+};
+
+} // namespace jit
+} // namespace torch
#include <torch/csrc/jit/passes/alias_analysis.h>
+#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/utils/memory.h>
}
}
+// Returns true if analysis was run using
+// the registered analyzer.
+bool AliasDb::tryRegisteredAnalysis(Node* node) {
+ const Operator& op = getOperatorFor(node);
+ auto analysis = op.options().aliasAnalysis();
+ switch (analysis) {
+ case AliasAnalysisKind::PURE:
+ analyzeCreator(node);
+ return true;
+ case AliasAnalysisKind::DEFAULT:
+ return false;
+ }
+ return false;
+}
+
// The basic strategy is:
// 1. Retrieve alias information for every input.
// 2. Use the node's schema's alias annotations to propgagate alias/write
// These ops do nothing
return;
default:
+ if (tryRegisteredAnalysis(node)) {
+ return;
+ }
AT_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind()));
}
void analyzeWait(Node* node);
void analyzeSetAttr(Node* node);
void analyzeCustomOp(Node* node);
+ bool tryRegisteredAnalysis(Node* node);
/**
* Alias manipulation methods