return ge
+ def test_training_param(self):
+ class What(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x):
+ # type: (int) -> int
+ if self.training:
+ r = x
+ else:
+ r = x + 4
+ # check double use of training
+ if self.training:
+ r = r + 1
+ return r
+
+ w = What()
+ self.assertEqual(4, w(3))
+ w.train(False)
+ self.assertEqual(7, w(3))
+
def test_jitter_bug(self):
@torch.jit.script
def fn2(input, kernel_size):
// select an attribute on it, e.g. `this.field`
std::shared_ptr<SugaredValue> attr(SourceRange loc, Method & m, const std::string& field) override {
+ // workaround to make self.training work
+ // it adds a buffer 'training' to the model if one doesn't exist
+ // and then loads that parameter, casting it to bool
+ if (field == "training") {
+ NamedParameter* v = module->find_parameter(field);
+ if (!v) {
+ py::object py_module = py::cast(module);
+ bool training = py::cast<bool>(py::getattr(py_module, "training"));
+ auto t = autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
+ module->register_parameter("training", std::move(t), true);
+ v = module->find_parameter(field);
+ }
+ Value* the_tensor = m.get_or_add_parameter(v->slot());
+ Value* the_bool =
+ m.graph()
+ ->insertNode(m.graph()->createTensorToBool(the_tensor))
+ ->output();
+ return std::make_shared<SimpleValue>(the_bool);
+ }
+
if(NamedModule* v = module->find_module(field)) {
return std::make_shared<ModuleValue>(v->module);
} else if(Method* v = module->find_method(field)) {
if isinstance(value, Module) and _is_weak_type(type(value)):
# Compile weak script module
value = _make_strong(value)
+ if attr == 'training':
+ if self._has_buffer('training'):
+ self.__dict__['training'] = value
+ self._get_parameter('training').fill_(int(value))
+ return
return super(ScriptModule, self).__setattr__(attr, value)
+
if hasattr(self, attr):
raise RuntimeError("attempting to re-assign constant '{}'".format(attr))
if isinstance(value, ModuleList):