self.checkScript(foo, [torch.rand(2, 3)])
+ def test_nn_LSTM_with_layers(self):
+ class M(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M, self).__init__()
+ self.rnn = nn.LSTM(2, 3, 2)
+
+ @torch.jit.script_method
+ def forward(self, x, lengths, h0, c0):
+ return self.rnn(x, (h0, c0))[0]
+
+ class Eager(torch.nn.Module):
+ def __init__(self):
+ super(Eager, self).__init__()
+ self.rnn = nn.LSTM(2, 3, 2)
+
+ def forward(self, x, lengths, h0, c0):
+ return self.rnn(x, (h0, c0))[0]
+
+ inputs = (torch.randn(1, 1, 2), torch.LongTensor([7]), torch.randn(2, 1, 3), torch.randn(2, 1, 3))
+ eager_out = self.runAndSaveRNG(lambda: Eager()(*inputs), ())[0]
+ script_out = self.runAndSaveRNG(lambda: M()(*inputs), ())[0]
+
+ self.assertEqual(eager_out, script_out)
+
def test_nn_LSTM(self):
input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)])
return fn
-def _parameter_list(fn):
+def _parameter_list(parameter_names_fn):
"""
Decorator to denote that a function returns a list of all the parameters
in a module
"""
- fn._is_parameter_list = True
- return fn
+ def decorator(fn):
+ fn._parameter_names_fn = parameter_names_fn
+ return fn
+
+ return decorator
try:
return std::make_shared<OverloadedFunctionValue>(
self_, py::cast<std::vector<std::string>>(overloads));
}
-
if (py::object attr = py::getattr(py_module, field.c_str(), py::none())) {
if (py::isinstance<py::function>(attr) &&
- py::hasattr(attr, "_is_parameter_list") &&
- py::cast<bool>(py::getattr(attr, "_is_parameter_list"))) {
+ py::hasattr(attr, "_parameter_names_fn")) {
+ // Fetch the names of the parameters in the list so they're in the
+ // right order
+ auto fn_self = py::getattr(attr, "__self__");
+ auto param_names = py::getattr(attr, "_parameter_names_fn")(fn_self);
+
Graph& g = *m.graph();
// Add all module parameters as inputs to the graph
std::vector<Value*> params;
- const auto& param_list = module_->get_parameters();
- for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
- auto& param = *it;
- params.emplace_back(g.insertGetAttr(self_, param.name()));
+ for (auto name : param_names) {
+ params.emplace_back(g.insertGetAttr(self_, py::str(name)));
}
auto list = g.insertNode(g.createTuple(params))->output();
return std::make_shared<ConstantParameterList>(list);
for weight in self.parameters():
init.uniform_(weight, -stdv, stdv)
- @_parameter_list
- def get_flat_weights(self):
+ def _get_flat_weights_names(self):
+ return [weight for weights in self._all_weights for weight in weights]
+
+ @_parameter_list(_get_flat_weights_names)
+ def _get_flat_weights(self):
return self._flat_weights
@weak_script_method
self.check_forward_args(input, hx, batch_sizes)
_impl = _rnn_impls[self.mode]
if batch_sizes is None:
- result = _impl(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
+ result = _impl(input, hx, self._get_flat_weights(), self.bias, self.num_layers,
self.dropout, self.training, self.bidirectional, self.batch_first)
else:
- result = _impl(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
+ result = _impl(input, batch_sizes, hx, self._get_flat_weights(), self.bias,
self.num_layers, self.dropout, self.training, self.bidirectional)
output = result[0]
hidden = result[1]
self.check_forward_args(input, hx, batch_sizes)
if batch_sizes is None:
- result = _VF.lstm(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
+ result = _VF.lstm(input, hx, self._get_flat_weights(), self.bias, self.num_layers,
self.dropout, self.training, self.bidirectional, self.batch_first)
else:
- result = _VF.lstm(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
+ result = _VF.lstm(input, batch_sizes, hx, self._get_flat_weights(), self.bias,
self.num_layers, self.dropout, self.training, self.bidirectional)
output = result[0]
hidden = result[1:]