Better error when module attr is used (#18164)
authorElias Ellison <eellison@fb.com>
Sat, 23 Mar 2019 03:13:02 +0000 (20:13 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 23 Mar 2019 03:22:27 +0000 (20:22 -0700)
Summary:
Adds a suggestion to add to __constants__ when a torch.nn.Module attr is accessed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18164

Differential Revision: D14580060

Pulled By: eellison

fbshipit-source-id: 0c5adc21d7341a5691d4b45930947cb1ba84c8e8

test/test_jit.py
torch/csrc/jit/script/init.cpp

index 4a9e555..dc10320 100644 (file)
@@ -6429,6 +6429,21 @@ a")
         with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
             a = M(nn.ModuleList([nn.ReLU()]))
 
+    def test_attr_module_constants_error(self):
+        class M2(torch.jit.ScriptModule):
+            def __init__(self, mod_list):
+                super(M2, self).__init__(False)
+                self.mods = mod_list
+
+            @torch.jit.script_method
+            def forward(self, v):
+                return self.mods.forward(x)
+
+        with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
+            M2(nn.Sequential(nn.ReLU()))
+        with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
+            M2(nn.ModuleList([nn.ReLU()]))
+
     def test_script_sequential_for(self):
         class Sub(torch.jit.ScriptModule):
             def __init__(self):
index 86ff58b..5ecebf3 100644 (file)
@@ -169,6 +169,14 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
     return ss.str();
   }
 
+  void checkForAddToConstantsError(std::stringstream& ss) {
+    auto nn = py::module::import("torch.nn");
+    if (py::isinstance(self, nn.attr("ModuleList")) ||
+        py::isinstance(self, nn.attr("Sequential"))) {
+      ss << ". Did you forget to add it to __constants__? ";
+    }
+  }
+
   std::vector<std::shared_ptr<SugaredValue>> asTuple(
       const SourceRange& loc,
       Method& m,
@@ -176,11 +184,18 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
     const std::string type_str = typeString(self);
     std::stringstream ss;
     ss << kind() << " cannot be used as a tuple";
-    auto nn = py::module::import("torch.nn");
-    if (py::isinstance(self, nn.attr("ModuleList")) ||
-        py::isinstance(self, nn.attr("Sequential"))) {
-      ss << ". Did you forget to add it to __constants__? ";
-    }
+    checkForAddToConstantsError(ss);
+    throw ErrorReport(loc) << ss.str();
+  }
+
+  std::shared_ptr<SugaredValue> attr(
+      const SourceRange& loc,
+      Method& m,
+      const std::string& field) override {
+    const std::string type_str = typeString(self);
+    std::stringstream ss;
+    ss << "attribute lookup is not defined on " << kind();
+    checkForAddToConstantsError(ss);
     throw ErrorReport(loc) << ss.str();
   }