From afc21e7149a0d146bd8db3145fe825b1f316c0a9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 6 Apr 2018 08:48:16 -0700 Subject: [PATCH] The training model need not be built when the kfac optimizer is initialized so the self._variables will be empty list. So pass a function which returns list of trainable variables to estimator. PiperOrigin-RevId: 191893084 --- tensorflow/contrib/kfac/python/ops/estimator.py | 11 +++++++---- tensorflow/contrib/kfac/python/ops/optimizer.py | 10 +++------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index ced1110..d11c9c8 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -85,9 +85,9 @@ class FisherEstimator(object): """Create a FisherEstimator object. Args: - variables: A list of the variables for which to estimate the Fisher. This - must match the variables registered in layer_collection (if it is not - None). + variables: A `list` of variables or `callable` which returns the variables + for which to estimate the Fisher. This must match the variables + registered in layer_collection (if it is not None). cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages. damping: float. The damping factor used to stabilize training due to @@ -147,7 +147,10 @@ class FisherEstimator(object): @property def variables(self): - return self._variables + if callable(self._variables): + return self._variables() + else: + return self._variables @property def damping(self): diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index 843aeef..f01c5a8 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -108,13 +108,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'. """ - - variables = var_list - if variables is None: - variables = tf_variables.trainable_variables() - # Parameters to be passed to the Fisher estimator: - self._variables = variables + self._variables = var_list or tf_variables.trainable_variables self._cov_ema_decay = cov_ema_decay self._layers = layer_collection self._estimation_mode = estimation_mode @@ -235,7 +230,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): @property def variables(self): - return self._variables + return self._fisher_est.variables @property def damping(self): @@ -373,6 +368,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): else: kwargs["var_list"] = kwargs.get("var_list") or self.variables var_list = kwargs["var_list"] + if set(var_list) != set(self.variables): raise ValueError("var_list doesn't match with set of Fisher-estimating " "variables.") -- 2.7.4