The training model need not be built when the kfac optimizer is initialized so the
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 15:48:16 +0000 (08:48 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 15:50:41 +0000 (08:50 -0700)
self._variables will be empty list.  So pass a function which returns list of trainable variables to estimator.

PiperOrigin-RevId: 191893084

tensorflow/contrib/kfac/python/ops/estimator.py
tensorflow/contrib/kfac/python/ops/optimizer.py

index ced1110..d11c9c8 100644 (file)
@@ -85,9 +85,9 @@ class FisherEstimator(object):
     """Create a FisherEstimator object.
 
     Args:
-      variables: A list of the variables for which to estimate the Fisher. This
-          must match the variables registered in layer_collection (if it is not
-          None).
+      variables: A `list` of variables or `callable` which returns the variables
+          for which to estimate the Fisher. This must match the variables
+          registered in layer_collection (if it is not None).
       cov_ema_decay: The decay factor used when calculating the covariance
           estimate moving averages.
       damping: float. The damping factor used to stabilize training due to
@@ -147,7 +147,10 @@ class FisherEstimator(object):
 
   @property
   def variables(self):
-    return self._variables
+    if callable(self._variables):
+      return self._variables()
+    else:
+      return self._variables
 
   @property
   def damping(self):
index 843aeef..f01c5a8 100644 (file)
@@ -108,13 +108,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
       ValueError: If momentum is non-zero and momentum_type is not 'regular'
           or 'adam'.
     """
-
-    variables = var_list
-    if variables is None:
-      variables = tf_variables.trainable_variables()
-
     # Parameters to be passed to the Fisher estimator:
-    self._variables = variables
+    self._variables = var_list or tf_variables.trainable_variables
     self._cov_ema_decay = cov_ema_decay
     self._layers = layer_collection
     self._estimation_mode = estimation_mode
@@ -235,7 +230,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
 
   @property
   def variables(self):
-    return self._variables
+    return self._fisher_est.variables
 
   @property
   def damping(self):
@@ -373,6 +368,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
     else:
       kwargs["var_list"] = kwargs.get("var_list") or self.variables
       var_list = kwargs["var_list"]
+
     if set(var_list) != set(self.variables):
       raise ValueError("var_list doesn't match with set of Fisher-estimating "
                        "variables.")