Respect order of Parameters in rnn.py (#18198)
authorDavid Riazati <davidriazati@fb.com>
Thu, 18 Apr 2019 18:07:45 +0000 (11:07 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 18 Apr 2019 18:18:20 +0000 (11:18 -0700)
Summary:
Previously to get a list of parameters this code was just putting them in the reverse order in which they were defined, which is not always right. This PR allows parameter lists to define the order themselves. To do this parameter lists need to have a corresponding function that provides the names of the parameters.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18198

Differential Revision: D14966270

Pulled By: driazati

fbshipit-source-id: 59331aa59408660069785906304b2088c19534b2

test/test_jit.py
torch/_jit_internal.py
torch/csrc/jit/script/init.cpp
torch/nn/modules/rnn.py

index 4b455e7..da4a20f 100644 (file)
@@ -11073,6 +11073,30 @@ a")
 
         self.checkScript(foo, [torch.rand(2, 3)])
 
+    def test_nn_LSTM_with_layers(self):
+        class M(torch.jit.ScriptModule):
+            def __init__(self):
+                super(M, self).__init__()
+                self.rnn = nn.LSTM(2, 3, 2)
+
+            @torch.jit.script_method
+            def forward(self, x, lengths, h0, c0):
+                return self.rnn(x, (h0, c0))[0]
+
+        class Eager(torch.nn.Module):
+            def __init__(self):
+                super(Eager, self).__init__()
+                self.rnn = nn.LSTM(2, 3, 2)
+
+            def forward(self, x, lengths, h0, c0):
+                return self.rnn(x, (h0, c0))[0]
+
+        inputs = (torch.randn(1, 1, 2), torch.LongTensor([7]), torch.randn(2, 1, 3), torch.randn(2, 1, 3))
+        eager_out = self.runAndSaveRNG(lambda: Eager()(*inputs), ())[0]
+        script_out = self.runAndSaveRNG(lambda: M()(*inputs), ())[0]
+
+        self.assertEqual(eager_out, script_out)
+
     def test_nn_LSTM(self):
         input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)])
 
index 7bf3f60..8c2cff4 100644 (file)
@@ -165,13 +165,16 @@ def ignore(fn):
     return fn
 
 
-def _parameter_list(fn):
+def _parameter_list(parameter_names_fn):
     """
     Decorator to denote that a function returns a list of all the parameters
     in a module
     """
-    fn._is_parameter_list = True
-    return fn
+    def decorator(fn):
+        fn._parameter_names_fn = parameter_names_fn
+        return fn
+
+    return decorator
 
 
 try:
index f701cf8..96d1e5b 100644 (file)
@@ -375,18 +375,19 @@ struct ModuleValue : public SugaredValue {
       return std::make_shared<OverloadedFunctionValue>(
           self_, py::cast<std::vector<std::string>>(overloads));
     }
-
     if (py::object attr = py::getattr(py_module, field.c_str(), py::none())) {
       if (py::isinstance<py::function>(attr) &&
-          py::hasattr(attr, "_is_parameter_list") &&
-          py::cast<bool>(py::getattr(attr, "_is_parameter_list"))) {
+          py::hasattr(attr, "_parameter_names_fn")) {
+        // Fetch the names of the parameters in the list so they're in the
+        // right order
+        auto fn_self = py::getattr(attr, "__self__");
+        auto param_names = py::getattr(attr, "_parameter_names_fn")(fn_self);
+
         Graph& g = *m.graph();
         // Add all module parameters as inputs to the graph
         std::vector<Value*> params;
-        const auto& param_list = module_->get_parameters();
-        for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
-          auto& param = *it;
-          params.emplace_back(g.insertGetAttr(self_, param.name()));
+        for (auto name : param_names) {
+          params.emplace_back(g.insertGetAttr(self_, py::str(name)));
         }
         auto list = g.insertNode(g.createTuple(params))->output();
         return std::make_shared<ConstantParameterList>(list);
index e19892d..ee7a9dc 100644 (file)
@@ -132,8 +132,11 @@ class RNNBase(Module):
         for weight in self.parameters():
             init.uniform_(weight, -stdv, stdv)
 
-    @_parameter_list
-    def get_flat_weights(self):
+    def _get_flat_weights_names(self):
+        return [weight for weights in self._all_weights for weight in weights]
+
+    @_parameter_list(_get_flat_weights_names)
+    def _get_flat_weights(self):
         return self._flat_weights
 
     @weak_script_method
@@ -204,10 +207,10 @@ class RNNBase(Module):
         self.check_forward_args(input, hx, batch_sizes)
         _impl = _rnn_impls[self.mode]
         if batch_sizes is None:
-            result = _impl(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
+            result = _impl(input, hx, self._get_flat_weights(), self.bias, self.num_layers,
                            self.dropout, self.training, self.bidirectional, self.batch_first)
         else:
-            result = _impl(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
+            result = _impl(input, batch_sizes, hx, self._get_flat_weights(), self.bias,
                            self.num_layers, self.dropout, self.training, self.bidirectional)
         output = result[0]
         hidden = result[1]
@@ -515,10 +518,10 @@ class LSTM(RNNBase):
 
         self.check_forward_args(input, hx, batch_sizes)
         if batch_sizes is None:
-            result = _VF.lstm(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
+            result = _VF.lstm(input, hx, self._get_flat_weights(), self.bias, self.num_layers,
                               self.dropout, self.training, self.bidirectional, self.batch_first)
         else:
-            result = _VF.lstm(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
+            result = _VF.lstm(input, batch_sizes, hx, self._get_flat_weights(), self.bias,
                               self.num_layers, self.dropout, self.training, self.bidirectional)
         output = result[0]
         hidden = result[1:]