}
return c10::nullopt;
}
- FunctionSchema cloneWithArguments(std::vector<Argument> new_arguments) const {
- return FunctionSchema(
- name(),
- overload_name(),
- std::move(new_arguments),
- returns(),
- is_vararg(),
- is_varret());
- }
- // Check that inputs have the correct types and appends any missing default
- // values.
- void checkAndNormalizeInputs(std::vector<IValue>& inputs) const;
};
inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) {
return str.str();
}
-inline void FunctionSchema::checkAndNormalizeInputs(std::vector<IValue>& inputs) const {
- // Do we have more inputs than the schema accepts?
- AT_CHECK(
- inputs.size() <= arguments().size(),
- "Expected at most ",
- arguments().size(),
- " argument(s) for operator '",
- name(),
- "', but received ",
- inputs.size(),
- " argument(s). Declaration: ",
- *this);
-
- for (size_t pos = 0; pos < arguments().size(); ++pos) {
- const auto& argument = arguments()[pos];
- if (pos < inputs.size()) {
- if (!isSubvalueOf(inputs[pos], argument.type())) {
- AT_ERROR(
- "Expected value of type ",
- *argument.type(),
- " for argument '",
- argument.name(),
- "' in position ",
- pos,
- ", but instead got value of type ",
- attemptToRecoverType(inputs[pos])->str(),
- ". Declaration: ",
- *this);
- }
- } else if (argument.default_value()) {
- inputs.push_back(*argument.default_value());
- } else {
- AT_ERROR(
- name(),
- "() is missing value for argument '",
- argument.name(),
- "'. Declaration: ",
- *this);
- }
- }
-}
-
} // namespace c10
namespace torch {
namespace jit {
namespace script {
-struct CompilationUnit;
-struct Function;
+struct Module;
+struct Method;
}
} // namespace jit
} // namespace torch
struct ClassType;
using ClassTypePtr = std::shared_ptr<ClassType>;
-using ::torch::jit::script::CompilationUnit;
-using ::torch::jit::script::Function;
+using ::torch::jit::script::Module;
+using ::torch::jit::script::Method;
// This represents a class in TorchScript.
struct CAFFE2_API ClassType : public Type {
// Create a user type and register it globally.
static ClassTypePtr create(
const std::string& name,
- std::shared_ptr<CompilationUnit> module);
+ std::shared_ptr<Module> module);
// Create a type representing a Module,
// These do not have methods, and are not globally registered
- static ClassTypePtr createModuleType(std::shared_ptr<CompilationUnit> module);
+ static ClassTypePtr createModuleType();
// returns nullptr if there is no type with that name
static ClassTypePtr get(const std::string& name);
return attributeNames_[slot];
}
- Function* getMethod(const std::string& name) const;
- CompilationUnit& compilation_unit();
- const CompilationUnit& compilation_unit() const;
- std::vector<Function*> methods() const;
-
+ Method* getMethod(const std::string& name) const;
+ std::vector<Method*> methods() const;
const std::string& name() const {
return typename_;
static const TypeKind Kind = TypeKind::ClassType;
private:
- ClassType(std::string name, std::shared_ptr<CompilationUnit> cu)
+ ClassType(std::string name, std::shared_ptr<Module> module)
: Type(TypeKind::ClassType),
typename_(std::move(name)),
- compilation_unit_(std::move(cu)) {}
+ module_(std::move(module)) {}
// Name of type (note that this has to be globally unique).
std::string typename_;
std::vector<std::string> attributeNames_;
std::vector<TypePtr> attributeTypes_;
// Holds method attributes
- std::shared_ptr<CompilationUnit> compilation_unit_;
+ std::shared_ptr<Module> module_;
};
} // namespace c10
ClassTypePtr ClassType::create(
const std::string& name,
- std::shared_ptr<CompilationUnit> cu) {
- auto ptr = ClassTypePtr(new ClassType(name, std::move(cu)));
+ std::shared_ptr<Module> module) {
+ auto ptr = ClassTypePtr(new ClassType(name, std::move(module)));
getRegistry().registerType(name, ptr);
return ptr;
}
-ClassTypePtr ClassType::createModuleType(std::shared_ptr<CompilationUnit> cu) {
- return ClassTypePtr(new ClassType("Module", std::move(cu)));
+ClassTypePtr ClassType::createModuleType() {
+ return ClassTypePtr(new ClassType("Module", nullptr));
}
ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
- auto ptr = ClassTypePtr(new ClassType(typename_, compilation_unit_));
+ auto ptr = ClassTypePtr(new ClassType(typename_, module_));
AT_ASSERT(numAttributes() == refined_slots.size());
for(size_t i = 0; i < attributeNames_.size(); ++i) {
AT_ASSERT(refined_slots[i]->isSubtypeOf(attributeTypes_[i]));
_(NoneSchemaMatch) \
_(ClassParser) \
_(PeepholeOptimize) \
- _(RecordFunction) \
- _(ModuleDefine)
+ _(RecordFunction)
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
#include "ATen/core/ivalue.h"
#include "torch/csrc/jit/script/compiler.h"
#include "torch/csrc/jit/script/module.h"
-#include "torch/jit.h"
#include "onnx/onnx_pb.h"
return a
)JIT";
void testControlFlow() {
- auto cu = compile(cf_examples);
-
+ auto cu = std::make_shared<script::Module>();
+ script::defineMethodsInModule(
+ cu, cf_examples, script::nativeResolver, c10::nullopt);
auto run = [&](const std::string& name, std::vector<IValue> stack) {
- auto graph = cu->get_function(name).graph();
+ auto graph = cu->get_method(name).graph();
Code code(graph);
InterpreterState interp(code);
interp.run(stack);
}
void invokeTestRecordFunction(at::Tensor& t) {
- autograd::profiler::GetPackedInputsCallback inputs_cb = [t]() {
- Stack st;
- pack(st, t);
- return st;
- };
+ autograd::profiler::GetPackedInputsCallback inputs_cb =
+ [t]() {
+ Stack st;
+ pack(st, t);
+ return st;
+ };
autograd::profiler::RecordFunction guard("test", inputs_cb);
t.add_(torch::ones_like(t));
}
void testRecordFunction() {
std::vector<std::vector<int64_t>> input_sizes;
- autograd::profiler::pushCallback(
- [&input_sizes](const autograd::profiler::RecordFunction& fn) {
- for (const auto& input : fn.inputs()) {
- if (input.isTensor()) {
- std::vector<int64_t> t = input.toTensor().sizes().vec();
- input_sizes.push_back(t);
- }
- }
- });
+ autograd::profiler::pushCallback([&input_sizes](
+ const autograd::profiler::RecordFunction& fn) {
+ for (const auto& input : fn.inputs()) {
+ if (input.isTensor()) {
+ std::vector<int64_t> t = input.toTensor().sizes().vec();
+ input_sizes.push_back(t);
+ }
+ }
+ });
auto t = torch::randn({1, 2, 3}, at::kCPU);
invokeTestRecordFunction(t);
// test nested RecordFunctions
std::vector<std::string> nested_names;
- autograd::profiler::pushCallback(
- [&nested_names](const autograd::profiler::RecordFunction& fn) {
- nested_names.push_back(getFullName(&fn));
- });
+ autograd::profiler::pushCallback([&nested_names](
+ const autograd::profiler::RecordFunction& fn) {
+ nested_names.push_back(getFullName(&fn));
+ });
{
autograd::profiler::RecordFunction guard("outer");
- invokeTestRecordFunctionNested();
- ;
+ invokeTestRecordFunctionNested();;
}
autograd::profiler::popCallback();
// checking that constant propagation ran wo/failure
AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
}
-
-void testModuleDefine() {
- auto m = std::make_shared<script::Module>();
- m->register_parameter("foo", torch::ones({}), false);
- m->define(R"(
- def add_it(self, x, b : int = 4):
- return self.foo + x + b
- )");
- auto result = m->run_method("add_it", torch::ones({}));
- AT_ASSERT(result.toTensor().item<float>() == 6)
-}
-
} // namespace test
} // namespace jit
} // namespace torch
graph:
GraphProto {
name: "torch-jit-export"
- inputs: [{name: "x", type:Tensor dims: 2 3},{name: "1", type:Tensor dims: 3 4},{name: "2", type:Tensor dims: 3 3}]
+ inputs: [{name: "x", type:Tensor dims: 2 3},{name: "1", type:Tensor dims: 3 3},{name: "2", type:Tensor dims: 3 4}]
outputs: [{name: "6", type:Tensor dims: 2 4}]
- initializers: [TensorProto shape: [3 4],TensorProto shape: [3 3]]
+ initializers: [TensorProto shape: [3 3],TensorProto shape: [3 4]]
nodes: [
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]},
- Node {type: "Gemm", inputs: [x,2,3], outputs: [4], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]},
+ Node {type: "Gemm", inputs: [x,1,3], outputs: [4], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]},
Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]},
- Node {type: "Gemm", inputs: [4,1,5], outputs: [6], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]}
+ Node {type: "Gemm", inputs: [4,2,5], outputs: [6], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]}
]
}
opset_import: [OperatorSetIdProto { domain: }],
graph:
GraphProto {
name: "torch-jit-export"
- inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20},{name: "2", type:Tensor dims: 20 10}]
+ inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20 10},{name: "2", type:Tensor dims: 20}]
outputs: [{name: "8", type:Tensor dims: 1 20}]
- initializers: [TensorProto shape: [20],TensorProto shape: [20 10]]
+ initializers: [TensorProto shape: [20 10],TensorProto shape: [20]]
nodes: [
Node {type: "Add", inputs: [x.1,x.1], outputs: [3], attributes: []},
Node {type: "ReduceSum", inputs: [3], outputs: [4], attributes: [{ name: 'keepdims', type: int, value: 0}]},
outputs: [{name: "10", type:Tensor dims: 1 20}]
initializers: []
nodes: [
- Node {type: "Gemm", inputs: [3,2,1], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
+ Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
]
}
outputs: [{name: "11", type:Tensor dims: 1 20}]
initializers: []
nodes: [
- Node {type: "Gemm", inputs: [3,2,1], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
+ Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
]
}
outputs: [{name: "12", type:Tensor dims: 1 20}]
initializers: []
nodes: [
- Node {type: "Gemm", inputs: [3,2,1], outputs: [12], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
+ Node {type: "Gemm", inputs: [3,1,2], outputs: [12], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
]
}
def foo(self, input):
self.call_foo(input)
- with self.assertRaisesRegex(RuntimeError, 'called recursively'):
+ with self.assertRaisesRegex(RuntimeError, 'called recursively involving'):
M()
def test_script_kwargs_fn_call(self):
"torch/csrc/jit/register_quantized_ops.cpp",
"torch/csrc/jit/scope.cpp",
"torch/csrc/jit/script/compiler.cpp",
- "torch/csrc/api/src/jit.cpp",
"torch/csrc/jit/script/edit_distance.cpp",
"torch/csrc/jit/script/logging.cpp",
"torch/csrc/jit/script/final_returns.cpp",
${TORCH_SRC_DIR}/csrc/jit/register_quantized_ops.cpp
${TORCH_SRC_DIR}/csrc/jit/scope.cpp
${TORCH_SRC_DIR}/csrc/jit/script/compiler.cpp
- ${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp
${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp
${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp
${TORCH_SRC_DIR}/csrc/api/src/data/samplers/random.cpp
${TORCH_SRC_DIR}/csrc/api/src/data/samplers/sequential.cpp
${TORCH_SRC_DIR}/csrc/api/src/data/samplers/stream.cpp
+ ${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/init.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/module.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/batchnorm.cpp
${TORCH_SRC_DIR}/csrc/jit/python_tracer.cpp
${TORCH_SRC_DIR}/csrc/jit/script/init.cpp
${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
${TORCH_SRC_DIR}/csrc/jit/script/python_tree_views.cpp
${TORCH_SRC_DIR}/csrc/multiprocessing/init.cpp
${TORCH_SRC_DIR}/csrc/nn/THNN.cpp
/// )JIT");
/// IValue output = module->run_method("relu_script", a, b);
/// \endrst
-TORCH_API std::shared_ptr<script::CompilationUnit> compile(const std::string& source);
+TORCH_API std::shared_ptr<script::Module> compile(const std::string& source);
} // namespace jit
} // namespace torch
namespace torch {
namespace jit {
-std::shared_ptr<script::CompilationUnit> compile(const std::string& source) {
- auto module = std::make_shared<script::CompilationUnit>();
- module->define(source, script::nativeResolver, nullptr);
+std::shared_ptr<script::Module> compile(const std::string& source) {
+ auto module = std::make_shared<script::Module>();
+ defineMethodsInModule(
+ module, source, script::nativeResolver, /*self=*/c10::nullopt);
return module;
}
namespace jit {
namespace script {
+// this is a much simpler accessor that only handles modules, parameters, and
+// and methods. It does not depend on python to work.
+struct ModuleAccessorValue : public SugaredValue {
+ ModuleAccessorValue(std::shared_ptr<Module> module)
+ : module(std::move(module)) {}
+ std::string kind() const override {
+ return "module";
+ }
+ // select an attribute on it, e.g. `this.field`
+ std::shared_ptr<SugaredValue> attr(
+ const SourceRange& loc,
+ Method& m,
+ const std::string& field) override {
+ if (std::shared_ptr<Module> v = module->find_module(field)) {
+ return std::make_shared<ModuleAccessorValue>(std::move(v));
+ } else if (script::Slot* v = module->find_parameter(field)) {
+ return std::make_shared<SimpleValue>(m.get_or_add_parameter(*v));
+ } else if (script::Slot* v = module->find_buffer(field)) {
+ return std::make_shared<SimpleValue>(m.get_or_add_parameter(*v));
+ } else if (script::Slot* v = module->find_attribute(field)) {
+ return std::make_shared<script::SimpleValue>(
+ m.get_or_add_attribute(*v));
+ } else if (Method* m = module->find_method(field)) {
+ return std::make_shared<MethodValue>(shared_from_this(), *m);
+ } else {
+ throw ErrorReport(loc) << "unknown attr: " << field;
+ }
+ }
+
+ private:
+ std::shared_ptr<Module> module;
+};
+
struct OpsValue : public SugaredValue {
OpsValue(size_t version) : version_(version) {}
std::string kind() const override {
}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field) override {
return std::make_shared<BuiltinModule>(field, version_);
}
std::string kind() const override {
return "constant";
}
- Value* asValue(const SourceRange& loc, Function& m) override {
+ Value* asValue(const SourceRange& loc, Method& m) override {
return m.graph()->insertConstant(value_);
}
};
// select an attribute on it, e.g. `this.field`
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field) override {
const char* field_s = field.c_str();
char* end;
};
resolver_ = [&](const std::string& name,
- Function& m,
+ Method& m,
const SourceRange& loc) -> std::shared_ptr<SugaredValue> {
auto it = env_.find(name);
if (it == env_.end()) {
const std::vector<at::Tensor>& constant_table_;
std::unordered_map<std::string, std::shared_ptr<SugaredValue>> env_;
std::function<std::shared_ptr<
- SugaredValue>(const std::string& name, Function& m, const SourceRange& loc)>
+ SugaredValue>(const std::string& name, Method& m, const SourceRange& loc)>
resolver_;
size_t parseVersionNumber() {
definitions.emplace_back(def);
resolvers.emplace_back(importer.resolver_);
}
- auto self = [&](Value* v) {
- v->setType(mod->module_object()->type());
- return std::make_shared<SimpleValue>(v);
- };
- mod->module_object()->type()->compilation_unit().define(definitions, resolvers, self);
+ auto self = std::make_shared<ModuleAccessorValue>(mod);
+ defineMethodsInModule(mod, definitions, resolvers, Self(self));
}
void import_libs(
resolvers.emplace_back(importer.resolver_);
}
- auto cu = std::make_shared<CompilationUnit>();
- auto class_type = ClassType::create(class_def.name().name(), cu);
- auto self = [&](Value* v) {
- v->setType(class_type);
- return std::make_shared<SimpleValue>(v);
- };
- cu->define(definitions, resolvers, self);
+ auto mod = std::make_shared<Module>();
+ Self self(ClassType::create(class_def.name().name(), mod));
+ defineMethodsInModule(mod, definitions, resolvers, self);
}
}
const std::string& field,
Value* newValue);
TORCH_API Node* createGetAttr(Value* obj, const std::string& field);
- TORCH_API Value* insertGetAttr(Value* obj, const std::string& field) {
- return insertNode(createGetAttr(obj, field))->output();
- }
-
// Note: defined in python_ir.cpp and can be used only in python extension
Node* createPythonOp(
THPObjectPtr&& pyobj,
norm_invstd = 1 / (eps + torch.sqrt(norm_var))
return ((input - norm_mean) * norm_invstd)
)SCRIPT";
- script::CompilationUnit cu;
- cu.define(source, script::nativeResolver, nullptr);
- *graph_ptr = cu.get_function("batch_norm").graph();
+ auto module = std::make_shared<script::Module>();
+ defineMethodsInModule(
+ module, source, script::nativeResolver, /*self=*/c10::nullopt);
+ *graph_ptr = module->get_method("batch_norm").graph();
},
&bn_graph);
[](const Argument& arg) { return arg.default_value(); });
printFunction(graph, name, is_class, defaults, ivalue_names);
}
- void printFunction(
- script::Function& method,
- bool is_class) {
- const std::string& name = method.name();
- Graph& graph = *method.graph();
- auto defaults = fmap(
- method.getSchema().arguments(),
- [](const Argument& arg) { return arg.default_value(); });
- printFunction(graph, name, is_class, defaults, {});
- }
void printModule(script::Module& module) {
std::unordered_map<script::Slot, QualifiedNamePtr> extra_ivalue_names;
createTensorToParameterNameMap(
out << "class " << classType->name() << ":\n";
{
const auto guard = WithIndented();
+ std::unordered_map<script::Slot, QualifiedNamePtr> extra_ivalue_names;
for (auto& method : classType->methods()) {
- printFunction(*method, /*is_class=*/true);
+ printMethod(*method, /*is_class=*/true, extra_ivalue_names);
}
}
}
this->cconv = other->cconv;
Py_INCREF(other->pyobj.get());
this->pyobj = THPObjectPtr(other->pyobj.get());
- this->ignore_on_export = other->ignore_on_export;
for (auto& sa : other->scalar_args) {
Py_INCREF(sa.get());
this->scalar_args.emplace_back(sa.get());
)SCRIPT");
struct BuiltinFunctionRegistry {
- const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
- const static std::vector<Function*> empty;
+ const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name) {
+ const static std::vector<Method*> empty;
// when initializing the builtin function library, we will re-enter
// getAllBuiltinFunctionsFor since it is called in the compiler to
// lookup builtins and initializing the builtin functions calls the
private:
void loadSource(const std::string& source) {
- std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
- modules.emplace_back(cu);
- cu->define(source, script::nativeResolver, /*self=*/nullptr);
- for (auto& method : cu->get_functions()) {
+ auto module = std::make_shared<script::Module>();
+ defineMethodsInModule(
+ module, source, script::nativeResolver, /*self=*/c10::nullopt);
+ modules.push_back(module);
+ for (auto& method : module->get_methods()) {
builtins_by_name[Symbol::fromQualString("aten::" + method->name())]
.push_back(method.get());
}
}
enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED;
std::recursive_mutex mutex;
- std::vector<std::shared_ptr<CompilationUnit>> modules;
- std::unordered_map<Symbol, std::vector<Function*>> builtins_by_name;
+ std::vector<std::shared_ptr<Module>> modules;
+ std::unordered_map<Symbol, std::vector<Method*>> builtins_by_name;
};
-TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
+TORCH_API const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name) {
static BuiltinFunctionRegistry registry;
return registry.getAllBuiltinFunctionsFor(name);
}
namespace jit {
namespace script {
-TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name);
+TORCH_API const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name);
}
} // namespace jit
// This file exists because we need to reference module.h, which we can't from
// c10. Sigh...
-Function* ClassType::getMethod(const std::string& name) const {
- return compilation_unit_->find_function(name).get();
+Method* ClassType::getMethod(const std::string& name) const {
+ return module_? module_->find_method(name) : nullptr;
}
-CompilationUnit& ClassType::compilation_unit() {
- return *compilation_unit_;
-}
-const CompilationUnit& ClassType::compilation_unit() const {
- return *compilation_unit_;
-}
-
-std::vector<Function*> ClassType::methods() const {
- std::vector<Function*> ret;
- for (const auto& pr : compilation_unit().get_functions()) {
+std::vector<Method*> ClassType::methods() const {
+ std::vector<Method*> ret;
+ for (const auto& pr : module_->get_methods()) {
ret.push_back(pr.get());
}
return ret;
+++ /dev/null
-#pragma once
-#include <c10/util/Exception.h>
-#include <torch/csrc/jit/graph_executor.h>
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/source_range.h>
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/utils/memory.h>
-
-#include <ATen/core/function_schema.h>
-#include <c10/util/ArrayRef.h>
-#include <c10/util/Optional.h>
-
-#include <functional>
-#include <memory>
-#include <mutex>
-#include <ostream>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-namespace torch {
-namespace jit {
-
-namespace script {
-
-struct Def;
-struct SugaredValue;
-struct Function;
-
-using Resolver = std::function<std::shared_ptr<SugaredValue>(
- const std::string& name,
- Function& f,
- const SourceRange& loc)>;
-using Self = std::function<std::shared_ptr<SugaredValue>(Value*)>;
-
-// A Function is a pure Graph with no implicit `self` object bound.
-// It contains schema information, and the executor that manages the
-// execution of the function. script::Method is a wrapper around a
-// underlying Function that also provides a `self` object.
-struct TORCH_API Function {
- Function(
- std::string name,
- bool optimize,
- std::shared_ptr<Graph> graph,
- std::function<void(Function&)> function_creator)
- : name_(std::move(name)),
- graph_(std::move(graph)),
- optimize_(optimize),
- function_creator_(std::move(function_creator)) {}
-
- void run(Stack& stack) {
- get_executor().run(stack);
- }
-
- void run(Stack&& stack) {
- run(stack);
- }
-
- IValue operator()(std::vector<IValue> stack) {
- getSchema().checkAndNormalizeInputs(stack);
- run(stack);
- return stack.front();
- }
-
- std::shared_ptr<Graph> graph_for(Stack inputs) {
- return get_executor().graphFor(inputs);
- }
-
- std::shared_ptr<Graph> graph() const {
- return graph_;
- }
-
- const std::string& name() const {
- return name_;
- }
-
- // if this isn't yet defined, run its method_creator function
- void ensure_defined();
-
- size_t num_inputs() const {
- return graph()->inputs().size();
- }
-
- Function& setSchema(FunctionSchema schema) {
- schema_ = make_unique<FunctionSchema>(std::move(schema));
- return *this;
- }
-
- const FunctionSchema& getSchema() const {
- if (schema_ == nullptr) {
- schema_ = make_unique<FunctionSchema>(defaultSchemaFor(*this));
- }
- return *schema_;
- }
-
- std::string pretty_print_schema() const {
- AT_ASSERT(schema_);
- std::stringstream ss;
- ss << *schema_;
- return ss.str();
- }
-
- GraphExecutorState getDebugState() {
- return get_executor().getDebugState();
- }
-
- void debugDisableAutodiffSubgraphInlining() {
- return get_executor().debugDisableAutodiffSubgraphInlining();
- }
-
- bool is_optimized() const {
- return optimize_;
- }
-
- void check_single_output() {
- AT_CHECK(
- graph()->outputs().size() == 1,
- "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
- }
-
- GraphExecutor& get_executor() {
- std::call_once(executor_init_, [&] {
- check_single_output();
- executor_ = GraphExecutor(graph(), optimize_);
- });
- return executor_;
- }
-
- // returns nullptr and fills in failure_messages if the callee does not
- // match the functions schema
-
- // TODO: defined in module.cpp, move to compilation_unit.cpp
- Value* try_emit_call(
- Graph& graph,
- const SourceRange& loc,
- c10::optional<NamedValue> self,
- ArrayRef<NamedValue> args,
- ArrayRef<NamedValue> kwargs,
- std::stringstream& failure_messages,
- bool conv_tensors_to_nums);
-
- Value* emit_call(
- Graph& graph,
- const SourceRange& loc,
- ArrayRef<NamedValue> args,
- ArrayRef<NamedValue> kwargs);
-
- private:
- static FunctionSchema defaultSchemaFor(const Function& function) {
- std::vector<Argument> args;
- std::vector<Argument> returns;
- Graph& g = *function.graph();
- size_t num_inputs = function.num_inputs();
- for (size_t i = 0; i < num_inputs; ++i) {
- const Value* v = g.inputs().at(i);
- std::string name = v->hasUniqueName() ? v->uniqueNameBase()
- : ("argument_" + std::to_string(i));
- args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
- }
- for (size_t i = 0; i < g.outputs().size(); ++i) {
- returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
- }
- return {function.name(), "", std::move(args), std::move(returns)};
- }
-
- std::string name_;
- std::shared_ptr<Graph> graph_; // for debugging and for inlining
- bool optimize_;
-
- GraphExecutor executor_; // for execution
-
- std::once_flag executor_init_;
-
- // an optional function that actually creates the method when
- // emit_call_to(this,...) is first called. this is used by the compiler so
- // that it can construct methods out of order
- std::function<void(Function&)> function_creator_;
-
- // if absent, then we generate a default schema based on the graph
- // mutable because getSchema caches the default schema if one is requested
- // before a call to setSchema
- mutable std::unique_ptr<FunctionSchema> schema_;
-};
-
-
-// A CompilationUnit is a list of named script::Functions
-// with helper methods to iterate the list, or invoke the function.
-// Classes have a CompilationUnit holding the class methods
-// and Modules also have a CompilationUnit holding the Functions that
-// are used to implement their Methods
-
-struct TORCH_API CompilationUnit {
- std::shared_ptr<Function> find_function(const std::string& name) const {
- auto it = dict_.find(name);
- if (it == dict_.end())
- return nullptr;
- return functions_[it->second];
- }
-
- Function& get_function(const std::string& name) const {
- if (auto r = find_function(name))
- return *r;
- AT_ERROR("attempted to get undefined function ", name);
- }
-
- void set_optimized(bool o) {
- optimized_ = o;
- }
-
- bool is_optimized() const {
- return optimized_;
- }
-
- // for historic reasons, these are defined in compiler.cpp
- void define(
- const std::vector<Def>& definitions,
- const std::vector<Resolver>& resolvers, /* determines how we handle free
- variables in each definition*/
- // if non-null, the first argument to each def, is bound to this value
- const Self& self);
-
- // same as above but parse the definitions from source
- void define(
- const std::string& source,
- const Resolver& resolver,
- const Self& self);
-
- void clone_function(const Function& remote) {
- create_function(remote.name(), remote.graph()->copy());
- }
-
- Function& create_function(std::string name, std::shared_ptr<Graph> graph) {
- auto fn = std::make_shared<Function>(
- std::move(name), is_optimized(), std::move(graph), nullptr);
- return register_function(std::move(fn));
- }
-
- const std::vector<std::shared_ptr<Function>>& get_functions() const {
- return functions_;
- }
-
- /// Run a method from this compilation.
- ///
- /// For example:
- /// @code
- /// IValue output = module->run("relu_script", a, b);
- /// @endcode
- ///
- /// To get a compile a module from a source string, see torch::jit::compile
- ///
- /// @param method_name The name of the method to run
- /// @param args Arguments to be passed to the method
- /// @return An IValue containing the return value (or values if it is a tuple)
- /// from the method
- template <typename... Types>
- IValue run_method(const std::string& method_name, Types&&... args) {
- return get_function(method_name)({IValue(std::forward<Types>(args))...});
- }
-
- void drop_all_functions() {
- dict_.clear();
- functions_.clear();
- }
-
- private:
- Function& register_function(std::shared_ptr<Function> fn) {
- AT_CHECK(
- 0 == dict_.count(fn->name()),
- "method '",
- fn->name(),
- "' already defined.");
- functions_.emplace_back(std::move(fn));
- dict_[functions_.back()->name()] = functions_.size() - 1;
- return *functions_.back();
- }
- std::vector<std::shared_ptr<Function>> functions_;
- // for fast lookup
- std::unordered_map<std::string, size_t> dict_;
- bool optimized_ = true;
-};
-
-} // namespace script
-} // namespace jit
-} // namespace torch
namespace script {
using SugaredValuePtr = std::shared_ptr<SugaredValue>;
-using FunctionTable = std::unordered_map<std::string, Function&>;
+using FunctionTable = std::unordered_map<std::string, Method&>;
using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
using AttributeMap = std::unordered_map<std::string, Const>;
using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
// delete unnecessary ones later with replaceAllusesWith().
struct Environment {
Environment(
- Function& method,
+ Method& method,
Resolver resolver,
Block* b,
std::shared_ptr<Environment> next = nullptr)
b(b),
next(std::move(next)) {}
- Function& method;
+ Method& method;
Resolver resolver;
std::vector<std::string> captured_inputs;
std::unordered_map<std::string, std::string> error_messages;
to_ir(
const Def& def,
Resolver resolver_,
- const Self& self,
- Function& method) // method being constructed
+ const c10::optional<Self>& self,
+ Method& method) // method being constructed
: method(method),
graph(method.graph()),
resolver(std::move(resolver_)),
}
private:
- Function& method;
+ Method& method;
std::shared_ptr<Graph> graph;
Resolver resolver;
std::unordered_map<int64_t, Value*> integral_constants;
FunctionSchema emitDef(
const Def& def,
- const Self& self,
+ const c10::optional<Self>& self,
Block* block) {
auto schema = extractSchemaFromDef(def, self);
// TODO need guards on init returning none
blank_decl,
List<Stmt>::create(r, {ret}));
auto m = std::make_shared<Module>();
- CompilationUnit cu;
- cu.define({def}, {resolver}, nullptr);
+ defineMethodsInModule(m, {def}, {resolver}, c10::nullopt);
Stack stack;
- cu.get_function("defaults").run(stack);
+ m->get_method("defaults").run(stack);
return stack.at(0).toTuple()->elements();
}
std::vector<Argument> parseArgsFromDecl(
const Decl& decl,
- const Self& self) {
+ const c10::optional<Self>& self) {
auto params_begin = decl.params().begin();
auto params_end = decl.params().end();
if (self) {
}
FunctionSchema extractSchemaFromDef(
const Def& def,
- const Self& self) {
+ const c10::optional<Self>& self) {
const auto name = def.name().name();
std::vector<Argument> args = parseArgsFromDecl(def.decl(), self);
std::vector<Argument> returns = parseReturnFromDecl(def.decl());
std::vector<Argument> emitFormalArguments(
const Def& def,
- const Self& self,
+ const c10::optional<Self>& self,
const FunctionSchema& schema,
Block* block) {
-
std::vector<Argument> arguments; // for schema
// inputs
auto it = def.decl().params().begin();
if (self) {
AT_ASSERT(it != end);
const auto& name = (*it).ident().name();
- Value* new_input = block->addInput()->setUniqueName(name);
- environment_stack->setSugaredVar((*it).ident().range(), name, self(new_input));
- arguments.emplace_back(name, new_input->type());
+ if (auto type = self->asFirstClass()) {
+ Value* new_input =
+ block->addInput()->setUniqueName(name)->setType(type);
+ environment_stack->setVar((*it).ident().range(), name, new_input);
+ arguments.emplace_back(name, type);
+ } else {
+ environment_stack->setSugaredVar(def.range(), name, self->asSugared());
+ }
++it;
}
size_t arg_annotation_idx = 0;
pushFrame(block, /*starts_def=*/true);
emitDef(
def,
- nullptr,
+ c10::nullopt,
block); // ignore schema return, we just wont use it for now since we
// never create a Method for the closure
popFrame(/*ends_def=*/true);
node_output = fork_node->output()->setType(
FutureType::create(fn_simple_output->type()));
}
+
// Lambda lift block(0) into attr::Subgraph
lambdaLiftFork(fork_node);
}
};
-void CompilationUnit::define(
+void defineMethodsInModule(
+ const std::shared_ptr<Module>& m,
const std::vector<Def>& definitions,
const std::vector<Resolver>& resolvers,
- const Self& self) {
+ const c10::optional<Self>& self) {
AT_ASSERT(definitions.size() == resolvers.size());
auto resolver_it = resolvers.begin();
- std::vector<Function*> methods;
- std::unordered_map<std::string, Function*> function_table;
+ std::vector<Method*> methods;
+ std::unordered_map<std::string, Method*> function_table;
for (const Def& def : definitions) {
const std::string& name = def.name().name();
auto resolver = *resolver_it++;
// the function table so the methods can see each other
resolver = [resolver, &function_table](
const std::string& name,
- Function& m,
+ Method& m,
const SourceRange& loc) -> std::shared_ptr<SugaredValue> {
auto it = function_table.find(name);
if (it != function_table.end()) {
- return std::make_shared<MethodValue>(c10::nullopt, *it->second);
+ return std::make_shared<MethodValue>(nullptr, *it->second);
}
return resolver(name, m, loc);
};
}
- auto creator = [def, resolver, self](Function& method) {
+ auto creator = [def, resolver, self](Method& method) {
AT_ASSERT(resolver);
to_ir(def, resolver, self, method);
};
- std::unique_ptr<Function> fn(
- new Function(name, is_optimized(), std::make_shared<Graph>(), creator));
- function_table[name] = fn.get();
- methods.push_back(fn.get());
- register_function(std::move(fn));
+ Method& method = m->create_method(name, creator);
+ function_table[name] = &method;
+ methods.push_back(&method);
}
- for (Function* method : methods) {
+ for (Method* method : methods) {
method->ensure_defined();
}
+ if (!self || !self->asFirstClass()) {
+ // Disable module hooks if the module is only used to store a class's code.
+ didFinishEmitModule(m);
+ }
}
-void CompilationUnit::define(
+void defineMethodsInModule(
+ const std::shared_ptr<Module>& m,
const std::string& source,
const Resolver& resolver,
- const Self& self) {
+ const c10::optional<Self>& self) {
Parser p(source);
std::vector<Def> definitions;
std::vector<Resolver> resolvers;
definitions.push_back(def);
resolvers.push_back(resolver);
}
- define(definitions, resolvers, self);
+ defineMethodsInModule(m, definitions, resolvers, self);
}
void lambdaLiftFork(Node* fork_node) {
fork_node->g_(attr::Subgraph, forked_graph);
fork_node->eraseBlock(0);
}
-
} // namespace script
} // namespace jit
} // namespace torch
namespace jit {
namespace script {
+using Resolver = std::function<std::shared_ptr<
+ SugaredValue>(const std::string& name, Method& m, const SourceRange& loc)>;
+
inline std::shared_ptr<SugaredValue> nativeResolver(
const std::string& name,
- Function& m,
+ Method& m,
const SourceRange& loc) {
if (name == "torch") {
return std::make_shared<BuiltinModule>("aten");
return nullptr;
}
+// Represents the `self` argument to a method. This wrapper class is necessary
+// because sometimes `self` sometimes is first class and sometimes not.
+//
+// `self` is first class when it refers to a ClassType. It will be bound as a
+// graph input argument.
+// `self` is sugared when it refers to a ModuleValue.
+class Self {
+ public:
+ explicit Self(std::shared_ptr<SugaredValue> sugared)
+ : sugared_(std::move(sugared)) {}
+ explicit Self(ClassTypePtr type) : firstClass_(std::move(type)) {}
+
+ ClassTypePtr asFirstClass() const {
+ return firstClass_;
+ }
+ std::shared_ptr<SugaredValue> asSugared() const {
+ return sugared_;
+ }
+
+ private:
+ // Used when `self` is not first-class and so we don't represent it in the
+ // graph. This is only ModuleValue.
+ std::shared_ptr<SugaredValue> sugared_ = nullptr;
+ // Used when `self` is a first-class type
+ ClassTypePtr firstClass_ = nullptr;
+};
+
+TORCH_API void defineMethodsInModule(
+ const std::shared_ptr<Module>& m,
+ const std::vector<Def>& definitions,
+ const std::vector<Resolver>& resolvers, /* determines how we handle free
+ variables in each definition*/
+ // if non-null, the first argument to each def, is bound to this value
+ const c10::optional<Self>& self);
+
+// same as above but parse the definitions from source
+TORCH_API void defineMethodsInModule(
+ const std::shared_ptr<Module>& m,
+ const std::string& source,
+ const Resolver& resolver,
+ const c10::optional<Self>& self);
TORCH_API void lambdaLiftFork(Node* fork_node);
// type, *add it in this function's implementation*.
std::shared_ptr<SugaredValue> toSugaredValue(
py::object obj,
- Function& m,
+ Method& m,
SourceRange loc,
- bool is_constant = false);
+ bool is_constant = false,
+ bool is_submodule = false);
struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
PythonValue(py::object self) : self(std::move(self)) {}
// call it like a function, e.g. `outputs = this(inputs)`
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
- Function& m,
+ Method& m,
at::ArrayRef<NamedValue> inputs_,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
- Function& m,
+ Method& m,
const c10::optional<size_t>& size_hint = {}) override {
const std::string type_str = typeString(self);
std::stringstream ss;
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field) override {
const std::string type_str = typeString(self);
std::stringstream ss;
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field) override {
py::object member = getattr(loc, field);
// note: is_constant = true because we consider that global properties
: PythonValue(std::move(tup)) {}
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
- Function& m,
+ Method& m,
const c10::optional<size_t>& size_hint = {}) override {
py::tuple tup = self;
std::vector<std::shared_ptr<SugaredValue>> result;
return result;
}
- Value* asValue(const SourceRange& loc, Function& m) override {
+ Value* asValue(const SourceRange& loc, Method& m) override {
std::vector<Value*> values;
for (const auto& sugared_item : asTuple(loc, m)) {
values.push_back(sugared_item->asValue(loc, m));
// Represents all the parameters of a module as a List[Tensor]
struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
- ConstantParameterList(Value* the_list) : the_list_(the_list) {}
- std::string kind() const override {
- return "constant parameter list";
- }
- std::shared_ptr<SugaredValue> call(
- const SourceRange& loc,
- Function& caller,
- at::ArrayRef<NamedValue> inputs,
- at::ArrayRef<NamedValue> attributes,
- size_t n_binders) override {
- return toSimple(the_list_);
- }
-
- private:
- Value* the_list_;
-};
-
-struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue {
- OverloadedFunctionValue(Value* module, std::vector<std::string> method_names)
- : module_(module), method_names_(std::move(method_names)) {}
+ ConstantParameterList(std::shared_ptr<Module> module)
+ : module_(std::move(module)) {}
std::string kind() const override {
- return "overloaded function";
+ return "constant parameter list";
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
- Function& caller,
+ Method& caller,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
- std::stringstream err;
- std::vector<NamedValue> new_inputs = inputs.vec();
- new_inputs.insert(new_inputs.begin(), module_);
-
- for (const std::string& method_name : method_names_) {
- auto cls = module_->type()->expect<ClassType>();
- Function* fn = cls->getMethod(method_name);
- auto match = tryMatchSchema(
- fn->getSchema(),
- loc,
- *caller.graph().get(),
- c10::nullopt,
- new_inputs,
- attributes,
- err,
- true);
- if (match) {
- return MethodValue(module_, *fn)
- .call(loc, caller, inputs, attributes, n_binders);
- }
+ // Add all module parameters as inputs to the graph
+ std::vector<Value*> params;
+ const auto& param_list = module_->get_parameters();
+ for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
+ auto& param = *it;
+ params.push_back(caller.get_or_add_parameter(param));
}
- throw ErrorReport(loc) << "Could not find any matching overloads\n"
- << err.str();
+ auto list = caller.graph()->createList(TensorType::get(), params);
+ caller.graph()->insertNode(list);
+ return toSimple(list->output());
}
private:
- Value* module_;
- std::vector<std::string> method_names_;
+ std::shared_ptr<Module> module_;
};
// defines how modules/methods behave inside the script subset.
// holding the actual nn.Module class.
struct ModuleValue : public SugaredValue {
- ModuleValue(Value* self, std::shared_ptr<Module> module)
- : self_(self), module_(std::move(module)) {}
+ ModuleValue(std::shared_ptr<Module> module) : module(std::move(module)) {}
std::string kind() const override {
return "module";
// select an attribute on it, e.g. `this.field`
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field) override {
// workaround to make self.training work
// it adds a buffer 'training' to the model if one doesn't exist
// and then loads that parameter, casting it to bool
if (field == "training") {
- Slot* v = module_->find_buffer(field);
+ Slot* v = module->find_buffer(field);
if (!v) {
- py::object py_module = py::cast(module_);
+ py::object py_module = py::cast(module);
bool training = py::cast<bool>(py::getattr(py_module, "training"));
auto t =
autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
- module_->register_buffer("training", std::move(t));
- v = module_->find_buffer(field);
+ module->register_buffer("training", std::move(t));
+ v = module->find_buffer(field);
}
- Value* the_tensor = m.graph()->insertGetAttr(self_, "training");
+ Value* the_tensor = m.get_or_add_parameter(*v);
Value* the_bool = m.graph()->insert(prim::Bool, {the_tensor});
return std::make_shared<SimpleValue>(the_bool);
}
- if (std::shared_ptr<Module> v = module_->find_module(field)) {
- return std::make_shared<ModuleValue>(
- m.graph()->insertGetAttr(self_, field), v);
- } else if (auto kind = module_->kind_of(field)) {
- // methods, parameters, attributes, and buffers are all first class
- return SimpleValue(self_).attr(loc, m, field);
+ if (std::shared_ptr<Module> v = module->find_module(field)) {
+ return std::make_shared<ModuleValue>(v);
+ } else if (Method* v = module->find_method(field)) {
+ return std::make_shared<MethodValue>(shared_from_this(), *v);
+ } else if (Slot* v = module->find_parameter(field)) {
+ return std::make_shared<SimpleValue>(m.get_or_add_parameter(*v));
+ } else if (Slot* v = module->find_attribute(field)) {
+ return std::make_shared<SimpleValue>(
+ m.get_or_add_attribute(*v));
}
// This can also be a call to a non-script module, or a plain
// python method. If so return this as a python value.
- py::object py_module = py::cast(module_);
-
- py::object overloads =
- py_module.attr("_overloads").attr("get")(field, py::none());
- if (!overloads.is_none()) {
- return std::make_shared<OverloadedFunctionValue>(
- self_, py::cast<std::vector<std::string>>(overloads));
- }
-
+ py::object py_module = py::cast(module);
if (py::object attr = py::getattr(py_module, field.c_str(), py::none())) {
if (py::isinstance<py::function>(attr) &&
py::hasattr(attr, "_is_parameter_list") &&
py::cast<bool>(py::getattr(attr, "_is_parameter_list"))) {
- Graph& g = *m.graph();
- // Add all module parameters as inputs to the graph
- std::vector<Value*> params;
- const auto& param_list = module_->get_parameters();
- for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
- auto& param = *it;
- params.emplace_back(g.insertGetAttr(self_, param.name()));
- }
- auto list =
- g.insertNode(g.createTuple(params))->output();
- return std::make_shared<ConstantParameterList>(list);
+ return std::make_shared<ConstantParameterList>(module);
}
if (py::isinstance<py::function>(attr) ||
py::isinstance(attr, py::module::import("torch.nn").attr("Module")) ||
// call module.forward
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
- Function& caller,
+ Method& caller,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
- Function& m,
+ Method& m,
const c10::optional<size_t>& size_hint = {}) override {
- py::object py_module = py::cast(module_);
+ py::object py_module = py::cast(module);
if (!py::isinstance(
py_module,
py::module::import("torch.jit").attr("_ConstModuleList")))
return SugaredValue::asTuple(loc, m, size_hint);
std::vector<std::shared_ptr<SugaredValue>> result;
- for (py::handle py_submodule : py_module) {
- py::object obj = py::reinterpret_borrow<py::object>(py_submodule);
- if (py::isinstance<Module>(obj)) {
- auto sub_module = py::cast<std::shared_ptr<Module>>(obj);
- Value* module_v = m.graph()->insertGetAttr(self_, sub_module->name());
- result.emplace_back(
- std::make_shared<ModuleValue>(module_v, sub_module));
- } else {
- result.push_back(toSugaredValue(
- obj,
- m,
- loc,
- /*is_constant =*/false));
- }
+ for (py::handle module : py_module) {
+ py::object obj = py::reinterpret_borrow<py::object>(module);
+ result.push_back(toSugaredValue(
+ obj,
+ m,
+ loc,
+ /*is_constant =*/false,
+ /*is_submodule =*/true));
}
return result;
}
private:
- Value* self_;
- std::shared_ptr<Module> module_;
+ std::shared_ptr<Module> module;
};
struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
- Function& caller,
+ Method& caller,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
py::dict dispatched_fn_;
};
-std::shared_ptr<MethodValue> moduleToMethod(
- const std::shared_ptr<Module>& mod) {
- // this path only supports calling raw script functions
- // but because they are not distinguished from models, we have to check
- // that they are function-like here. They must not have state, and they
- // must have a forward method. When we expose functions to python
- // this will be replaced with a direct py::isinstance<Function> call.
-
- if (mod->get_parameters().size() != 0) {
- throw ErrorReport()
- << "Attempted to inline a Module with parameters. "
- "Stateful modules to be inlined must be submodules of the callee.";
+struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue {
+ OverloadedFunctionValue(py::list functions)
+ : possible_functions_(std::move(functions)) {}
+
+ std::string kind() const override {
+ return "overloaded function";
}
- Method* forward = mod->find_method("forward");
- if (!forward) {
- throw ErrorReport() << " expected this module to have a forward function.";
+
+ std::shared_ptr<SugaredValue> call(
+ const SourceRange& loc,
+ Method& caller,
+ at::ArrayRef<NamedValue> inputs,
+ at::ArrayRef<NamedValue> attributes,
+ size_t n_binders) override {
+ std::stringstream err;
+ auto possible_functions =
+ py::cast<std::vector<py::object>>(possible_functions_);
+
+ for (const py::object& fn : possible_functions) {
+ auto& method = py::cast<Method&>(fn);
+ auto match = tryMatchSchema(
+ method.getSchema(),
+ loc,
+ *caller.graph().get(),
+ c10::nullopt,
+ inputs,
+ attributes,
+ err,
+ true);
+ if (match) {
+ return MethodValue(nullptr, method)
+ .call(loc, caller, inputs, attributes, n_binders);
+ }
+ }
+ throw ErrorReport(loc) << "Could not find any matching overloads\n"
+ << err.str();
}
- return std::make_shared<MethodValue>(at::nullopt, forward->function());
-}
+
+ private:
+ py::list possible_functions_;
+};
std::shared_ptr<SugaredValue> toSugaredValue(
py::object obj,
- Function& m,
+ Method& m,
SourceRange loc,
- bool is_constant) {
+ bool is_constant,
+ bool is_submodule) {
// directly create SimpleValues when possible, because they are first-class
// and can be re-assigned. Otherwise, this would be invalid:
// f = python_constant
obj = weak_obj;
}
if (py::isinstance<Module>(obj)) {
+ auto mod = py::cast<std::shared_ptr<Module>>(obj);
// In the case that this Python object is not a submodule, inline *ONLY
// PURE* ScriptModules. This allows us to call arbitrary @script functions
// within a scripting context while still enforcing that parameters from
// stateful submodules are properly accounted for.
- auto mod = py::cast<std::shared_ptr<Module>>(obj);
- return moduleToMethod(mod);
+ if (!is_submodule && mod->get_parameters().size() != 0) {
+ throw ErrorReport()
+ << "Attempted to inline a Module with parameters. "
+ "Stateful modules to be inlined must be submodules of the callee.";
+ }
+ return std::make_shared<ModuleValue>(mod);
} else if (py::isinstance<py::module>(obj)) {
return std::make_shared<PythonModuleValue>(obj);
} else if (obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr()) {
py::module::import("torch.jit").attr("_try_compile_weak_script")(obj);
if (!compiled_fn.is(py::none())) {
auto mod = py::cast<std::shared_ptr<Module>>(compiled_fn);
- return moduleToMethod(mod);
+ return std::make_shared<ModuleValue>(mod);
}
}
return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
}
+ py::object overloads =
+ py::module::import("torch.jit").attr("_try_get_overloaded_fn")(obj);
+ if (!overloads.is_none()) {
+ return std::make_shared<OverloadedFunctionValue>(std::move(overloads));
+ }
+
return std::make_shared<PythonValue>(obj);
}
namespace {
Resolver pythonResolver(const ResolutionCallback& rcb) {
- return [rcb](const std::string& name, Function& m, const SourceRange& loc)
+ return [rcb](const std::string& name, Method& m, const SourceRange& loc)
-> std::shared_ptr<SugaredValue> {
AutoGIL ag;
py::object obj = rcb(name);
schema.is_varret());
}
-static Self moduleSelf(const std::shared_ptr<Module>& m) {
- return [m](Value* v) {
- v->setType(m->module_object()->type());
- return std::make_shared<ModuleValue>(v, m);
- };
-}
-
void initJitScriptBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
bool has_self) {
c10::optional<Self> self;
if (has_self) {
- m->class_compilation_unit().define(
- script, pythonResolver(rcb), moduleSelf(m));
- } else {
- m->_define_lowered(script, pythonResolver(rcb));
+ self = Self(std::make_shared<ModuleValue>(m));
}
- didFinishEmitModule(m);
+ defineMethodsInModule(m, script, pythonResolver(rcb), self);
})
.def(
"_create_methods",
for (auto& callback : rcbs) {
resolvers.push_back(pythonResolver(callback));
}
- m->class_compilation_unit().define(defs, resolvers, moduleSelf(m));
+ defineMethodsInModule(
+ m, defs, resolvers, Self(std::make_shared<ModuleValue>(m)));
+
// Stitch in default arguments for each Def if provided
auto defaults_it = defaults.begin();
auto defs_it = defs.begin();
while (defs_it != defs.end()) {
- auto& method = m->class_compilation_unit().get_function(
- (*defs_it).name().name());
+ auto& method = m->get_method((*defs_it).name().name());
method.setSchema(getSchemaWithNameAndDefaults(
defs_it->range(),
method.getSchema(),
auto& p = parameters[i];
py::tuple r(2);
result[i] = std::make_tuple(
- p.name(), autograd::as_variable_ref(p.value().toTensor()));
+ p.name(),
+ autograd::as_variable_ref(p.value().toTensor()));
}
return result;
})
[](Module& self,
const std::string& name,
std::shared_ptr<Graph> graph) {
- self._define_lowered(name, std::move(graph), {});
+ self.create_method(name, std::move(graph), {});
})
.def(
"_create_method_from_trace",
var_lookup_fn,
force_outplace,
input_tuple.size());
- self->_define_lowered(
- name, std::move(graph), std::move(parameters));
+ self->create_method(name, std::move(graph), std::move(parameters));
didFinishEmitModule(self);
})
.def(
[](Module& self) {
if (self.find_method("forward")) {
Method& m = self.get_method("forward");
- return m.get_executor().getDebugState();
+ return m.getDebugState();
}
throw std::runtime_error(
"Attempted to call get_debug_state on a Module without a compiled forward()");
[](Module& self) {
if (self.find_method("forward")) {
Method& m = self.get_method("forward");
- m.get_executor().debugDisableAutodiffSubgraphInlining();
+ m.debugDisableAutodiffSubgraphInlining();
}
})
.def(
}
Method* orig_method = orig->find_method(name);
- m->_define_lowered(
- name, orig_method->graph()->copy(), std::move(member_inputs));
+ m->create_method(name, orig_method->graph()->copy(), member_inputs);
});
py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
method, tuple_slice(std::move(args), 1), std::move(kwargs));
})
.def_property_readonly("graph", [](Method& m) { return m.graph(); })
- .def(
- "propagate_shapes",
- [](Method& m, const std::vector<at::Tensor>& inputs, bool with_grad) {
- return propagate_shapes(
- *m.graph(), inputs, m.initial_ivalues(), with_grad);
- })
+ .def("propagate_shapes", &Method::propagate_shapes)
.def(
"propagate_and_assign_input_and_output_shapes",
- [](Method& m,
- const std::vector<at::Tensor>& inputs,
- std::vector<at::Tensor> outputs,
- bool with_grad,
- bool propagate) {
- return propagate_and_assign_input_and_output_shapes(
- *m.graph(),
- inputs,
- m.initial_ivalues(),
- outputs,
- with_grad,
- propagate);
- })
+ &Method::propagate_and_assign_input_and_output_shapes)
.def(
"initial_ivalues",
[](Method& m) {
})
.def(
"debug_disable_autodiff_subgraph_inlining",
- [](Method& m) {
- return m.get_executor().debugDisableAutodiffSubgraphInlining();
- })
+ &Method::debugDisableAutodiffSubgraphInlining)
.def("schema", &Method::getSchema)
- .def(
- "pretty_print_schema",
- [](Method& m) {
- const FunctionSchema& schema = m.getSchema();
- std::stringstream ss;
- ss << schema;
- return ss.str();
- })
+ .def("pretty_print_schema", &Method::pretty_print_schema)
.def(
"python_print",
[](Method& m) {
ResolutionCallback rcb,
FunctionDefaults defaults) {
auto def_f = def.withName("forward");
-
- mod->_define_lowered({def_f}, {pythonResolver(rcb)});
- auto& func = mod->lowered_methods().get_function("forward");
- func.setSchema(getSchemaWithNameAndDefaults(
- def.range(), func.getSchema(), def.name().name(), defaults));
- auto& func2 = mod->class_compilation_unit().get_function("forward");
- func2.setSchema(getSchemaWithNameAndDefaults(
- def.range(), func2.getSchema(), def.name().name(), defaults));
+ defineMethodsInModule(
+ mod, {def_f}, {pythonResolver(rcb)}, c10::nullopt);
+ auto& method = mod->get_method("forward");
+ method.setSchema(getSchemaWithNameAndDefaults(
+ def.range(), method.getSchema(), def.name().name(), defaults));
didFinishEmitModule(mod);
return mod;
});
m.def(
"_jit_script_class_compile",
- [](const ClassDef& classDef, ResolutionCallback rcb) {
- auto cu = std::make_shared<CompilationUnit>();
- auto classType = ClassType::create(classDef.name().name(), cu);
+ [](std::shared_ptr<Module> module,
+ const ClassDef& classDef,
+ ResolutionCallback rcb) {
+ auto classType = ClassType::create(classDef.name().name(), module);
std::vector<Resolver> rcbs;
std::vector<Def> methodDefs;
for (const auto& def : classDef.defs()) {
methodDefs.push_back(def);
rcbs.push_back(pythonResolver(rcb));
}
- cu->define(methodDefs, rcbs, simpleSelf(classType));
+ defineMethodsInModule(module, methodDefs, rcbs, Self(classType));
+ return module;
});
m.def("parse_type_comment", [](const std::string& comment) {
#include <c10/util/Exception.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/jit/script/schema_matching.h>
namespace script {
struct RecursiveMethodCallError : public std::exception {};
-void placeholderCreator(Function&) {
+void placeholderCreator(Method&) {
throw RecursiveMethodCallError();
}
-void Function::ensure_defined() {
- try {
- if (function_creator_) {
- auto creator = function_creator_;
- function_creator_ = placeholderCreator;
- creator(*this);
- function_creator_ = nullptr;
- }
- } catch (RecursiveMethodCallError&) {
- throw ErrorReport() // TODO: once lower_first_class methods is removed
- // re-establish callsite info for debugging
- << " method '" << name() << "' is called recursively. "
- << "Recursive calls are not supported";
- }
-}
-
-Value* Function::try_emit_call(
+Value* try_emit_call_to(
Graph& graph,
const SourceRange& loc,
+ Method& callee,
c10::optional<NamedValue> self,
ArrayRef<NamedValue> args,
ArrayRef<NamedValue> kwargs,
std::stringstream& failure_messages,
+ Method* caller,
bool conv_tensors_to_nums) {
- ensure_defined();
- auto fn = this->graph();
+ try {
+ callee.ensure_defined();
+ } catch (RecursiveMethodCallError&) {
+ throw ErrorReport(loc)
+ << " method '" << callee.name()
+ << "' is called recursively involving this call site. "
+ << "Recursive calls are not supported";
+ }
+ auto fn = callee.graph();
auto matched_schema = tryMatchSchema(
- getSchema(),
+ callee.getSchema(),
loc,
graph,
std::move(self),
if (!matched_schema)
return nullptr;
- check_single_output();
- return inlineCallTo(graph, *fn, matched_schema->inputs).at(0);
+ // parameters to callee method (which become parameters to _this_ method
+ // if they were not already)
+ for (const auto& member : callee.initial_ivalues()) {
+ if (!caller) {
+ throw ErrorReport(loc)
+ << " attempting to call a method with parameters/attributes"
+ " from a raw graph. File a bug report";
+ }
+ matched_schema->inputs.push_back(
+ caller->get_or_add_attribute(member));
+ }
+ callee.check_single_output();
+ return inlineCallTo(graph, *callee.graph(), matched_schema->inputs).at(0);
}
-Value* Function::emit_call(
- Graph& graph,
+Value* Method::emit_call_to(
const SourceRange& loc,
+ Method& callee,
ArrayRef<NamedValue> args,
ArrayRef<NamedValue> kwargs) {
+ AT_ASSERT(!executor);
std::stringstream failure_messages;
- if (auto result = try_emit_call(
- graph,
+ if (auto result = try_emit_call_to(
+ *graph(),
loc,
+ callee,
c10::nullopt,
args,
kwargs,
failure_messages,
+ this,
/*conv_tensors_to_nums=*/true)) {
return result;
}
throw ErrorReport(loc) << failure_messages.str();
}
+void Method::ensure_defined() {
+ if (method_creator) {
+ auto creator = method_creator;
+ method_creator = placeholderCreator;
+ creator(*this);
+ method_creator = nullptr;
+ }
+}
+
void Module::to(at::Device device, at::ScalarType dtype, bool non_blocking) {
to_impl(device, dtype, non_blocking);
}
}
}
-// lower_first_class_method and lift_lowered_method are transitionary functions
-// used to translate between module-as-first-class code generation,
-// and module-as-special execution. Once module-as-first-class execution is
-// debugged, then we can remove both and remove the lowered_functions_ table.
-
-// remove the first module argument, replacing any access of its
-// parameters/attributes with extra_ivalue input Slots that hold what value to
-// pass into the graph
-std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
- const ModulePtr& self,
- Graph& g_,
- size_t self_offset = 0) {
- std::shared_ptr<Graph> g = g_.copy();
- std::vector<Slot> extra_ivalues;
- std::unordered_map<Slot, size_t> slot_to_offset;
- struct ToScan {
- ModulePtr mod;
- Node* n;
- size_t offset;
- };
- std::vector<ToScan> to_scan;
- std::vector<Node*> to_clean; // nodes that should be dead at the end
-
- auto getOrAddSlot = [&](const Slot& slot) -> Value* {
- auto it = slot_to_offset.find(slot);
- if (it != slot_to_offset.end()) {
- size_t ivalues_start = g->inputs().size() - extra_ivalues.size();
- return g->inputs().at(ivalues_start + it->second);
- }
- extra_ivalues.emplace_back(slot);
- slot_to_offset[slot] = extra_ivalues.size() - 1;
- return g->addInput()->setType(slot.type());
- };
-
- auto self_value = g->inputs().at(self_offset);
-
- for (Use use : self_value->uses()) {
- to_scan.emplace_back(ToScan{self, use.user, use.offset});
- }
- while (to_scan.size() > 0) {
- auto e = to_scan.back();
- to_scan.pop_back();
-
- // when we lambda lift forks, first-class modules may be passed across
- // forks. This code recursively lowers the module in the fork call.
- if (e.n->kind() == prim::fork) {
- auto subgraph = e.n->g(attr::Subgraph);
- std::vector<Slot> new_slots;
- std::tie(subgraph, new_slots) = lower_graph(e.mod, *subgraph, e.offset);
- e.n->g_(attr::Subgraph, subgraph);
- for (const Slot& slot : new_slots) {
- e.n->addInput(getOrAddSlot(slot));
- }
- e.n->removeInput(e.offset);
- continue;
- }
- if (e.n->kind() != prim::GetAttr) {
- throw ErrorReport(e.n->getSourceLocation())
- << "temporary: the only valid use of a module is looking up an attribute";
- }
- Slot slot(e.mod, e.mod->type()->getAttributeSlot(e.n->s(attr::name)));
- if (ClassTypePtr c = e.n->output()->type()->cast<ClassType>()) {
- if (c->name() == "Module") {
- auto obj = slot.value().toObject();
- for (Use use : e.n->output()->uses()) {
- to_scan.emplace_back(ToScan{obj, use.user, use.offset});
- }
- to_clean.emplace_back(e.n);
- continue;
- }
- }
- e.n->output()->replaceAllUsesWith(getOrAddSlot(slot));
- e.n->destroy();
- }
-
- while (to_clean.size() > 0) {
- Node* n = to_clean.back();
- AT_ASSERT(!n->hasUses());
- n->destroy();
- to_clean.pop_back();
- }
- AT_ASSERT(!self_value->hasUses());
- g->eraseInput(self_offset);
-
- return std::make_pair(std::move(g), std::move(extra_ivalues));
-}
-
-Method& Module::lower_first_class_method(Function* fn) {
- fn->ensure_defined();
- auto lowered = lower_graph(module_object(), *fn->graph());
- Function& new_func =
- lowered_methods_.create_function(fn->name(), lowered.first);
-
- // generate the new schema
- // slice away the self argument
- std::vector<Argument> args(
- fn->getSchema().arguments().begin() + 1,
- fn->getSchema().arguments().end());
- size_t id = 0;
- for (const Slot& slot : lowered.second) {
- std::ostringstream ss;
- ss << "slot" << id++;
- args.emplace_back(ss.str(), slot.type());
- }
- new_func.setSchema(fn->getSchema().cloneWithArguments(std::move(args)));
- return _create_lowered_method(&new_func, std::move(lowered.second));
-}
-
-static void createFirstClassValues(
- Module* module,
- Value* self,
- std::unordered_map<Slot, Value*>& result) {
- auto& g = *self->owningGraph();
-
- std::vector<Node*> created;
- struct ToScan {
- Module* mod;
- Value* v; // value representing module in the graph
- };
- std::vector<ToScan> to_scan = {{module, self}};
-
- while (!to_scan.empty()) {
- auto s = to_scan.back();
- to_scan.pop_back();
- size_t offset = 0;
- for (const std::string& name :
- s.mod->module_object()->type()->attributeNames()) {
- Value* v = g.insertGetAttr(s.v, name);
- result[Slot(s.mod->module_object(), offset++)] = v;
- if (std::shared_ptr<Module> sub = s.mod->find_module(name)) {
- to_scan.emplace_back(ToScan{sub.get(), v});
- }
- }
- }
-}
-
-void Module::lift_lowered_method(Method& m) {
- auto graph = m.graph()->copy();
- Value* self = graph->insertInput(0, "self")->setType(module_object()->type());
- std::unordered_map<Slot, Value*> slot_to_value;
- if (!m.initial_ivalues().empty()) {
- WithInsertPoint guard(*graph->nodes().begin());
- createFirstClassValues(this, self, slot_to_value);
- }
-
- size_t orig_graph_inputs_size = graph->inputs().size();
- for (size_t i = 0; i < m.initial_ivalues().size(); ++i) {
- size_t input_offset = orig_graph_inputs_size - i - 1;
- size_t ivalue_offset = m.initial_ivalues().size() - i - 1;
- graph->inputs()
- .at(input_offset)
- ->replaceAllUsesWith(
- slot_to_value.at(m.initial_ivalues().at(ivalue_offset)));
- graph->eraseInput(input_offset);
- }
-
- if (!m.initial_ivalues().empty()) {
- // we added _all_ the submodules as first-class values but maybe did not use
- // them. So remove any dead attribute lookups
- EliminateDeadCode(graph);
- }
-
- Function& new_fn = class_cu().create_function(m.name(), std::move(graph));
- // created lifted schema
- // self argument is named '$self' to prevent accidental name collisions
- // with another input that the user named 'self'
- std::vector<Argument> new_args = {Argument("$self", module_object()->type())};
- const auto& lowered_args = m.function().getSchema().arguments();
- new_args.insert(
- new_args.end(),
- lowered_args.begin(),
- lowered_args.begin() + m.num_inputs());
- new_fn.setSchema(m.function().getSchema().cloneWithArguments(std::move(new_args)));
-}
-
-Method& Module::_create_lowered_method(
- Function* func,
- std::vector<Slot> member_inputs) {
- std::unique_ptr<Method> m(new Method(this, func, std::move(member_inputs)));
- return *insert(func->name(), methods_, EntityType::METHOD, std::move(m));
-}
-
-void Module::lift_lowered_methods(size_t start) {
- for (size_t i = start; i < lowered_methods_.get_functions().size(); ++i) {
- Method& m = _create_lowered_method(
- lowered_methods_.get_functions().at(i).get(), {});
- lift_lowered_method(m);
- }
-}
-
-void Module::_define_lowered(
- const std::vector<Def>& definitions,
- const std::vector<Resolver>& resolvers) {
- size_t start = lowered_methods_.get_functions().size();
- lowered_methods_.define(definitions, resolvers, nullptr);
- lift_lowered_methods(start);
- // call lift_lowered_method for each definition
-}
-
-void Module::_define_lowered(const std::string& src, const Resolver& resolver) {
- size_t start = lowered_methods_.get_functions().size();
- lowered_methods_.define(src, resolver, nullptr);
- lift_lowered_methods(start);
-}
-
-Method& Module::_define_lowered(
- std::string name,
- std::shared_ptr<Graph> graph,
- std::vector<Slot> slots) {
- Method& m = _create_lowered_method(
- &lowered_methods_.create_function(std::move(name), std::move(graph)),
- std::move(slots));
- lift_lowered_method(m);
- return m;
-}
-
-void Module::define(const std::string& src, const Resolver& resolver) {
- class_cu().define(
- src,
- resolver ? resolver : nativeResolver,
- simpleSelf(module_object()->type()));
-}
-
} // namespace script
} // namespace jit
} // namespace torch
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/api/include/torch/ordered_dict.h>
-#include <torch/csrc/jit/script/compilation_unit.h>
#include <torch/csrc/utils/memory.h>
#include <ATen/core/function_schema.h>
// Map which stores filename to content.
using ExtraFilesMap = std::unordered_map<std::string, std::string>;
-using ModulePtr = c10::intrusive_ptr<c10::ivalue::Object>;
// A method in a module, e.g. f in:
//
// class M(ScriptModule):
using ModuleLookup =
std::function<std::shared_ptr<Module>(const std::vector<std::string>&)>;
-struct TORCH_API Method {
- Method(Module* owner, Function* function, std::vector<Slot> initial_members)
+struct Method {
+ Method(
+ Module* owner,
+ std::string name,
+ bool optimize,
+ std::shared_ptr<Graph> graph,
+ std::vector<Slot> initial_members,
+ std::function<void(Method&)> method_creator)
: owner_(owner),
- function_(function),
- initial_ivalues_(std::move(initial_members)) {
- AT_ASSERT(function->num_inputs() >= initial_ivalues_.size());
- }
-
- // the module that contains this method.
- Module& owner() const {
- return *owner_;
+ name_(std::move(name)),
+ graph_(std::move(graph)),
+ optimize(optimize),
+ initial_ivalues_(std::move(initial_members)),
+ method_creator(std::move(method_creator)) {
+ AT_ASSERT(graph_->inputs().size() >= initial_ivalues_.size());
+ int i = graph_->inputs().size() - initial_ivalues_.size();
+ for (auto member : initial_ivalues_) {
+ initial_ivalue_index[member] = i++;
+ }
}
void run(Stack& stack) {
for (auto input : initial_ivalues_) {
push(stack, input.value());
}
- function_->run(stack);
+ get_executor().run(stack);
}
+
void run(Stack&& stack) {
run(stack);
}
IValue operator()(std::vector<IValue> stack) {
- getSchema().checkAndNormalizeInputs(stack);
- for (auto input : initial_ivalues_) {
- push(stack, input.value());
- }
- // use run rather than operator() to skip the second schema check.
- function_->run(std::move(stack));
+ checkInputsAgainstSchema(stack);
+ run(stack);
return stack.front();
}
+ std::shared_ptr<Graph> graph_for(Stack inputs) {
+ for (auto tp : initial_ivalues_) {
+ inputs.emplace_back(tp.value());
+ }
+ return get_executor().graphFor(inputs);
+ }
+ TORCH_API std::shared_ptr<Graph> graph() const {
+ return graph_;
+ }
+
+ TORCH_API const std::string& name() const {
+ return name_;
+ }
+ // emit a function call by inlining the callees Graph into this one
+ // adding any extra parameters necessary to do this call
+
+ // defined here to keep details of member_input handling confined to this
+ // class
+ Value* emit_call_to(
+ const SourceRange& loc,
+ Method& callee,
+ ArrayRef<NamedValue> args,
+ ArrayRef<NamedValue> kwargs);
+
+ // if this isn't yet defined, run its method_creator function
+ TORCH_API void ensure_defined();
+
+ size_t num_inputs() const {
+ return graph()->inputs().size() - initial_ivalues_.size();
+ }
+ TORCH_API Value* get_or_add_parameter(Slot slot) {
+ AT_ASSERT(slot.value().isTensor());
+ return get_or_add_attribute(slot);
+ }
+ TORCH_API Value* get_or_add_attribute(Slot slot) {
+ auto it = initial_ivalue_index.find(slot);
+ if (it != initial_ivalue_index.end()) {
+ return graph()->inputs().at(it->second);
+ }
+ initial_ivalues_.push_back(slot);
+ initial_ivalue_index[slot] = graph()->inputs().size();
+ return graph()->addInput()->setType(slot.type());
+ }
+
+ static void setInputTensorTypes(Graph& g, const Stack& stack) {
+ AT_ASSERT(stack.size() == g.inputs().size());
+ for (size_t i = 0; i < stack.size(); ++i) {
+ g.inputs().at(i)->setType(
+ DimensionedTensorType::create(stack.at(i).toTensor()));
+ }
+ }
+
+ std::shared_ptr<Graph> propagate_shapes(
+ std::vector<at::Tensor> inputs,
+ bool with_grad = false) {
+ auto retval = graph_->copy();
+ Stack stack;
+ stack.reserve(inputs.size() + initial_ivalues_.size());
+ for (at::Tensor& i : inputs) {
+ stack.emplace_back(std::move(i));
+ }
+ for (const Slot& inp : initial_ivalues_) {
+ stack.push_back(inp.value());
+ }
+ setInputTensorTypes(*retval, stack);
+ PropagateInputShapes(retval);
+ return retval;
+ }
+
+ std::shared_ptr<Graph> propagate_and_assign_input_and_output_shapes(
+ std::vector<at::Tensor> inputs,
+ std::vector<at::Tensor> outputs,
+ bool with_grad = false,
+ bool propagate = true) {
+ auto retval = graph_->copy();
+ for (auto inp : initial_ivalues_) {
+ if (inp.value().isTensor()) {
+ inputs.push_back(inp.value().toTensor());
+ }
+ }
+ if (propagate) {
+ setInputTensorTypes(*retval, fmap<IValue>(inputs));
+ PropagateInputShapes(retval);
+ }
+ AT_ASSERT(retval->inputs().size() == inputs.size());
+ for (size_t i = 0; i < retval->inputs().size(); ++i) {
+ auto scalar_type = inputs[i].scalar_type();
+ auto sizes = inputs[i].sizes();
+ auto type =
+ torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
+ retval->inputs()[i]->setType(type);
+ }
+ at::ArrayRef<Value*> output_values = retval->outputs();
+ // patch this to still work if we are returning a tuple of multiple values
+ if (output_values.at(0)->type()->kind() == TupleType::Kind) {
+ AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
+ output_values = output_values.at(0)->node()->inputs();
+ }
+ AT_ASSERT(output_values.size() == outputs.size());
+ for (size_t i = 0; i < retval->outputs().size(); ++i) {
+ auto scalar_type = outputs[i].scalar_type();
+ auto sizes = outputs[i].sizes();
+ auto type =
+ torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
+ output_values[i]->setType(type);
+ }
+ return retval;
+ }
+
const std::vector<Slot>& initial_ivalues() const {
return initial_ivalues_;
}
- // proxies for underlying unbound Function
- std::shared_ptr<Graph> graph_for(Stack inputs) {
- for (auto tp : initial_ivalues_) {
- inputs.emplace_back(tp.value());
+ Method& setSchema(FunctionSchema schema_) {
+ schema = make_unique<FunctionSchema>(std::move(schema_));
+ return *this;
+ }
+
+ TORCH_API const FunctionSchema& getSchema() const {
+ if (schema == nullptr) {
+ schema = make_unique<FunctionSchema>(defaultSchemaFor(*this));
}
- return function_->get_executor().graphFor(inputs);
+ return *schema;
}
- std::shared_ptr<Graph> graph() const {
- return function_->graph();
+ std::string pretty_print_schema() const {
+ AT_ASSERT(schema);
+ std::stringstream ss;
+ ss << *schema;
+ return ss.str();
}
- const std::string& name() const {
- return function_->name();
+ GraphExecutorState getDebugState() {
+ return get_executor().getDebugState();
}
- size_t num_inputs() const {
- return function_->num_inputs() - initial_ivalues_.size();
+ void debugDisableAutodiffSubgraphInlining() {
+ return get_executor().debugDisableAutodiffSubgraphInlining();
}
- FunctionSchema getSchema() const {
- // we are required to slice out the slot inputs from the schema
- // we can't cache this because setSchema on the underlying function
- // will change the underlying schema
- auto sliced = ArrayRef<Argument>(function_->getSchema().arguments())
- .slice(0, num_inputs());
- return function_->getSchema().cloneWithArguments(sliced.vec());
+ bool is_optimized() const {
+ return optimize;
}
- GraphExecutor& get_executor() {
- return function_->get_executor();
+ // the module that contains this method.
+ Module& owner() const {
+ return *owner_;
}
- Function& function() const {
- return *function_;
+ void check_single_output() {
+ AT_CHECK(
+ graph()->outputs().size() == 1,
+ "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
}
private:
+ static FunctionSchema defaultSchemaFor(const Method& method) {
+ std::vector<Argument> args;
+ std::vector<Argument> returns;
+ Graph& g = *method.graph();
+ size_t num_inputs = method.num_inputs();
+ for (size_t i = 0; i < num_inputs; ++i) {
+ const Value* v = g.inputs().at(i);
+ std::string name = v->hasUniqueName() ? v->uniqueNameBase()
+ : ("argument_" + std::to_string(i));
+ args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
+ }
+ for (size_t i = 0; i < g.outputs().size(); ++i) {
+ returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
+ }
+ return {method.name(), "", std::move(args), std::move(returns)};
+ }
+
+ GraphExecutor& get_executor() {
+ std::call_once(executor_init, [&] {
+ check_single_output();
+ executor = GraphExecutor(graph(), optimize);
+ });
+ return executor;
+ }
+
+ void checkInputsAgainstSchema(std::vector<IValue>& inputs) {
+ const auto& schema = getSchema();
+ // Do we have more inputs than the schema accepts?
+ AT_CHECK(
+ inputs.size() <= schema.arguments().size(),
+ "Expected at most ",
+ schema.arguments().size(),
+ " argument(s) for operator '",
+ schema.name(),
+ "', but received ",
+ inputs.size(),
+ " argument(s). Declaration: ",
+ schema);
+
+ for (size_t pos = 0; pos < schema.arguments().size(); ++pos) {
+ const auto& argument = schema.arguments()[pos];
+ if (pos < inputs.size()) {
+ if (!isSubvalueOf(inputs[pos], argument.type())) {
+ AT_ERROR(
+ "Expected value of type ",
+ *argument.type(),
+ " for argument '",
+ argument.name(),
+ "' in position ",
+ pos,
+ ", but instead got value of type ",
+ attemptToRecoverType(inputs[pos])->str(),
+ ". Declaration: ",
+ schema);
+ }
+ } else if (argument.default_value()) {
+ inputs.push_back(*argument.default_value());
+ } else {
+ AT_ERROR(
+ schema.name(),
+ "() is missing value for argument '",
+ argument.name(),
+ "'. Declaration: ",
+ schema);
+ }
+ }
+ }
+
// Methods are uniqued onwed by a single module. This raw pointer allows
// looking up the module.
Module* owner_;
- // Underlying unbound function
- Function* function_;
-
- // parameters and attributes loaded from the Module and appending
- // before calling function_
+ std::string name_;
+ std::shared_ptr<Graph> graph_; // for debugging and for inlining
+ bool optimize;
+
+ GraphExecutor executor; // for execution
+ // initial_ivalues are a list of additional arguments appended to graph
+ // that are inputs that come from the members of the Module or its submodules.
+ // each is a pointer to a slot in the module that owns this parameter
+ // parameters and submodules can only be _added_ to script Modules to ensure
+ // these pointers always stay valid
std::vector<Slot> initial_ivalues_;
+
+ // map from a IValue* in initial_ivalues to the offset it appears at
+ // in graph. used to accelerate get_or_add_parameter
+ std::unordered_map<Slot, size_t> initial_ivalue_index;
+
+ // TODO: support that case where we allow _writes_ to parameters from
+ // compiled functions.
+ // This requires more sophisticated tracking of ssa values in Graphs so that
+ // stores to all modules can be lifted to the end of a graph execution.
+ // It also adds more complexity to adding actual module invocations
+ // to the executor, so currently it is not done.
+ // std::vector<at::Tensor*> member_outputs;
+
+ std::once_flag executor_init;
+
+ // an optional function that actually creates the method when
+ // emit_call_to(this,...) is first called. this is used by the compiler so
+ // that it can construct methods out of order
+ std::function<void(Method&)> method_creator;
+
+ // if absent, then we generate a default schema based on the graph
+ // mutable because getSchema caches the default schema if one is requested
+ // before a call to setSchema
+ mutable std::unique_ptr<FunctionSchema> schema;
};
struct Module;
-struct TORCH_API Module {
+struct Module {
TH_DISALLOW_COPY_AND_ASSIGN(Module);
Module()
: name_("__main__"),
module_value_(c10::ivalue::Object::create(
- ClassType::createModuleType(std::make_shared<CompilationUnit>()),
- 0)) {}
+ ClassType::createModuleType(),
+ 0)),
+ optimize_(true) {}
- ~Module() {
- // ClassType own the compilation unit of their Functions, but each
- // Function has a self argument which owns the ClassType, created a
- // referernce cycle. By dropping all the methods of the module's class
- // here we break the cycle.
- class_cu().drop_all_functions();
- }
const std::string& name() const {
return name_;
}
// note this doesn't change the flags of existing methods just ones
// added afterward.
void set_optimized(bool o) {
- class_cu().set_optimized(o);
+ optimize_ = o;
}
bool is_optimized() const {
- return class_cu().is_optimized();
+ return optimize_;
}
IValue forward(std::vector<IValue> inputs) {
name,
attributes_,
EntityType::ATTRIBUTE,
- appendSlot(name, TensorType::get(), std::move(v)));
+ appendSlot(name, TensorType::get(),std::move(v)));
}
void register_parameter(
const std::string& name,
const TypePtr type,
IValue ivalue) {
- insert(
- name,
- attributes_,
- EntityType::ATTRIBUTE,
- appendSlot(name, type, ivalue));
+ insert(name, attributes_, EntityType::ATTRIBUTE, appendSlot(name, type, ivalue));
}
void register_module(
const std::string& name,
// AT_WARN(
// "Attempting to assign submodule '",
// name,
- // "' but it is already a submodule of another ScriptModule '",
- // module->parent_->name(), "'", " Modules of this form do not import
- // and export correctly. This use is deprecated and may be" " removed
- // in a future version.");
+ // "' but it is already a submodule of another ScriptModule '", module->parent_->name(), "'",
+ // " Modules of this form do not import and export correctly. This use is deprecated and may be"
+ // " removed in a future version.");
// }
module->parent_ = this;
module->name_ = name;
insert(name, modules_, EntityType::MODULE, std::move(module));
}
+ Method& create_method(
+ const std::string& name,
+ std::shared_ptr<Graph> graph,
+ std::vector<Slot> member_inputs) {
+ AT_ASSERT(graph);
+ std::unique_ptr<Method> method(new Method(
+ this,
+ name,
+ optimize_,
+ std::move(graph),
+ std::move(member_inputs),
+ nullptr));
+ return *insert(name, methods_, EntityType::METHOD, std::move(method));
+ }
+
+ Method& create_method(
+ const std::string& name,
+ std::function<void(Method&)> creator) {
+ std::unique_ptr<Method> method(new Method(
+ this,
+ name,
+ optimize_,
+ std::make_shared<Graph>(),
+ {},
+ std::move(creator)));
+ return *insert(name, methods_, EntityType::METHOD, std::move(method));
+ }
+
Slot parameter_slot(const std::string& name) const {
return parameters_[get_offset(name, EntityType::PARAMETER)];
}
// each module owns its method. The reference returned here
// is guarenteed to stay valid until this module has been destroyed
Method& get_method(const std::string& name) const {
- if (Method* method = find_method(name)) {
- return *method;
- }
- // temporary: force the error message
- // once the on-demand creation of Method is removed, this code
- // can be removed as well
- get_offset(name, EntityType::METHOD);
- AT_ERROR("unreachable");
+ return *methods_[get_offset(name, EntityType::METHOD)];
}
std::shared_ptr<Module> get_module(const std::string& name) const {
c10::ArrayRef<Slot> get_attributes() const {
return attributes_;
}
- const std::vector<std::unique_ptr<Method>>& get_methods() const {
- // force methods_ to be up to date by querying all
- // methods. This will go away when lowered_methods_ is deleted
- for (const auto& m : class_cu().get_functions()) {
- get_method(m->name());
- }
+ c10::ArrayRef<std::unique_ptr<Method>> get_methods() const {
return methods_;
}
auto offset = find_offset(name, EntityType::MODULE);
return offset ? modules_[*offset] : nullptr;
}
- Method* find_method(const std::string& name) const {
+ Method* find_method(const std::string& name) {
auto offset = find_offset(name, EntityType::METHOD);
- if (offset) {
- return methods_[*offset].get();
- }
-
- if (Function* fn = class_cu().find_function(name).get()) {
- // temporary lock because technically this is marked const,
- // but we have to update the internal Method cache.
- // This can be removed when class_cu() is the source of truth for
- // methods.
- std::lock_guard<std::recursive_mutex> guard(find_method_guard_);
- return &const_cast<Module*>(this)->lower_first_class_method(fn);
- }
-
- return nullptr;
+ return offset ? methods_[*offset].get() : nullptr;
}
void apply(std::function<void(Module&)> fn) {
for (auto& submod : get_modules()) {
/// destination is on the GPU or vice versa, the copy is performed
/// asynchronously with respect to the host. Otherwise, the argument has no
/// effect.
- void to(at::Device device, at::ScalarType dtype, bool non_blocking = false);
+ TORCH_API void to(
+ at::Device device,
+ at::ScalarType dtype,
+ bool non_blocking = false);
/// Recursively casts all parameters to the given dtype.
///
/// destination is on the GPU or vice versa, the copy is performed
/// asynchronously with respect to the host. Otherwise, the argument has no
/// effect.
- void to(at::ScalarType dtype, bool non_blocking = false);
+ TORCH_API void to(at::ScalarType dtype, bool non_blocking = false);
/// Recursively moves all parameters to the given device.
///
/// destination is on the GPU or vice versa, the copy is performed
/// asynchronously with respect to the host. Otherwise, the argument has no
/// effect.
- void to(at::Device device, bool non_blocking = false);
+ TORCH_API void to(at::Device device, bool non_blocking = false);
/// Run a method from this module.
///
mod->copy_into(module_lookup, parameter_remap, names);
names.pop_back();
}
-
- for (auto& fn : class_cu().get_functions()) {
- curr->class_cu().clone_function(*fn);
+ for (auto& method : get_methods()) {
+ std::vector<Slot> initial_ivalues;
+ for (auto& p : method->initial_ivalues()) {
+ initial_ivalues.push_back(parameter_remap.at(p));
+ }
+ curr->create_method(
+ method->name(), method->graph()->copy(), initial_ivalues);
}
}
enum class EntityType { MODULE, PARAMETER, ATTRIBUTE, METHOD };
at::optional<EntityType> kind_of(const std::string& name) const {
- // force lazy creation of Method if needed
- // remove once lowered_methods_ is removed.
- find_method(name);
-
auto it = dict_.find(name);
if (it == dict_.end())
return at::nullopt;
return it->second.type;
}
- ModulePtr module_object() const {
- return module_value_;
- }
- CompilationUnit& class_compilation_unit() {
- return module_object()->type()->compilation_unit();
- }
- CompilationUnit& lowered_methods() const {
- return lowered_methods_;
- }
-
- // so that C++ users can easily add methods
- void define(const std::string& src, const Resolver& resolver = nullptr);
-
- void _define_lowered(
- const std::vector<Def>& definitions,
- const std::vector<Resolver>& resolvers);
- void _define_lowered(const std::string& src, const Resolver& resolver);
-
- Method& _define_lowered(
- std::string name,
- std::shared_ptr<Graph> graph,
- std::vector<Slot> slots);
-
private:
- Method& _create_lowered_method(
- Function* func,
- std::vector<Slot> member_inputs);
-
- Method& lower_first_class_method(Function* fn);
- void lift_lowered_method(Method& fn);
- void lift_lowered_methods(size_t start);
-
void to_impl(
const c10::optional<at::Device>& device,
const c10::optional<at::ScalarType>& dtype,
return Slot(module_value_, slot_index);
}
- CompilationUnit& class_cu() {
- return module_value_->type()->compilation_unit();
- }
- const CompilationUnit& class_cu() const {
- return module_value_->type()->compilation_unit();
- }
-
// modules have a single namespace, but spread over 4 different concepts:
// parameters, attributes, methods, and sub-modules
// we store individual lists of each concept, and a single map to
std::unordered_map<std::string, Entry> dict_;
std::string name_;
- ModulePtr module_value_;
-
- // back reference to parent of this Module if present
- Module* parent_ = nullptr;
- // Currently we are in a transitionary state
- // where we construct such first class functions but we lower them
- // to a form where the modules does not exist before execution.
+ c10::intrusive_ptr<at::ivalue::Object> module_value_;
- // So each Method is actually stored twice once in first-class Module
- // form and once in lowered form.
- // first-class: module_value_->type().compilation_unit() holds Functions that
- // treat modules as first class.
-
- // lowered: In this lowered form, all the attributes/parameters are appended
- // as additional inputs. lowered_methods_ holds this lowered form
- // mutable because it is a cache for class_cu() methods
- mutable CompilationUnit lowered_methods_;
- mutable std::recursive_mutex find_method_guard_;
+ // back reference to parent of this Module if present
+ Module* parent_ = nullptr;
+ bool optimize_;
};
-static void setInputTensorTypes(Graph& g, const Stack& stack) {
- AT_ASSERT(stack.size() == g.inputs().size());
- for (size_t i = 0; i < stack.size(); ++i) {
- g.inputs().at(i)->setType(
- DimensionedTensorType::create(stack.at(i).toTensor()));
- }
-}
-
-inline std::shared_ptr<Graph> propagate_shapes(
- Graph& graph,
- const std::vector<at::Tensor>& inputs,
- const std::vector<Slot>& initial_ivalues,
- bool with_grad = false) {
- auto retval = graph.copy();
- Stack stack;
- stack.reserve(inputs.size() + initial_ivalues.size());
- for (const at::Tensor& i : inputs) {
- stack.emplace_back(std::move(i));
- }
- for (const Slot& inp : initial_ivalues) {
- stack.push_back(inp.value());
- }
- setInputTensorTypes(*retval, stack);
- PropagateInputShapes(retval);
- return retval;
-}
-
-inline std::shared_ptr<Graph> propagate_and_assign_input_and_output_shapes(
+// returns nullptr and fills in failure_messages if the callee does not
+// match the functions schema
+Value* try_emit_call_to(
Graph& graph,
- std::vector<at::Tensor> inputs,
- const std::vector<Slot>& initial_ivalues,
- std::vector<at::Tensor> outputs,
- bool with_grad = false,
- bool propagate = true) {
- auto retval = graph.copy();
- for (auto inp : initial_ivalues) {
- if (inp.value().isTensor()) {
- inputs.push_back(inp.value().toTensor());
- }
- }
- if (propagate) {
- setInputTensorTypes(*retval, fmap<IValue>(inputs));
- PropagateInputShapes(retval);
- }
- AT_ASSERT(retval->inputs().size() == inputs.size());
- for (size_t i = 0; i < retval->inputs().size(); ++i) {
- auto scalar_type = inputs[i].scalar_type();
- auto sizes = inputs[i].sizes();
- auto type =
- torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
- retval->inputs()[i]->setType(type);
- }
- at::ArrayRef<Value*> output_values = retval->outputs();
- // patch this to still work if we are returning a tuple of multiple values
- if (output_values.at(0)->type()->kind() == TupleType::Kind) {
- AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
- output_values = output_values.at(0)->node()->inputs();
- }
- AT_ASSERT(output_values.size() == outputs.size());
- for (size_t i = 0; i < retval->outputs().size(); ++i) {
- auto scalar_type = outputs[i].scalar_type();
- auto sizes = outputs[i].sizes();
- auto type =
- torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
- output_values[i]->setType(type);
- }
- return retval;
-}
-
+ const SourceRange& loc,
+ Method& callee,
+ c10::optional<NamedValue> self,
+ ArrayRef<NamedValue> args,
+ ArrayRef<NamedValue> kwargs,
+ std::stringstream& failure_messages,
+ // when callee uses no parameters (e.g. it is a function in a compilation
+ // unit, and not a method), then nullptr can be passed as caller.
+ Method* caller,
+ bool conv_tensors_to_nums);
} // namespace script
} // namespace jit
} // namespace torch
return emitBuiltinNode(*matched_schema, loc, graph, name);
}
}
- for (Function* method : builtin_functions) {
- if (auto result = method->try_emit_call(
+ for (Method* method : builtin_functions) {
+ if (auto result = try_emit_call_to(
graph,
loc,
+ *method,
self,
inputs,
attributes,
failure_messages,
+ nullptr,
allow_conversions)) {
return result;
}
c10::intrusive_ptr<c10::ivalue::Object> container_;
size_t offset_;
friend struct std::hash<Slot>;
- friend struct Module;
};
}}}
std::shared_ptr<SugaredValue> PrintValue::call(
const SourceRange& loc,
- Function& m,
+ Method& m,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) {
std::shared_ptr<SugaredValue> BuiltinFunction::call(
const SourceRange& loc,
- Function& m,
+ Method& m,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) {
// callable value that will resolve to foo(x, y, z) when called.
std::shared_ptr<SugaredValue> SimpleValue::attr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field) {
// Allow method-style casts on Tensor types. e.g. x.int()
if (value_->type()->isSubtypeOf(TensorType::get())) {
if (auto classType = value_->type()->cast<ClassType>()) {
// This is a class, emit the proper attribute lookup
if (auto method = classType->getMethod(field)) {
- return std::make_shared<MethodValue>(getValue(), *method);
+ return std::make_shared<MethodValue>(shared_from_this(), *method);
}
if (!classType->hasAttribute(field)) {
std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
const SourceRange& loc,
- Function& m,
+ Method& m,
const c10::optional<size_t>& size_hint) {
static const auto make_simple_value =
[](Value* v) -> std::shared_ptr<SugaredValue> {
void SimpleValue::setAttr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field,
Value* newValue) {
const auto classType = value_->type()->cast<ClassType>();
std::shared_ptr<SugaredValue> ClassValue::call(
const SourceRange& loc,
- Function& m,
+ Method& m,
// note: names for args will be 'argument 0', 'argument 1', etc..
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
// Generate a new object of the right type, then call `__init__` on it
auto& g = *m.graph();
- auto self = g.insertNode(g.createObject(type_))->output();
+ auto createNode = g.insertNode(g.createObject(type_));
+ auto self = std::make_shared<SimpleValue>(createNode->output());
auto initMethod = type_->getMethod("__init__");
AT_ASSERT(initMethod);
// Call the init function
MethodValue(self, *initMethod).call(loc, m, inputs, attributes, n_binders);
- return std::make_shared<SimpleValue>(self);
+ return self;
}
std::shared_ptr<SugaredValue> ClassValue::attr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field) {
if (field != "__new__") {
throw ErrorReport(loc) << "Tried to lookup unknown attribute on class";
// what can we do with this thing?
// use it as a value e.g. `this + 4`
- virtual Value* asValue(const SourceRange& loc, Function& m) {
+ virtual Value* asValue(const SourceRange& loc, Method& m) {
throw ErrorReport(loc) << kind() << " cannot be used as a value";
}
// select an attribute on it, e.g. `this.field`
virtual std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field) {
throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
}
// assign an attribute on it, e.g. `this.field = newValue`
virtual void setAttr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field,
Value* newValue) {
throw ErrorReport(loc) << "attribute assignment is not defined on "
// a method invocation
virtual std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
- Function& m,
+ Method& m,
const c10::optional<size_t>& size_hint = {}) {
throw ErrorReport(loc) << kind() << " cannot be used as a tuple";
}
// call it like a function, e.g. `outputs = this(inputs)`
virtual std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
- Function& m,
+ Method& m,
// note: names for args will be 'argument 0', 'argument 1', etc..
at::ArrayRef<NamedValue> inputs_,
at::ArrayRef<NamedValue> attributes,
std::string kind() const override {
return "value";
}
- Value* asValue(const SourceRange& range, Function& m) override {
+ Value* asValue(const SourceRange& range, Method& m) override {
return value_;
}
NoneStatus isNone() override {
}
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
- Function& m,
+ Method& m,
const c10::optional<size_t>& size_hint = {}) override;
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field) override;
void setAttr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field,
Value* newValue) override;
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
- Function& m,
+ Method& m,
at::ArrayRef<NamedValue> attributes,
at::ArrayRef<NamedValue> inputs,
size_t n_binders) override;
}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field) override {
return std::make_shared<BuiltinFunction>(
Symbol::fromQualString(name + "::" + field), c10::nullopt);
// n = Foo(constructor_arg)
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
- Function& m,
+ Method& m,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override;
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& field) override;
std::string kind() const override {
// defines how a method obtained from a module behaves in script
struct MethodValue : public SugaredValue {
- MethodValue(c10::optional<NamedValue> self, Function& method)
+ MethodValue(std::shared_ptr<SugaredValue> self, Method& method)
: self_(std::move(self)), method(method) {}
std::string kind() const override {
return "method";
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
- Function& f,
+ Method& caller,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
- Graph& graph = *f.graph();
- if (self_) {
+ if (auto classType = dynamic_cast<SimpleValue*>(self_.get())) {
+ // If self_ is a class, then it will be expected as part of
+ // the schema. Add it to the front of the inputs.
std::vector<NamedValue> inputsWithSelf;
- inputsWithSelf.emplace_back(loc, self_->value(graph));
+ inputsWithSelf.emplace_back(loc, classType->getValue());
inputsWithSelf.insert(inputsWithSelf.end(), inputs.begin(), inputs.end());
return std::make_shared<SimpleValue>(
- method.emit_call(graph, loc, inputsWithSelf, attributes));
+ caller.emit_call_to(loc, method, inputsWithSelf, attributes));
}
return std::make_shared<SimpleValue>(
- method.emit_call(graph, loc, inputs, attributes));
+ caller.emit_call_to(loc, method, inputs, attributes));
}
private:
- c10::optional<NamedValue> self_;
- Function& method;
+ std::shared_ptr<SugaredValue> self_;
+ Method& method;
};
struct TORCH_API PrintValue : public SugaredValue {
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
- Function& m,
+ Method& m,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override;
: BuiltinFunction(method, c10::nullopt), type_(std::move(type)) {}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
- Function& m,
+ Method& m,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
std::shared_ptr<SugaredValue> createObject(
const SourceRange& loc,
- Function& m,
+ Method& m,
const std::string& classname) {
if (classname != type_->name()) {
throw ErrorReport(loc)
return fmap(nvs, [&](const NamedValue& v) { return v.value(g); });
}
-static inline Self simpleSelf(const TypePtr& typ) {
- return [typ](Value* v) {
- v->setType(typ);
- return std::make_shared<SimpleValue>(v);
- };
-}
-
} // namespace script
} // namespace jit
} // namespace torch
return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0;
}
-void loadModule(const script::CompilationUnit& module) {
- for (const auto& method : module.get_functions()) {
+void loadModule(const std::shared_ptr<script::Module>& module) {
+ for (const auto& method : module->get_methods()) {
if (isHelperFunction(method->name()))
continue;
void loadFunctions() {
for (const std::string& str : functions) {
- script::CompilationUnit cu;
- cu.define(str, script::nativeResolver, nullptr);
+ auto cu = std::make_shared<script::Module>();
+ script::defineMethodsInModule(
+ cu, str, script::nativeResolver, c10::nullopt);
loadModule(cu);
}
}
return _jit_internal.boolean_dispatched.get(fn)
-def _try_get_overloaded_fn(mod, field):
- return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None
+def _try_get_overloaded_fn(fn):
+ if not hasattr(fn, '__self__') or not isinstance(fn.__self__, ScriptModule):
+ # Only allow overloads for bound methods
+ return None
+ overloads = fn.__self__._overloads.get(fn.__name__, None)
+ if overloads is None:
+ return None
+ return [getattr(fn.__self__, overload) for overload in overloads]
def _try_compile_weak_script(fn):
return obj
if _rcb is None:
_rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
+ mod = ScriptModule()
if inspect.isclass(obj):
if not _is_new_style_class(obj):
raise RuntimeError("TorchScript classes must be new-style classes. Please inherit from 'object'")
ast = get_jit_class_def(obj)
- _jit_script_class_compile(ast, _rcb)
+ _jit_script_class_compile(mod, ast, _rcb)
_add_script_class(obj, obj.__name__)
return obj
else:
- mod = ScriptModule()
ast = get_jit_def(obj)
_jit_script_compile(mod, ast, _rcb, get_default_args(obj))
- # Forward docstrings
- mod.__doc__ = obj.__doc__
- return mod
+ # Forward docstrings
+ mod.__doc__ = obj.__doc__
+ return mod
ScriptMethodStub = namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))