const torch::jit::script::Module& module,
Predicate predicate) {
for (const auto& parameter : module.get_parameters()) {
- AT_ASSERT(predicate(*parameter->slot()));
+ AT_ASSERT(predicate(parameter->slot()->toTensor()));
}
for (const auto& child : module.get_modules()) {
check_all_parameters(*child->module, predicate);
def emitModuleHook(self, module):
def copy_structure_and_params(m):
c = torch.jit.ScriptModule()
- for name, v, buffer in m._get_parameters():
- c._register_parameter(name, v, buffer)
+ for name, v in m._get_parameters():
+ c._register_parameter(name, v, False)
+ for name, the_type, v in m._get_attributes():
+ c._register_attribute(name, the_type, v)
for name, s in m._get_modules():
c._register_module(name, copy_structure_and_params(s))
return c
torch.nn.BatchNorm2d(100),
torch.nn.BatchNorm2d(100, affine=False)]:
getattr(clazz, mode)()
-
input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
torch.randn(20, 100, 35, 45)
-
traced = torch.jit.trace(clazz, (input,))
imported = self.getExportImportCopy(traced)
x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2}
self.checkScript(fn, (a_dict, ('a', 'c')))
+ def test_module_attrs(self):
+ class M(torch.jit.ScriptModule):
+ def __init__(self, table):
+ super(M, self).__init__()
+ self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor])
+ self.x = torch.nn.Parameter(torch.tensor([100.0]))
+
+ @torch.jit.script_method
+ def forward(self, key):
+ # type: (str) -> Tensor
+ return self.table[key] + self.x
+
+ with self.disableModuleHook():
+ # TODO: re-enable module hook when Python printing of attributes is
+ # supported
+ m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
+ self.assertEqual(m("c"), torch.tensor([103]))
+
def test_tensor_import_export(self):
@torch.jit.script
def foo(x):
const std::string& key,
Tensor& tensor,
bool is_buffer) {
- auto* read_tensor = module_->find_parameter(key);
- AT_CHECK(read_tensor != nullptr, "No such serialized tensor '", key, "'");
+ auto param = module_->find_parameter(key);
+ auto buffer = module_->find_buffer(key);
+ AT_CHECK(
+ param != nullptr || buffer != nullptr,
+ "No such serialized tensor '",
+ key,
+ "'");
// clang-format off
+ auto read_param = is_buffer ? buffer : param;
+ auto read_tensor = read_param->slot()->toTensor();
AT_CHECK(
- read_tensor->is_buffer == is_buffer,
+ bool(buffer) == is_buffer,
"Expected deserialized tensor for key '", key,
"' to ", is_buffer ? "not " : "", "be a buffer, but it was not");
// clang-format on
if (tensor.defined()) {
torch::NoGradGuard guard;
- if (tensor.device() != read_tensor->slot()->device()) {
- tensor.set_data(autograd::Variable(*read_tensor->slot()).data());
+ if (tensor.device() != read_tensor.device()) {
+ tensor.set_data(autograd::Variable(read_tensor).data());
} else {
- tensor.set_(*read_tensor->slot());
+ tensor.set_(read_tensor);
}
} else {
- tensor = std::move(*read_tensor->slot());
+ tensor = std::move(read_tensor);
}
}
torch::ModuleDef* module_def);
void convertParameter(
- const script::NamedParameter& param,
- torch::ParameterDef* param_def);
+ const script::NamedIValue& param,
+ torch::ParameterDef* param_def,
+ bool is_parameter);
std::ofstream ofs_;
caffe2::serialize::PyTorchStreamWriter writer_;
module_def->set_optimize(module.is_optimized());
for (const auto& elem : module.get_parameters()) {
torch::ParameterDef* param_def = module_def->add_parameters();
- convertParameter(elem.value(), param_def);
+ convertParameter(elem.value(), param_def, /*is_buffer=*/false);
+ }
+ for (const auto& elem : module.get_attributes()) {
+ if (elem.value().type->isSubtypeOf(TensorType::get())) {
+ torch::ParameterDef* param_def = module_def->add_parameters();
+ convertParameter(elem.value(), param_def, /*is_buffer=*/true);
+ }
}
std::stringstream module_name;
}
void ScriptModuleSerializer::convertParameter(
- const script::NamedParameter& param,
- torch::ParameterDef* param_def) {
- param_def->set_name(param.name);
- param_def->set_is_buffer(param.is_buffer);
- param_def->set_tensor_id(addTensor(*param.slot()));
+ const script::NamedIValue& param,
+ torch::ParameterDef* param_def,
+ bool is_parameter) {
+ param_def->set_name(param.name_);
+ param_def->set_is_buffer(is_parameter);
+ param_def->set_tensor_id(addTensor(param.slot()->toTensor()));
}
// Pretty printing for ONNX
ScriptModuleDeserializer(std::istream* is);
explicit ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai);
void deserialize(
- ModuleLookup module_lookup,
+ script::ModuleLookup module_lookup,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files);
caffe2::serialize::PyTorchStreamReader reader_;
// this is a hack to make sure the script module created in C++ is the
// same as created in Python
- ModuleLookup moduleLookup_;
+ script::ModuleLookup moduleLookup_;
c10::optional<at::Device> device_;
std::vector<std::string> moduleStack_;
: reader_(std::move(rai)) {}
void ScriptModuleDeserializer::deserialize(
- ModuleLookup module_lookup,
+ script::ModuleLookup module_lookup,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
torch::ModelDef model_def;
for (int i = 0; i < module_def.parameters_size(); ++i) {
const torch::ParameterDef& param_def = module_def.parameters(i);
at::Tensor tensor = tensor_table_.at(param_def.tensor_id());
- module->register_parameter(param_def.name(), tensor, param_def.is_buffer());
+ if (param_def.is_buffer()) {
+ module->register_buffer(param_def.name(), tensor);
+ } else {
+ module->register_parameter(param_def.name(), tensor, /*is_buffer=*/false);
+ }
}
if (module_def.has_torchscript_arena()) {
at::DataPtr data;
} // namespace
void import_ir_module(
- ModuleLookup module_lookup,
+ script::ModuleLookup module_lookup,
std::istream& in,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
}
void import_ir_module(
- ModuleLookup module_lookup,
+ script::ModuleLookup module_lookup,
const std::string& filename,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
}
void import_ir_module(
- ModuleLookup module_lookup,
+ script::ModuleLookup module_lookup,
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<at::Device> device,
script::ExtraFilesMap& extra_files) {
namespace torch {
namespace jit {
-using ModuleLookup = std::function<std::shared_ptr<script::Module>(
- const std::vector<std::string>&)>;
-
static script::ExtraFilesMap default_extra_files;
TORCH_API void import_ir_module(
- ModuleLookup module_lookup,
+ script::ModuleLookup module_lookup,
const std::string& filename,
c10::optional<c10::Device> device = c10::nullopt,
script::ExtraFilesMap& extra_files = default_extra_files);
TORCH_API void import_ir_module(
- ModuleLookup module_lookup,
+ script::ModuleLookup module_lookup,
std::istream& in,
c10::optional<c10::Device> device = c10::nullopt,
script::ExtraFilesMap& extra_files = default_extra_files);
TORCH_API void import_ir_module(
- ModuleLookup module_lookup,
+ script::ModuleLookup module_lookup,
std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
c10::optional<c10::Device> device = c10::nullopt,
script::ExtraFilesMap& extra_files = default_extra_files);
const std::string& field) override {
if (script::NamedModule* v = module->find_module(field)) {
return std::make_shared<ModuleAccessorValue>(v->module);
- } else if (script::NamedParameter* v = module->find_parameter(field)) {
+ } else if (script::NamedIValue* v = module->find_parameter(field)) {
+ return std::make_shared<script::SimpleValue>(
+ m.get_or_add_parameter(v->slot()));
+ } else if (script::NamedIValue* v = module->find_buffer(field)) {
return std::make_shared<script::SimpleValue>(
m.get_or_add_parameter(v->slot()));
} else if (script::Method* m = module->find_method(field)) {
void createTensorToParameterNameMap(
const script::Module& module,
const QualifiedNamePtr& prefix,
- std::unordered_map<at::Tensor*, QualifiedNamePtr>& result) {
+ std::unordered_map<IValue*, QualifiedNamePtr>& result) {
for (const auto& elem : module.get_parameters()) {
- const script::NamedParameter& param = elem.value();
- result[param.slot()] = QualifiedName::create(prefix, param.name);
+ const script::NamedIValue& param = elem.value();
+ result[param.slot()] = QualifiedName::create(prefix, param.name_);
+ }
+ for (const auto& elem : module.get_attributes()) {
+ const script::NamedIValue& param = elem.value();
+ result[param.slot()] = QualifiedName::create(prefix, param.name_);
}
for (const auto& elem : module.get_modules()) {
createTensorToParameterNameMap(
}
}
void printMethod(script::Method& method) {
- std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;
+ std::unordered_map<IValue*, QualifiedNamePtr> parameter_names;
createTensorToParameterNameMap(
method.owner(), QualifiedName::create("self"), parameter_names);
printMethod(method, parameter_names);
}
void printMethod(
script::Method& method,
- const std::unordered_map<at::Tensor*, QualifiedNamePtr>&
- parameter_names) {
- std::vector<std::string> param_names = fmap(
- method.params(),
- [&](at::Tensor* slot) { return parameter_names.at(slot)->str(); });
+ const std::unordered_map<IValue*, QualifiedNamePtr>&
+ extra_ivalue_names) {
+ std::vector<std::string> ivalue_names = fmap(
+ method.initial_ivalues(),
+ [&](IValue* slot) { return extra_ivalue_names.at(slot)->str(); });
const std::string& name = method.name();
Graph& graph = *method.graph();
auto defaults = fmap(
method.getSchema().arguments(),
[](const Argument& arg) { return arg.default_value(); });
- printFunction(graph, name, defaults, param_names);
+ printFunction(graph, name, defaults, ivalue_names);
}
void printModule(script::Module& module) {
- std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;
+ std::unordered_map<IValue*, QualifiedNamePtr> parameter_names;
createTensorToParameterNameMap(
module, QualifiedName::create("self"), parameter_names);
for (auto& method : module.get_methods()) {
py_dict[toPyObject(IValue{pair.first})] = toPyObject(IValue{pair.second});
}
return std::move(py_dict);
-
} else {
AT_ERROR("Missing cases in 'toPyObject'! File a bug report.");
}
const auto& param_list = module_->get_parameters().items();
for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
auto& param = *it;
- if (!param->is_buffer) {
- params.push_back(caller.get_or_add_parameter(param->slot()));
- }
+ params.push_back(caller.get_or_add_parameter(param->slot()));
}
auto list = caller.graph()->createList(TensorType::get(), params);
caller.graph()->insertNode(list);
std::shared_ptr<Module> module_;
};
+
// defines how modules/methods behave inside the script subset.
// for now this does not have any interaction with python.
// in the future, we will add the ability to resolve `self.foo` to python
// it adds a buffer 'training' to the model if one doesn't exist
// and then loads that parameter, casting it to bool
if (field == "training") {
- NamedParameter* v = module->find_parameter(field);
+ NamedIValue* v = module->find_buffer(field);
if (!v) {
py::object py_module = py::cast(module);
bool training = py::cast<bool>(py::getattr(py_module, "training"));
auto t =
autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
- module->register_parameter("training", std::move(t), true);
- v = module->find_parameter(field);
+ module->register_buffer("training", std::move(t));
+ v = module->find_buffer(field);
}
Value* the_tensor = m.get_or_add_parameter(v->slot());
Value* the_bool = m.graph()->insert(prim::Bool, {the_tensor});
return std::make_shared<ModuleValue>(v->module);
} else if (Method* v = module->find_method(field)) {
return std::make_shared<MethodValue>(shared_from_this(), *v);
- } else if (NamedParameter* v = module->find_parameter(field)) {
+ } else if (NamedIValue* v = module->find_parameter(field)) {
return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
+ } else if (NamedIValue* v = module->find_attribute(field)) {
+ return std::make_shared<SimpleValue>(
+ m.get_or_add_attribute(v->type, v->slot()));
}
+
// This can also be a call to a non-script module, or a plain
// python method. If so return this as a python value.
py::object py_module = py::cast(module);
}
static void gatherParametersAndBuffers(
- std::vector<at::Tensor*>& values,
+ std::vector<IValue*>& values,
const Module& m) {
for (auto& param : m.get_parameters()) {
values.push_back(param->slot());
}
+ for (auto& param : m.get_attributes()) {
+ if (param->type->isSubtypeOf(TensorType::get())) {
+ values.push_back(param->slot());
+ }
+ }
for (const auto& sub : m.get_modules()) {
gatherParametersAndBuffers(values, *sub->module);
}
},
py::return_value_policy::reference_internal)
.def("_register_parameter", &Module::register_parameter)
+ .def(
+ "_register_attribute",
+ [](Module& self, std::string name, TypePtr type, py::object value) {
+ self.register_attribute(name, type, toIValue(value, type));
+ })
.def("_register_module", &Module::register_module)
+ .def("_register_buffer", &Module::register_buffer)
.def("_set_parameter", &Module::set_parameter)
.def("_get_parameter", &Module::get_parameter)
+ .def("_get_buffer", &Module::get_buffer)
.def("_get_module", &Module::get_module)
.def(
"_get_modules",
py::tuple result(parameters.size());
for (size_t i = 0; i < parameters.size(); ++i) {
auto& p = parameters[i];
+ py::tuple r(2);
+ result[i] = std::make_tuple(
+ p.key(),
+ autograd::as_variable_ref(p->slot()->toTensor()));
+ }
+ return result;
+ })
+ .def(
+ "_get_attributes",
+ [](Module& self) -> py::tuple {
+ auto& attributes = self.get_attributes();
+ py::tuple result(attributes.size());
+ for (size_t i = 0; i < attributes.size(); ++i) {
+ auto& buffer = attributes[i];
py::tuple r(3);
+ IValue v = *buffer->slot();
result[i] = std::make_tuple(
- p.key(), autograd::as_variable_ref(*p->slot()), p->is_buffer);
+ buffer.key(),
+ buffer->type,
+ toPyObject(std::move(v)));
}
return result;
})
.def(
"_has_parameter",
- [](Module& self, const std::string& name) {
- if (auto r = self.find_parameter(name)) {
- return !r->is_buffer;
- }
- return false;
+ [](Module& self, const std::string& name) -> bool {
+ return self.find_parameter(name);
})
.def(
"_has_buffer",
- [](Module& self, const std::string& name) {
- if (auto r = self.find_parameter(name)) {
- return r->is_buffer;
- }
- return false;
+ [](Module& self, const std::string& name) -> bool {
+ return self.find_buffer(name);
})
.def(
"_has_module",
bool force_outplace) {
// prereq: Module's buffers and parameters are unique
// this was ensured in python before calling this function
- std::vector<at::Tensor*> parameters;
+ std::vector<IValue*> parameters;
gatherParametersAndBuffers(parameters, *self);
Stack inputs = toStack(input_tuple);
- for (at::Tensor* param : parameters) {
+ for (IValue* param : parameters) {
inputs.emplace_back(*param);
}
auto graph = tracer::createGraphByTracing(
std::vector<std::tuple<std::shared_ptr<Module>, std::string>>
params,
std::shared_ptr<Module> orig) {
- std::vector<at::Tensor*> member_inputs;
+ std::vector<IValue*> member_inputs;
for (auto& p : params) {
- NamedParameter* np =
+ NamedIValue* np =
std::get<0>(p)->find_parameter(std::get<1>(p));
+ if (np == nullptr) {
+ np = std::get<0>(p)->find_buffer(std::get<1>(p));
+ }
AT_ASSERT(np != nullptr);
member_inputs.push_back(np->slot());
}
Method* orig_method = orig->find_method(name);
- m->create_method(name, orig_method->graph()->copy(), member_inputs);
+ m->create_method(
+ name, orig_method->graph()->copy(), member_inputs);
});
py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
.def(
"propagate_and_assign_input_and_output_shapes",
&Method::propagate_and_assign_input_and_output_shapes)
- .def("params", &Method::params)
+ .def("initial_ivalues", &Method::initial_ivalues)
.def(
"graph_for",
[](py::args args, py::kwargs kwargs) {
// parameters to callee method (which become parameters to _this_ method
// if they were not already)
- for (at::Tensor* member : callee.params()) {
+ for (auto member : callee.initial_ivalues()) {
if (!caller) {
throw ErrorReport(loc)
- << " attempting to call a method with parameters from a raw graph. File a bug report";
+ << " attempting to call a method with parameters/attributes"
+ " from a raw graph. File a bug report";
}
matched_schema->inputs.push_back(caller->get_or_add_parameter(member));
}
// Then convert every of our parameters.
for (auto& parameter : parameters) {
// Need to access the `at::Tensor` as a `Variable` here.
- autograd::Variable variable = *parameter->slot();
+ autograd::Variable variable = parameter.value().slot()->toTensor();
at::Tensor data = variable.data();
// Use the data's original device or dtype if not supplied here.
auto new_data = data.to(
struct Module;
+using ModuleLookup = std::function<std::shared_ptr<Module>(
+ const std::vector<std::string>&)>;
+
struct Method {
Method(
Module* owner,
std::string name,
bool optimize,
std::shared_ptr<Graph> graph,
- std::vector<at::Tensor*> initial_members,
+ std::vector<IValue*> initial_members,
std::function<void(Method&)> method_creator)
: owner_(owner),
name_(std::move(name)),
graph_(std::move(graph)),
optimize(optimize),
- member_inputs(std::move(initial_members)),
+ initial_ivalues_(std::move(initial_members)),
method_creator(std::move(method_creator)) {
- AT_ASSERT(graph_->inputs().size() >= member_inputs.size());
- int i = graph_->inputs().size() - member_inputs.size();
- for (at::Tensor* member : member_inputs) {
- member_input_index[member] = i++;
+ AT_ASSERT(graph_->inputs().size() >= initial_ivalues_.size());
+ int i = graph_->inputs().size() - initial_ivalues_.size();
+ for (auto member : initial_ivalues_) {
+ initial_ivalue_index[member] = i++;
}
}
void run(Stack& stack) {
- for (at::Tensor* tp : member_inputs) {
- stack.emplace_back(*tp);
+ for (auto input : initial_ivalues_) {
+ push(stack, *input);
}
get_executor().run(stack);
}
}
std::shared_ptr<Graph> graph_for(Stack inputs) {
- for (at::Tensor* tp : member_inputs) {
+ for (auto tp : initial_ivalues_) {
inputs.emplace_back(*tp);
}
return get_executor().graphFor(inputs);
TORCH_API void ensure_defined();
size_t num_inputs() const {
- return graph()->inputs().size() - member_inputs.size();
+ return graph()->inputs().size() - initial_ivalues_.size();
+ }
+ TORCH_API Value* get_or_add_parameter(IValue* slot) {
+ AT_ASSERT(slot->isTensor());
+ return get_or_add_attribute(TensorType::get(), slot);
}
- TORCH_API Value* get_or_add_parameter(at::Tensor* slot) {
- auto it = member_input_index.find(slot);
- if (it != member_input_index.end()) {
+
+ TORCH_API Value* get_or_add_attribute(TypePtr type, IValue* slot) {
+ auto it = initial_ivalue_index.find(slot);
+ if (it != initial_ivalue_index.end()) {
return graph()->inputs().at(it->second);
}
- // add it as a new parameter
- member_inputs.push_back(slot);
- member_input_index[slot] = graph()->inputs().size();
- return graph()->addInput();
+ initial_ivalues_.push_back(slot);
+ initial_ivalue_index[slot] = graph()->inputs().size();
+ return graph()->addInput()->setType(type);
}
std::shared_ptr<Graph> propagate_shapes(
bool with_grad = false) {
auto retval = graph_->copy();
Stack stack;
- stack.reserve(inputs.size() + member_inputs.size());
+ stack.reserve(inputs.size() + initial_ivalues_.size());
for (at::Tensor& i : inputs) {
stack.emplace_back(std::move(i));
}
- for (at::Tensor* inp : member_inputs) {
+ for (IValue* inp : initial_ivalues_) {
stack.push_back(*inp);
}
const auto size = stack.size();
bool with_grad = false,
bool propagate = true) {
auto retval = graph_->copy();
- for (auto inp : member_inputs) {
- inputs.push_back(*inp);
+ for (auto inp : initial_ivalues_) {
+ if (inp->isTensor()) {
+ inputs.push_back(inp->toTensor());
+ }
}
if (propagate) {
setInputTypes(
return retval;
}
- std::vector<at::Tensor*> params() const {
- return member_inputs;
+ const std::vector<IValue*>& initial_ivalues() const {
+ return initial_ivalues_;
}
Method& setSchema(FunctionSchema schema_) {
bool optimize;
GraphExecutor executor; // for execution
- // member_inputs are a list of additional arguments appended to graph that are
- // inputs that come from the members of the Module or its submodules.
+ // initial_ivalues are a list of additional arguments appended to graph
+ // that are inputs that come from the members of the Module or its submodules.
// each is a pointer to a slot in the module that owns this parameter
// parameters and submodules can only be _added_ to script Modules to ensure
// these pointers always stay valid
- std::vector<at::Tensor*> member_inputs;
+ std::vector<IValue*> initial_ivalues_;
- // map from a at::Tensor* in member_inputs to the offset it appears at
+ // map from a IValue* in initial_ivalues to the offset it appears at
// in graph. used to accelerate get_or_add_parameter
- std::unordered_map<at::Tensor*, size_t> member_input_index;
+ std::unordered_map<IValue*, size_t> initial_ivalue_index;
// TODO: support that case where we allow _writes_ to parameters from
// compiled functions.
std::shared_ptr<Module> module;
};
-struct NamedParameter {
- NamedParameter(std::string name, at::Tensor tensor, bool is_buffer)
- : name(std::move(name)),
- is_buffer(is_buffer),
- parameter(torch::make_unique<at::Tensor>(std::move(tensor))) {}
+struct NamedIValue {
+ NamedIValue(std::string name, TypePtr type, IValue ivalue)
+ : name_(name),
+ type(type),
+ ivalue(torch::make_unique<IValue>(std::move(ivalue))) {}
- const std::string name;
- bool is_buffer; // buffers are part of the module state but
- // are not modified by optimizers during SGD
- at::Tensor* slot() const {
- return parameter.get();
+ IValue* slot() const {
+ return ivalue.get();
}
-
- private:
- // the extra level of indirection allows Methods to safely store pointers
- // to the slots where parameters are kept while also allow parameters
- // to be reassigned
- std::unique_ptr<at::Tensor> parameter;
+ const std::string name_;
+ const TypePtr type;
+ std::unique_ptr<IValue> ivalue;
};
struct Module {
Module()
: modules("Module"),
parameters("Parameter"),
+ attributes("Attributes"),
methods("Method"),
optimize(true) {}
return get_method("forward")(std::move(inputs));
}
+ void register_buffer(const std::string& name, autograd::Variable v) {
+ if (auto b = attributes.find(name)) {
+ AT_ASSERT(b->type->isSubtypeOf(TensorType::get()));
+ *b->slot() = v;
+ return;
+ }
+ attributes.insert(name, NamedIValue(name, TensorType::get(), std::move(v)));
+ }
void register_parameter(
const std::string& name,
autograd::Variable v,
bool is_buffer) {
+ if (is_buffer) {
+ register_buffer(name, std::move(v));
+ return;
+ }
if (auto p = parameters.find(name)) {
*p->slot() = v;
- p->is_buffer = is_buffer;
return;
}
- parameters.insert(name, NamedParameter(name, std::move(v), is_buffer));
+ parameters.insert(name, NamedIValue(name, TensorType::get(), std::move(v)));
+ }
+ void register_attribute(
+ const std::string& name,
+ const TypePtr type,
+ IValue ivalue) {
+ attributes.insert(name, NamedIValue(name, type, ivalue));
}
void register_module(
const std::string& name,
Method& create_method(
const std::string& name,
std::shared_ptr<Graph> graph,
- std::vector<at::Tensor*> member_inputs) {
+ std::vector<IValue*> member_inputs) {
AT_ASSERT(graph);
std::unique_ptr<Method> method(new Method(
this,
return *methods.insert(name, std::move(method));
}
- at::Tensor* parameter_slot(const std::string& name) const {
+ IValue* parameter_slot(const std::string& name) const {
return parameters[name].slot();
}
}
autograd::Variable get_parameter(const std::string& name) const {
- return autograd::as_variable_ref(*parameter_slot(name));
+ return autograd::as_variable_ref(parameter_slot(name)->toTensor());
+ }
+ autograd::Variable get_buffer(const std::string& name) const {
+ return autograd::as_variable_ref(attributes.find(name)->slot()->toTensor());
}
// each module owns its method. The reference returned here
const torch::OrderedDict<std::string, NamedModule>& get_modules() const {
return modules;
}
- const torch::OrderedDict<std::string, NamedParameter>& get_parameters()
+ const torch::OrderedDict<std::string, NamedIValue>& get_parameters()
const {
return parameters;
}
+ const torch::OrderedDict<std::string, NamedIValue>& get_attributes()
+ const {
+ return attributes;
+ }
const torch::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods()
const {
return methods;
}
- NamedParameter* find_parameter(const std::string& name) {
+ NamedIValue* find_parameter(const std::string& name) {
return parameters.find(name);
}
+ NamedIValue* find_attribute(const std::string& name) {
+ return attributes.find(name);
+ }
+ NamedIValue* find_buffer(const std::string& name) {
+ auto b = attributes.find(name);
+ if (b && b->type->isSubtypeOf(TensorType::get())) {
+ return b;
+ }
+ return nullptr;
+ }
NamedModule* find_module(const std::string& name) {
return modules.find(name);
}
for (auto& submod : get_modules()) {
submod->module->train(on);
}
- register_parameter("training", torch::tensor(on ? 1 : 0, at::kLong), /*is_buffer=*/true);
+ register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
}
/// Calls train(false) to enable "eval" mode.
/// Do not override this method, override `train()` instead.
}
/// True if the module is in training mode.
bool is_training() {
- if (auto p = find_parameter("training")) {
- return p->slot()->item<int64_t>() == 1;
+ if (auto p = find_buffer("training")) {
+ return p->slot()->toTensor().item<int64_t>() == 1;
}
// We are in training mode by default
return true;
const ExtraFilesMap& extra_files = ExtraFilesMap());
void copy_into(
- std::function<std::shared_ptr<Module>(std::vector<std::string>)>
- module_lookup,
+ ModuleLookup module_lookup,
// parameter_remap is needed when a parent module uses a parameter of a
// submodule
- std::unordered_map<at::Tensor*, at::Tensor*>& parameter_remap,
+ std::unordered_map<IValue*, IValue*>& parameter_remap,
std::vector<std::string> names = {}) const {
auto curr = module_lookup(names);
for (auto& kv : parameters) {
curr->register_parameter(
- kv.key(), *kv.value().slot(), kv.value().is_buffer);
+ kv.key(),
+ kv.value().slot()->toTensor(),
+ /*is_buffer=*/false);
parameter_remap[kv.value().slot()] = curr->parameter_slot(kv.key());
}
+ for (auto& kv : attributes) {
+ if (!kv.value().type->isSubtypeOf(TensorType::get())) {
+ continue;
+ }
+ curr->register_buffer(
+ kv.key(),
+ kv.value().slot()->toTensor());
+ parameter_remap[kv.value().slot()] = curr->find_buffer(kv.key())->slot();
+ }
for (auto& kv : modules) {
names.push_back(kv.key());
// Submodules must be translated first, otherwise parameter_remap entries
names.pop_back();
}
for (auto& kv : methods) {
- std::vector<at::Tensor*> params;
- for (auto& p : kv.value()->params()) {
- params.push_back(parameter_remap.at(p));
+ std::vector<IValue*> initial_ivalues;
+ for (auto& p : kv.value()->initial_ivalues()) {
+ initial_ivalues.push_back(parameter_remap.at(p));
}
- curr->create_method(kv.key(), kv.value()->graph()->copy(), params);
+ curr->create_method(
+ kv.key(), kv.value()->graph()->copy(), initial_ivalues);
}
}
const c10::optional<at::ScalarType>& dtype,
bool non_blocking);
- // invariant: to ensure member_inputs of Methods stay valid,
+ // invariant: to ensure initial_ivalues of Methods stay valid,
// it is only legal to _add_ new modules and parameters.
- // removing them will allow member_inputs to point to invalid parameters
+ // removing them will allow initial_ivalues to point to invalid parameters
// no such restriction exists for methods
torch::OrderedDict<std::string, NamedModule> modules;
- torch::OrderedDict<std::string, NamedParameter> parameters;
+ torch::OrderedDict<std::string, NamedIValue> parameters;
+ torch::OrderedDict<std::string, NamedIValue> attributes;
torch::OrderedDict<std::string, std::unique_ptr<Method>> methods;
bool optimize;
};
super(OrderedParameterDict, self).__init__(module)
def items(self):
- return [(name, param) for name, param, is_buffer
- in self.module._get_parameters()
- if not is_buffer]
+ return [(name, param) for name, param in self.module._get_parameters()]
def __setitem__(self, k, v):
self.module._register_parameter(k, v, False)
super(OrderedBufferDict, self).__init__(module)
def items(self):
- return [(name, param) for name, param, is_buffer
- in self.module._get_parameters()
- if is_buffer]
+ return [(name, param) for name, _, param in
+ self.module._get_attributes() if isinstance(param, torch.Tensor)]
def __setitem__(self, k, v):
- self.module._register_parameter(k, v, True)
+ self.module._register_buffer(k, v)
def __contains__(self, k):
return self.module._has_buffer(k)
def __getitem__(self, k):
if k not in self:
raise KeyError(k)
- return self.module._get_parameter(k)
+ return self.module._get_buffer(k)
# base types that can be constants
# in addition, tuples and lists of these base types are also considered constants
if attr == 'training':
if self._has_buffer('training'):
self.__dict__['training'] = value
- self._get_parameter('training').fill_(int(value))
+ self._get_buffer('training').fill_(int(value))
return
+ if isinstance(value, Attribute):
+ the_type = torch.jit.annotations.ann_to_type(value.type)
+ self._register_attribute(attr, the_type, value.value)
+ return
return super(ScriptModule, self).__setattr__(attr, value)
if hasattr(self, attr):
# noop in python
return the_value
+
+class Attribute(object):
+ def __init__(self, value, the_type):
+ self.value = value
+ self.type = the_type
+
+
if not torch._C._jit_init():
raise RuntimeError("JIT initialization failed")
for method_name in module._method_names():
method = module._get_method(method_name)
param_list = []
- for param in method.params():
+ for param in method.initial_ivalues():
param_list.append(param_dict[param])
replica._copy_method(method_name, param_list, module)
graph = method.propagate_and_assign_input_and_output_shapes(
args, example_outputs, False, propagate)
# Erase number types to bring the graph to a pre-NumberType state
- params = method.params()
+ params = method.initial_ivalues()
except AttributeError:
# TODO: just trace it
raise RuntimeError('\'forward\' method must be a script method')