#include <condition_variable>
#include <type_traits>
+#include <ATen/core/blob.h>
+#include <ATen/core/interned_strings.h>
#include <c10/core/Scalar.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/UndefinedTensorImpl.h>
-#include <ATen/core/blob.h>
#include <c10/util/intrusive_ptr.h>
#include <ATen/core/Tensor.h>
#pragma once
-#include <ATen/core/ivalue.h>
-#include <ATen/core/interned_strings.h>
-#include <ATen/core/functional.h>
-#include <ATen/core/Type.h>
#include <ATen/core/TensorMethods.h>
+#include <ATen/core/Type.h>
+#include <ATen/core/functional.h>
+#include <ATen/core/interned_strings.h>
+#include <ATen/core/ivalue.h>
#include <c10/util/TypeList.h>
#include <caffe2/core/common.h>
for (const auto& pr : tupleTypes) {
makeAllAlias(pr.second, *aliasTracker_);
}
+ for (const auto& pr : dictTypes) {
+ makeAllAlias(pr.second, *aliasTracker_);
+ }
makeAllAlias(tensors, *aliasTracker_);
analyze(graph->block());
AT_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind()));
}
-
const auto& schema = node->schema();
if (schema.is_vararg() || schema.is_varret()) {
const auto hasMutableOutputs = std::any_of(
// WARNING: by adding a case to this list, you are asserting that you have
// added a case for the unschematized node in AliasDb::analyze
const static std::unordered_set<Symbol> handled = {
- prim::If,
- prim::Loop,
- prim::FusionGroup,
- prim::DifferentiableGraph,
- prim::Constant,
- prim::DictConstruct,
- prim::ListConstruct,
- prim::TupleConstruct,
- prim::Undefined,
- prim::FusedConcat,
- prim::MMTreeReduce,
- prim::MMBatchSide,
- prim::None,
- prim::BroadcastSizes,
- prim::ChunkSizes,
- prim::Function,
- prim::TupleUnpack,
- prim::TupleIndex,
- prim::DictIndex,
- prim::TupleSlice,
- prim::ListUnpack,
- prim::PythonOp,
- prim::ConstantChunk,
- prim::BroadcastingChunk,
- aten::add,
- aten::sub,
- aten::mul,
- aten::div,
+ prim::If,
+ prim::Loop,
+ prim::FusionGroup,
+ prim::DifferentiableGraph,
+ prim::Constant,
+ prim::DictConstruct,
+ prim::ListConstruct,
+ prim::TupleConstruct,
+ prim::Undefined,
+ prim::FusedConcat,
+ prim::MMTreeReduce,
+ prim::MMBatchSide,
+ prim::None,
+ prim::BroadcastSizes,
+ prim::ChunkSizes,
+ prim::Function,
+ prim::TupleUnpack,
+ prim::TupleIndex,
+ prim::DictIndex,
+ prim::TupleSlice,
+ prim::ListUnpack,
+ prim::PythonOp,
+ prim::ConstantChunk,
+ prim::BroadcastingChunk,
+ prim::fork,
+ aten::wait,
+ aten::add,
+ aten::sub,
+ aten::mul,
+ aten::div,
};
// Operators that should not be used by alias analysis
const static std::unordered_set<Symbol> purposefully_not_handled = {
- prim::Print,
- prim::Load,
- prim::Store,
- prim::Drop,
- at::onnx::Reshape,
- at::onnx::Shape,
- prim::AnyDefined,
- prim::AutogradAdd,
- prim::fork, // TODO: fork aliasing / futures
+ prim::Print,
+ prim::Load,
+ prim::Store,
+ prim::Drop,
+ at::onnx::Reshape,
+ at::onnx::Shape,
+ prim::AnyDefined,
+ prim::AutogradAdd,
};
return handled.count(symbol) || purposefully_not_handled.count(symbol);
-#include <torch/csrc/jit/passes/python_print.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/attributes.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/ir_views.h>
+#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/jit/script/module.h>
std::cout << wildcard->uniqueName() << ", ";
}
std::cout << "\n";
+
+ std::cout << "\n===4. Writes===\n";
+ for (const auto& pr : writeIndex_) {
+ const auto node = pr.first;
+ const auto& values = pr.second;
+ std::cout << *node;
+ std::cout << " ";
+ for (const auto value : values) {
+ std::cout << value->uniqueName() << ", ";
+ }
+ std::cout << "\n";
+ }
+ std::cout << "\n";
}
std::unordered_set<const AliasTracker::Element*> AliasTracker::Element::
-#include <torch/csrc/jit/script/compiler.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/hooks_for_testing.h>
#include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/script/final_returns.h>
#include <torch/csrc/jit/script/parser.h>
#include <torch/csrc/jit/script/schema_matching.h>
return v;
}
-std::shared_ptr<SugaredValue> BuiltinFunction::call(
- const SourceRange& loc,
- Method& m,
- at::ArrayRef<NamedValue> inputs,
- at::ArrayRef<NamedValue> attributes,
- size_t n_binders) {
- return std::make_shared<SimpleValue>(
- emitBuiltinCall(loc, *m.graph(), symbol, self, inputs, attributes, true));
-}
-
inline bool isSupportedListElementType(const TypePtr& type) {
return type->isSubtypeOf(TensorType::get()) ||
type->isSubtypeOf(NumberType::get());
}
method.setSchema(emitDef(def, self, graph->block()));
+
runCleanupPasses(graph);
}
#include <torch/csrc/Layout.h>
#include <torch/csrc/jit/import.h>
#include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/jit/script/schema_matching.h>
#include <torch/csrc/jit/script/sugared_value.h>
-#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/csrc/jit/constants.h>
}
};
-
struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
explicit ConstantPythonTupleValue(py::object tup)
: PythonValue(std::move(tup)) {}
.def(
"_copy_method",
[](std::shared_ptr<Module> m,
- std::string name,
- std::vector<std::tuple<std::shared_ptr<Module>, std::string>> params,
- std::shared_ptr<Module> orig) {
- std::vector<at::Tensor*> member_inputs;
- for (auto& p : params) {
- NamedParameter* np = std::get<0>(p)->find_parameter(std::get<1>(p));
- AT_ASSERT(np != nullptr);
- member_inputs.push_back(np->slot());
- }
-
- Method* orig_method = orig->find_method(name);
- m->create_method(name, orig_method->graph()->copy(), member_inputs);
+ std::string name,
+ std::vector<std::tuple<std::shared_ptr<Module>, std::string>>
+ params,
+ std::shared_ptr<Module> orig) {
+ std::vector<at::Tensor*> member_inputs;
+ for (auto& p : params) {
+ NamedParameter* np =
+ std::get<0>(p)->find_parameter(std::get<1>(p));
+ AT_ASSERT(np != nullptr);
+ member_inputs.push_back(np->slot());
+ }
+
+ Method* orig_method = orig->find_method(name);
+ m->create_method(name, orig_method->graph()->copy(), member_inputs);
});
py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/script/schema_matching.h>
#include <torch/csrc/jit/script/sugared_value.h>
#include <torch/csrc/jit/script/tree_views.h>
#include <torch/csrc/jit/script/type_parser.h>
return builtin_cast_methods;
}
+std::shared_ptr<SugaredValue> BuiltinFunction::call(
+ const SourceRange& loc,
+ Method& m,
+ at::ArrayRef<NamedValue> inputs,
+ at::ArrayRef<NamedValue> attributes,
+ size_t n_binders) {
+ return std::make_shared<SimpleValue>(
+ emitBuiltinCall(loc, *m.graph(), symbol, self, inputs, attributes, true));
+}
+
// support syntax sugar for x.foo(y, z) by allowing x.foo to return a
// callable value that will resolve to foo(x, y, z) when called.
std::shared_ptr<SugaredValue> SimpleValue::attr(
auto names = tuple_type->names();
for (int i = 0; i < names.size(); i++) {
if (names[i] == field) {
- auto r = m.graph()->insertNode(m.graph()->createTupleIndex(getValue(), i))->output();
+ auto r = m.graph()
+ ->insertNode(m.graph()->createTupleIndex(getValue(), i))
+ ->output();
return std::make_shared<SimpleValue>(r);
}
}
const SourceRange& range,
const List<Expr>& keys,
const List<Expr>& values) {
- return DictLiteral(Compound::create(TK_DICT_LITERAL, range, {keys, values}));
+ return DictLiteral(
+ Compound::create(TK_DICT_LITERAL, range, {keys, values}));
}
};