Performing the finalization of the LayerCollection outside of FisherEstimator's const...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 13 Mar 2018 17:13:45 +0000 (10:13 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Mar 2018 17:17:49 +0000 (10:17 -0700)
PiperOrigin-RevId: 188889252

tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
tensorflow/contrib/kfac/python/ops/estimator.py

index c1ea296..30c5404 100644 (file)
@@ -96,49 +96,57 @@ class EstimatorTest(test.TestCase):
       # Check that we throw an error if we try to build an estimator for vars
       # that were not manually registered.
       with self.assertRaises(ValueError):
-        estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
-                                  self.layer_collection)
+        est = estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
+                                        self.layer_collection)
+        est.make_ops_and_vars()
 
       # Check that we throw an error if we don't include registered variables,
       # i.e. self.weights
       with self.assertRaises(ValueError):
-        estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
+        est = estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
+        est.make_ops_and_vars()
 
   @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
   def testVariableWrongNumberOfUses(self, mock_uses):
     with self.assertRaises(ValueError):
-      estimator.FisherEstimator([self.weights], 0.1, 0.2,
-                                self.layer_collection)
+      est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+                                      self.layer_collection)
+      est.make_ops_and_vars()
 
   def testInvalidEstimationMode(self):
     with self.assertRaises(ValueError):
-      estimator.FisherEstimator([self.weights], 0.1, 0.2,
-                                self.layer_collection,
-                                estimation_mode="not_a_real_mode")
+      est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+                                      self.layer_collection,
+                                      estimation_mode="not_a_real_mode")
+      est.make_ops_and_vars()
 
   def testGradientsModeBuild(self):
     with self._graph.as_default():
-      estimator.FisherEstimator([self.weights], 0.1, 0.2,
-                                self.layer_collection,
-                                estimation_mode="gradients")
+      est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+                                      self.layer_collection,
+                                      estimation_mode="gradients")
+      est.make_ops_and_vars()
 
   def testEmpiricalModeBuild(self):
     with self._graph.as_default():
-      estimator.FisherEstimator([self.weights], 0.1, 0.2,
-                                self.layer_collection,
-                                estimation_mode="empirical")
+      est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+                                      self.layer_collection,
+                                      estimation_mode="empirical")
+      est.make_ops_and_vars()
 
   def testCurvaturePropModeBuild(self):
     with self._graph.as_default():
-      estimator.FisherEstimator([self.weights], 0.1, 0.2,
-                                self.layer_collection,
-                                estimation_mode="curvature_prop")
+      est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+                                      self.layer_collection,
+                                      estimation_mode="curvature_prop")
+      est.make_ops_and_vars()
 
   def testExactModeBuild(self):
     with self._graph.as_default():
-      estimator.FisherEstimator([self.weights], 0.1, 0.2,
-                                self.layer_collection,
-                                estimation_mode="exact")
+      est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+                                      self.layer_collection,
+                                      estimation_mode="exact")
+      est.make_ops_and_vars()
 
   def test_cov_update_thunks(self):
     """Ensures covariance update ops run once per global_step."""
index fdfd959..64755be 100644 (file)
@@ -149,8 +149,6 @@ class FisherEstimator(object):
     self._damping = damping
     self._estimation_mode = estimation_mode
     self._layers = layer_collection
-    self._layers.create_subgraph()
-    self._layers.check_registration(variables)
     self._gradient_fns = {
         "gradients": self._get_grads_lists_gradients,
         "empirical": self._get_grads_lists_empirical,
@@ -164,9 +162,6 @@ class FisherEstimator(object):
 
     self._name = name
 
-    self._instantiate_factors()
-    self._register_matrix_functions()
-
   @property
   def variables(self):
     return self._variables
@@ -285,6 +280,12 @@ class FisherEstimator(object):
       for block in self.blocks:
         block.register_matpower(exp)
 
+  def _finalize_layer_collection(self):
+    self._layers.create_subgraph()
+    self._layers.check_registration(self.variables)
+    self._instantiate_factors()
+    self._register_matrix_functions()
+
   def make_ops_and_vars(self, scope=None):
     """Make ops and vars with no specific device placement.
 
@@ -467,6 +468,8 @@ class FisherEstimator(object):
     """
     self._check_vars_unmade_and_set_made_flag()
 
+    self._finalize_layer_collection()
+
     scope = self.name if scope is None else scope
 
     cov_variable_thunks = [