**defaults: Any,
):
self.params = list(params)
- self.local_optimizer = optimizer_class(iter(self.params), **defaults)
- self.param_groups = self.local_optimizer.param_groups
+ self.optim = optimizer_class(iter(self.params), **defaults)
+ self.param_groups = self.optim.param_groups
self.averager = averager
+ @property
+ def state(self):
+ return self.optim.state
+
+ def __repr__(self):
+ return self.optim.__repr__()
+
+ def state_dict(self):
+ return self.optim.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self.optim.load_state_dict(state_dict)
+
def step(self):
r"""
Performs a single optimization step (parameter update).
"""
- self.local_optimizer.step()
+ self.optim.step()
self.averager.average_parameters(iter(self.params))
def zero_grad(self):
- self.local_optimizer.zero_grad()
-
- def state_dict(self):
- raise NotImplementedError
+ self.optim.zero_grad()
- def load_state_dict(self, state_dict):
- raise NotImplementedError
+ def add_param_group(self, param_group):
+ self.optim.add_param_group(param_group)