"""Create a FisherEstimator object.
Args:
- variables: A list of the variables for which to estimate the Fisher. This
- must match the variables registered in layer_collection (if it is not
- None).
+ variables: A `list` of variables or `callable` which returns the variables
+ for which to estimate the Fisher. This must match the variables
+ registered in layer_collection (if it is not None).
cov_ema_decay: The decay factor used when calculating the covariance
estimate moving averages.
damping: float. The damping factor used to stabilize training due to
@property
def variables(self):
- return self._variables
+ if callable(self._variables):
+ return self._variables()
+ else:
+ return self._variables
@property
def damping(self):
ValueError: If momentum is non-zero and momentum_type is not 'regular'
or 'adam'.
"""
-
- variables = var_list
- if variables is None:
- variables = tf_variables.trainable_variables()
-
# Parameters to be passed to the Fisher estimator:
- self._variables = variables
+ self._variables = var_list or tf_variables.trainable_variables
self._cov_ema_decay = cov_ema_decay
self._layers = layer_collection
self._estimation_mode = estimation_mode
@property
def variables(self):
- return self._variables
+ return self._fisher_est.variables
@property
def damping(self):
else:
kwargs["var_list"] = kwargs.get("var_list") or self.variables
var_list = kwargs["var_list"]
+
if set(var_list) != set(self.variables):
raise ValueError("var_list doesn't match with set of Fisher-estimating "
"variables.")