From: David Riazati Date: Thu, 18 Apr 2019 18:07:45 +0000 (-0700) Subject: Respect order of Parameters in rnn.py (#18198) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~159 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f5435634b4a1de4bf1c7cdf4d4f9e44418960572;p=platform%2Fupstream%2Fpytorch.git Respect order of Parameters in rnn.py (#18198) Summary: Previously to get a list of parameters this code was just putting them in the reverse order in which they were defined, which is not always right. This PR allows parameter lists to define the order themselves. To do this parameter lists need to have a corresponding function that provides the names of the parameters. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18198 Differential Revision: D14966270 Pulled By: driazati fbshipit-source-id: 59331aa59408660069785906304b2088c19534b2 --- diff --git a/test/test_jit.py b/test/test_jit.py index 4b455e7..da4a20f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -11073,6 +11073,30 @@ a") self.checkScript(foo, [torch.rand(2, 3)]) + def test_nn_LSTM_with_layers(self): + class M(torch.jit.ScriptModule): + def __init__(self): + super(M, self).__init__() + self.rnn = nn.LSTM(2, 3, 2) + + @torch.jit.script_method + def forward(self, x, lengths, h0, c0): + return self.rnn(x, (h0, c0))[0] + + class Eager(torch.nn.Module): + def __init__(self): + super(Eager, self).__init__() + self.rnn = nn.LSTM(2, 3, 2) + + def forward(self, x, lengths, h0, c0): + return self.rnn(x, (h0, c0))[0] + + inputs = (torch.randn(1, 1, 2), torch.LongTensor([7]), torch.randn(2, 1, 3), torch.randn(2, 1, 3)) + eager_out = self.runAndSaveRNG(lambda: Eager()(*inputs), ())[0] + script_out = self.runAndSaveRNG(lambda: M()(*inputs), ())[0] + + self.assertEqual(eager_out, script_out) + def test_nn_LSTM(self): input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)]) diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 7bf3f60..8c2cff4 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -165,13 +165,16 @@ def ignore(fn): return fn -def _parameter_list(fn): +def _parameter_list(parameter_names_fn): """ Decorator to denote that a function returns a list of all the parameters in a module """ - fn._is_parameter_list = True - return fn + def decorator(fn): + fn._parameter_names_fn = parameter_names_fn + return fn + + return decorator try: diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index f701cf8..96d1e5b 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -375,18 +375,19 @@ struct ModuleValue : public SugaredValue { return std::make_shared( self_, py::cast>(overloads)); } - if (py::object attr = py::getattr(py_module, field.c_str(), py::none())) { if (py::isinstance(attr) && - py::hasattr(attr, "_is_parameter_list") && - py::cast(py::getattr(attr, "_is_parameter_list"))) { + py::hasattr(attr, "_parameter_names_fn")) { + // Fetch the names of the parameters in the list so they're in the + // right order + auto fn_self = py::getattr(attr, "__self__"); + auto param_names = py::getattr(attr, "_parameter_names_fn")(fn_self); + Graph& g = *m.graph(); // Add all module parameters as inputs to the graph std::vector params; - const auto& param_list = module_->get_parameters(); - for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) { - auto& param = *it; - params.emplace_back(g.insertGetAttr(self_, param.name())); + for (auto name : param_names) { + params.emplace_back(g.insertGetAttr(self_, py::str(name))); } auto list = g.insertNode(g.createTuple(params))->output(); return std::make_shared(list); diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index e19892d..ee7a9dc 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -132,8 +132,11 @@ class RNNBase(Module): for weight in self.parameters(): init.uniform_(weight, -stdv, stdv) - @_parameter_list - def get_flat_weights(self): + def _get_flat_weights_names(self): + return [weight for weights in self._all_weights for weight in weights] + + @_parameter_list(_get_flat_weights_names) + def _get_flat_weights(self): return self._flat_weights @weak_script_method @@ -204,10 +207,10 @@ class RNNBase(Module): self.check_forward_args(input, hx, batch_sizes) _impl = _rnn_impls[self.mode] if batch_sizes is None: - result = _impl(input, hx, self.get_flat_weights(), self.bias, self.num_layers, + result = _impl(input, hx, self._get_flat_weights(), self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, self.batch_first) else: - result = _impl(input, batch_sizes, hx, self.get_flat_weights(), self.bias, + result = _impl(input, batch_sizes, hx, self._get_flat_weights(), self.bias, self.num_layers, self.dropout, self.training, self.bidirectional) output = result[0] hidden = result[1] @@ -515,10 +518,10 @@ class LSTM(RNNBase): self.check_forward_args(input, hx, batch_sizes) if batch_sizes is None: - result = _VF.lstm(input, hx, self.get_flat_weights(), self.bias, self.num_layers, + result = _VF.lstm(input, hx, self._get_flat_weights(), self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, self.batch_first) else: - result = _VF.lstm(input, batch_sizes, hx, self.get_flat_weights(), self.bias, + result = _VF.lstm(input, batch_sizes, hx, self._get_flat_weights(), self.bias, self.num_layers, self.dropout, self.training, self.bidirectional) output = result[0] hidden = result[1:]