WAR for self.training (#14719)
authorZachary DeVito <zdevito@fb.com>
Tue, 4 Dec 2018 04:29:51 +0000 (20:29 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 4 Dec 2018 04:32:16 +0000 (20:32 -0800)
Summary:
To enable self.training in script modules, this PR automatically adds a buffer called 'training' if a script method requests self.training. Assignment to self.training is overloaded to assign both to the boolean property and the tensor value.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14719

Differential Revision: D13310569

Pulled By: zdevito

fbshipit-source-id: 406387bb602f8ce5794eeff37642863c75928be5

test/test_jit.py
torch/csrc/jit/script/init.cpp
torch/jit/__init__.py

index cf10230..8dd70d3 100644 (file)
@@ -3016,6 +3016,25 @@ class TestScript(JitTestCase):
 
         return ge
 
+    def test_training_param(self):
+        class What(torch.jit.ScriptModule):
+            @torch.jit.script_method
+            def forward(self, x):
+                # type: (int) -> int
+                if self.training:
+                    r = x
+                else:
+                    r = x + 4
+                # check double use of training
+                if self.training:
+                    r = r + 1
+                return r
+
+        w = What()
+        self.assertEqual(4, w(3))
+        w.train(False)
+        self.assertEqual(7, w(3))
+
     def test_jitter_bug(self):
         @torch.jit.script
         def fn2(input, kernel_size):
index f0de96d..b4612b4 100644 (file)
@@ -211,6 +211,26 @@ struct ModuleValue : public SugaredValue {
 
   // select an attribute on it, e.g. `this.field`
   std::shared_ptr<SugaredValue> attr(SourceRange loc, Method & m, const std::string& field) override {
+    // workaround to make self.training work
+    // it adds a buffer 'training' to the model if one doesn't exist
+    // and then loads that parameter, casting it to bool
+    if (field == "training") {
+      NamedParameter* v = module->find_parameter(field);
+      if (!v) {
+        py::object py_module = py::cast(module);
+        bool training = py::cast<bool>(py::getattr(py_module, "training"));
+        auto t = autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
+        module->register_parameter("training", std::move(t), true);
+        v = module->find_parameter(field);
+      }
+      Value* the_tensor = m.get_or_add_parameter(v->slot());
+      Value* the_bool =
+          m.graph()
+              ->insertNode(m.graph()->createTensorToBool(the_tensor))
+              ->output();
+      return std::make_shared<SimpleValue>(the_bool);
+    }
+
     if(NamedModule* v = module->find_module(field)) {
       return std::make_shared<ModuleValue>(v->module);
     } else if(Method* v = module->find_method(field)) {
index 7b5a98d..c9e1435 100644 (file)
@@ -1104,7 +1104,13 @@ if _enabled:
                 if isinstance(value, Module) and _is_weak_type(type(value)):
                     # Compile weak script module
                     value = _make_strong(value)
+                if attr == 'training':
+                    if self._has_buffer('training'):
+                        self.__dict__['training'] = value
+                        self._get_parameter('training').fill_(int(value))
+                        return
                 return super(ScriptModule, self).__setattr__(attr, value)
+
             if hasattr(self, attr):
                 raise RuntimeError("attempting to re-assign constant '{}'".format(attr))
             if isinstance(value, ModuleList):