# Check that we throw an error if we try to build an estimator for vars
# that were not manually registered.
with self.assertRaises(ValueError):
- estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
- self.layer_collection)
+ est = estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
+ self.layer_collection)
+ est.make_ops_and_vars()
# Check that we throw an error if we don't include registered variables,
# i.e. self.weights
with self.assertRaises(ValueError):
- estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
+ est = estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
+ est.make_ops_and_vars()
@test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
def testVariableWrongNumberOfUses(self, mock_uses):
with self.assertRaises(ValueError):
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection)
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection)
+ est.make_ops_and_vars()
def testInvalidEstimationMode(self):
with self.assertRaises(ValueError):
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="not_a_real_mode")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="not_a_real_mode")
+ est.make_ops_and_vars()
def testGradientsModeBuild(self):
with self._graph.as_default():
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="gradients")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="gradients")
+ est.make_ops_and_vars()
def testEmpiricalModeBuild(self):
with self._graph.as_default():
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="empirical")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="empirical")
+ est.make_ops_and_vars()
def testCurvaturePropModeBuild(self):
with self._graph.as_default():
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="curvature_prop")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="curvature_prop")
+ est.make_ops_and_vars()
def testExactModeBuild(self):
with self._graph.as_default():
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="exact")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="exact")
+ est.make_ops_and_vars()
def test_cov_update_thunks(self):
"""Ensures covariance update ops run once per global_step."""
self._damping = damping
self._estimation_mode = estimation_mode
self._layers = layer_collection
- self._layers.create_subgraph()
- self._layers.check_registration(variables)
self._gradient_fns = {
"gradients": self._get_grads_lists_gradients,
"empirical": self._get_grads_lists_empirical,
self._name = name
- self._instantiate_factors()
- self._register_matrix_functions()
-
@property
def variables(self):
return self._variables
for block in self.blocks:
block.register_matpower(exp)
+ def _finalize_layer_collection(self):
+ self._layers.create_subgraph()
+ self._layers.check_registration(self.variables)
+ self._instantiate_factors()
+ self._register_matrix_functions()
+
def make_ops_and_vars(self, scope=None):
"""Make ops and vars with no specific device placement.
"""
self._check_vars_unmade_and_set_made_flag()
+ self._finalize_layer_collection()
+
scope = self.name if scope is None else scope
cov_variable_thunks = [