- Default values of cov and inv variables are now 0. Zero-debiasing (as in Adam...
authorJames Martens <jamesmartens@google.com>
Thu, 26 Apr 2018 22:13:48 +0000 (15:13 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 26 Apr 2018 22:16:46 +0000 (15:16 -0700)
- Changed the initial default approximation used for generic registrations to "diagonal"
- Convenience properties for ops and thunks have all been removed, along with "make_ops_and_vars". User should only interface with "make_vars_and_create_op_thunks" (or maybe "create_ops_and_vars_thunks").

PiperOrigin-RevId: 194461623

13 files changed:
tensorflow/contrib/kfac/examples/convnet.py
tensorflow/contrib/kfac/examples/mlp.py
tensorflow/contrib/kfac/examples/tests/convnet_test.py
tensorflow/contrib/kfac/python/kernel_tests/BUILD
tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
tensorflow/contrib/kfac/python/ops/estimator.py
tensorflow/contrib/kfac/python/ops/fisher_factors.py
tensorflow/contrib/kfac/python/ops/layer_collection.py
tensorflow/contrib/kfac/python/ops/optimizer.py
tensorflow/contrib/kfac/python/ops/placement.py

index e8e3353..b261f41 100644 (file)
@@ -223,26 +223,26 @@ def minimize_loss_single_machine(loss,
   (cov_update_thunks,
    inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
 
-  with tf.device(device):
-    train_op = optimizer.minimize(loss, global_step=g_step)
-
   def make_update_op(update_thunks):
-    update_op = [thunk() for thunk in update_thunks]
-    return tf.group(*update_op)
+    update_ops = [thunk() for thunk in update_thunks]
+    return tf.group(*update_ops)
 
   cov_update_op = make_update_op(cov_update_thunks)
-  with tf.control_dependencies([train_op, cov_update_op]):
+  with tf.control_dependencies([cov_update_op]):
     inverse_op = tf.cond(
-        tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0),
+        tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
         lambda: make_update_op(inv_update_thunks), tf.no_op)
+    with tf.control_dependencies([inverse_op]):
+      with tf.device(device):
+        train_op = optimizer.minimize(loss, global_step=g_step)
 
   tf.logging.info("Starting training.")
   with tf.train.MonitoredTrainingSession(config=session_config) as sess:
     while not sess.should_stop():
       global_step_, loss_, accuracy_, _ = sess.run(
-          [g_step, loss, accuracy, inverse_op])
+          [g_step, loss, accuracy, train_op])
 
-      if (global_step_ + 1) % _INVERT_EVERY == 0:
+      if global_step_ % _INVERT_EVERY == 0:
         tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
                         global_step_, loss_, accuracy_)
 
@@ -357,24 +357,25 @@ def distributed_grads_only_and_ops_chief_worker(
       task_id, num_worker_tasks, num_ps_tasks, layer_collection)
   (cov_update_thunks,
    inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-  train_op = sync_optimizer.minimize(loss, global_step=global_step)
 
   tf.logging.info("Starting training.")
   hooks = [sync_optimizer.make_session_run_hook(is_chief)]
 
   def make_update_op(update_thunks):
-    update_op = [thunk() for thunk in update_thunks]
-    return tf.group(*update_op)
+    update_ops = [thunk() for thunk in update_thunks]
+    return tf.group(*update_ops)
 
   if is_chief:
     cov_update_op = make_update_op(cov_update_thunks)
-    with tf.control_dependencies([train_op, cov_update_op]):
-      update_op = tf.cond(
-          tf.equal(tf.mod(global_step + 1, invert_every), 0),
+    with tf.control_dependencies([cov_update_op]):
+      inverse_op = tf.cond(
+          tf.equal(tf.mod(global_step, invert_every), 0),
           lambda: make_update_op(inv_update_thunks),
           tf.no_op)
+      with tf.control_dependencies([inverse_op]):
+        train_op = sync_optimizer.minimize(loss, global_step=global_step)
   else:
-    update_op = train_op
+    train_op = sync_optimizer.minimize(loss, global_step=global_step)
 
   with tf.train.MonitoredTrainingSession(
       master=master,
@@ -384,7 +385,7 @@ def distributed_grads_only_and_ops_chief_worker(
       stop_grace_period_secs=0) as sess:
     while not sess.should_stop():
       global_step_, loss_, accuracy_, _ = sess.run(
-          [global_step, loss, accuracy, update_op])
+          [global_step, loss, accuracy, train_op])
       tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
                       loss_, accuracy_)
   return accuracy_
@@ -577,25 +578,25 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers,
   (cov_update_thunks,
    inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
 
-  train_op = optimizer.minimize(loss, global_step=g_step)
-
   def make_update_op(update_thunks):
-    update_op = [thunk() for thunk in update_thunks]
-    return tf.group(*update_op)
+    update_ops = [thunk() for thunk in update_thunks]
+    return tf.group(*update_ops)
 
   cov_update_op = make_update_op(cov_update_thunks)
-  with tf.control_dependencies([train_op, cov_update_op]):
+  with tf.control_dependencies([cov_update_op]):
     inverse_op = tf.cond(
-        tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0),
+        tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
         lambda: make_update_op(inv_update_thunks), tf.no_op)
+    with tf.control_dependencies([inverse_op]):
+      train_op = optimizer.minimize(loss, global_step=g_step)
 
   tf.logging.info("Starting training.")
   with tf.train.MonitoredTrainingSession(config=session_config) as sess:
     while not sess.should_stop():
       global_step_, loss_, accuracy_, _ = sess.run(
-          [g_step, loss, accuracy, inverse_op])
+          [g_step, loss, accuracy, train_op])
 
-      if (global_step_ + 1) % _INVERT_EVERY == 0:
+      if global_step_ % _INVERT_EVERY == 0:
         tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
                         global_step_, loss_, accuracy_)
 
index 87eed03..ea2b252 100644 (file)
@@ -105,18 +105,21 @@ def build_model(examples, labels, num_labels, layer_collection):
   return loss, accuracy
 
 
-def minimize(loss, accuracy, layer_collection, session_config=None):
+def minimize(loss, accuracy, layer_collection, num_towers, session_config=None):
   """Minimize 'loss' with KfacOptimizer.
 
   Args:
     loss: 0-D Tensor. Loss to be minimized.
     accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
     layer_collection: LayerCollection instance. Describes layers in model.
+    num_towers: int. Number of CPUs to split minibatch across.
     session_config: tf.ConfigProto. Configuration for tf.Session().
 
   Returns:
     accuracy of classifier on final minibatch.
   """
+  devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers))
+
   # Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2
   # every 10k iterations.
   tf.logging.info("Building KFAC Optimizer.")
@@ -125,27 +128,38 @@ def minimize(loss, accuracy, layer_collection, session_config=None):
       learning_rate=tf.train.exponential_decay(
           0.00002, global_step, 10000, 0.5, staircase=True),
       cov_ema_decay=0.95,
-      damping=0.0001,
+      damping=0.0005,
       layer_collection=layer_collection,
-      momentum=0.99)
-  train_op = optimizer.minimize(loss, global_step=global_step)
+      momentum=0.99,
+      placement_strategy="round_robin",
+      cov_devices=devices,
+      inv_devices=devices)
+
+  (cov_update_thunks,
+   inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+
+  def make_update_op(update_thunks):
+    update_ops = [thunk() for thunk in update_thunks]
+    return tf.group(*update_ops)
+
+  # TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt
+  # once that gets moved over?  Could still leave more advanced examples as they
+  # are (e.g. train_mnist_estimator in this file)
+
+  cov_update_op = make_update_op(cov_update_thunks)
+  with tf.control_dependencies([cov_update_op]):
+    # We update the inverses only every 20 iterations.
+    inverse_op = tf.cond(
+        tf.equal(tf.mod(global_step, 100), 0),
+        lambda: make_update_op(inv_update_thunks), tf.no_op)
+    with tf.control_dependencies([inverse_op]):
+      train_op = optimizer.minimize(loss, global_step=global_step)
 
   tf.logging.info("Starting training.")
   with tf.train.MonitoredTrainingSession(config=session_config) as sess:
     while not sess.should_stop():
-      # K-FAC has 3 primary ops,
-      # - train_op: Update the weights with the minibatch's gradient.
-      # - cov_update_op: Update statistics used for building K-FAC's
-      #   preconditioner matrix.
-      # - inv_update_op: Update preconditioner matrix using statistics.
-      #
-      # The first 2 of these are cheap and should be done with each step. The
-      # latter is more expensive, and should be updated ~100 iterations.
-      global_step_, loss_, accuracy_, _, _ = sess.run(
-          [global_step, loss, accuracy, train_op, optimizer.cov_update_op])
-
-      if global_step_ % 100 == 0:
-        sess.run(optimizer.inv_update_op)
+      global_step_, loss_, accuracy_, _ = sess.run(
+          [global_step, loss, accuracy, train_op])
 
       if global_step_ % 100 == 0:
         tf.logging.info("global_step: %d | loss: %f | accuracy: %f",
@@ -180,7 +194,7 @@ def train_mnist(data_dir, num_epochs, use_fake_data=False):
   loss, accuracy = build_model(examples, labels, 10, layer_collection)
 
   # Fit model.
-  minimize(loss, accuracy, layer_collection)
+  minimize(loss, accuracy, layer_collection, 1)
 
 
 def train_mnist_multitower(data_dir,
@@ -238,7 +252,8 @@ def train_mnist_multitower(data_dir,
           "CPU": num_towers
       })
   return minimize(
-      loss, accuracy, layer_collection, session_config=session_config)
+      loss, accuracy, layer_collection, num_towers,
+      session_config=session_config)
 
 
 def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False):
@@ -298,13 +313,26 @@ def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False):
         layer_collection=layer_collection,
         momentum=0.99)
 
+    (cov_update_thunks,
+     inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+
+    def make_update_op(update_thunks):
+      update_ops = [thunk() for thunk in update_thunks]
+      return tf.group(*update_ops)
+
+    def make_batch_executed_op(update_thunks, batch_size=1):
+      return tf.group(*tf.contrib.kfac.utils.batch_execute(
+          global_step, update_thunks, batch_size=batch_size))
+
     # Run cov_update_op every step. Run 1 inv_update_ops per step.
-    cov_update_op = optimizer.cov_update_op
-    inv_update_op = tf.group(
-        tf.contrib.kfac.utils.batch_execute(
-            global_step, optimizer.inv_update_thunks, batch_size=1))
-    with tf.control_dependencies([cov_update_op, inv_update_op]):
-      train_op = optimizer.minimize(loss, global_step=global_step)
+    cov_update_op = make_update_op(cov_update_thunks)
+    with tf.control_dependencies([cov_update_op]):
+      # But make sure to execute all the inverse ops on the first step
+      inverse_op = tf.cond(tf.equal(global_step, 0),
+                           lambda: make_update_op(inv_update_thunks),
+                           lambda: make_batch_executed_op(inv_update_thunks))
+      with tf.control_dependencies([inverse_op]):
+        train_op = optimizer.minimize(loss, global_step=global_step)
 
     # Print metrics every 5 sec.
     hooks = [
index 6de775c..adecda7 100644 (file)
@@ -157,7 +157,7 @@ class ConvNetTest(tf.test.TestCase):
           num_ps_tasks=0,
           master="",
           data_dir=None,
-          num_epochs=1,
+          num_epochs=2,
           op_strategy="chief_worker",
           use_fake_data=True)
 
index c2436af..6e4a8d7 100644 (file)
@@ -97,6 +97,7 @@ py_test(
     srcs = ["optimizer_test.py"],
     srcs_version = "PY2AND3",
     deps = [
+        "//tensorflow/contrib/kfac/python/ops:fisher_factors",
         "//tensorflow/contrib/kfac/python/ops:kfac_optimizer",
         "//tensorflow/contrib/kfac/python/ops:layer_collection",
         "//tensorflow/python:array_ops",
index f22dbcf..0e65d41 100644 (file)
@@ -81,7 +81,7 @@ class EstimatorTest(test.TestCase):
             damping=0.2,
             layer_collection=self.layer_collection
         )
-        est.make_ops_and_vars()
+        est.make_vars_and_create_op_thunks()
 
       # Check that we throw an error if we don't include registered variables,
       # i.e. self.weights
@@ -91,7 +91,7 @@ class EstimatorTest(test.TestCase):
             cov_ema_decay=0.1,
             damping=0.2,
             layer_collection=self.layer_collection)
-        est.make_ops_and_vars()
+        est.make_vars_and_create_op_thunks()
 
   @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
   def testVariableWrongNumberOfUses(self, mock_uses):
@@ -101,7 +101,7 @@ class EstimatorTest(test.TestCase):
           cov_ema_decay=0.1,
           damping=0.2,
           layer_collection=self.layer_collection)
-      est.make_ops_and_vars()
+      est.make_vars_and_create_op_thunks()
 
   def testInvalidEstimationMode(self):
     with self.assertRaises(ValueError):
@@ -111,7 +111,7 @@ class EstimatorTest(test.TestCase):
           damping=0.2,
           layer_collection=self.layer_collection,
           estimation_mode="not_a_real_mode")
-      est.make_ops_and_vars()
+      est.make_vars_and_create_op_thunks()
 
   def testGradientsModeBuild(self):
     with self._graph.as_default():
@@ -121,7 +121,7 @@ class EstimatorTest(test.TestCase):
           damping=0.2,
           layer_collection=self.layer_collection,
           estimation_mode="gradients")
-      est.make_ops_and_vars()
+      est.make_vars_and_create_op_thunks()
 
   def testEmpiricalModeBuild(self):
     with self._graph.as_default():
@@ -131,7 +131,7 @@ class EstimatorTest(test.TestCase):
           damping=0.2,
           layer_collection=self.layer_collection,
           estimation_mode="empirical")
-      est.make_ops_and_vars()
+      est.make_vars_and_create_op_thunks()
 
   def testCurvaturePropModeBuild(self):
     with self._graph.as_default():
@@ -141,7 +141,7 @@ class EstimatorTest(test.TestCase):
           damping=0.2,
           layer_collection=self.layer_collection,
           estimation_mode="curvature_prop")
-      est.make_ops_and_vars()
+      est.make_vars_and_create_op_thunks()
 
   def testExactModeBuild(self):
     with self._graph.as_default():
@@ -151,7 +151,7 @@ class EstimatorTest(test.TestCase):
           damping=0.2,
           layer_collection=self.layer_collection,
           estimation_mode="exact")
-      est.make_ops_and_vars()
+      est.make_vars_and_create_op_thunks()
 
   def test_cov_update_thunks(self):
     """Ensures covariance update ops run once per global_step."""
@@ -215,8 +215,11 @@ class EstimatorTest(test.TestCase):
           inv_devices=["/cpu:{}".format(i) for i in range(2)])
 
       # Construct an op that executes one covariance update per step.
-      (cov_update_ops, _, inv_update_ops, _, _,
-       _) = fisher_estimator.make_ops_and_vars(scope="test")
+      (cov_update_thunks,
+       inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks(
+           scope="test")
+      cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
+      inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
       self.assertEqual(cov_update_ops[0].device, "/device:CPU:0")
       self.assertEqual(cov_update_ops[1].device, "/device:CPU:1")
       self.assertEqual(inv_update_ops[0].device, "/device:CPU:0")
index 566d393..86ec7a0 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import print_function
 import numpy as np
 
 from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
+from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
 from tensorflow.contrib.kfac.python.ops import layer_collection as lc
 from tensorflow.contrib.kfac.python.ops import linear_operator as lo
 from tensorflow.contrib.kfac.python.ops import utils
@@ -35,6 +36,19 @@ from tensorflow.python.ops import variables as tf_variables
 from tensorflow.python.platform import test
 
 
+# We need to set these constants since the numerical values used in the tests
+# were chosen when these used to be the defaults.
+ff.set_global_constants(init_covariances_at_zero=False,
+                        zero_debias=False,
+                        init_inverses_at_zero=False)
+
+# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our
+# inverse is something other than the identity" are actually broken. They never
+# run the covariance update ops and so the inverse actually is the identity
+# (possible plus the damping term, which would still make it a multiple of the
+# identity).
+
+
 def _make_psd(dim):
   """Constructs a PSD matrix of the given dimension."""
   mat = np.ones((dim, dim), dtype=np.float32)
index 9153ddf..fad47cd 100644 (file)
@@ -35,6 +35,13 @@ from tensorflow.python.ops import variables as tf_variables
 from tensorflow.python.platform import test
 
 
+# We need to set these constants since the numerical values used in the tests
+# were chosen when these used to be the defaults.
+ff.set_global_constants(init_covariances_at_zero=False,
+                        zero_debias=False,
+                        init_inverses_at_zero=False)
+
+
 def make_damping_func(damping):
   return fb._package_func(lambda: damping, damping)
 
index 9325aa1..560a9b0 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
+from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
 from tensorflow.contrib.kfac.python.ops import layer_collection as lc
 from tensorflow.contrib.kfac.python.ops import optimizer
 from tensorflow.python.framework import ops
@@ -32,6 +33,13 @@ from tensorflow.python.ops import variables as tf_variables
 from tensorflow.python.platform import test
 
 
+# We need to set these constants since the numerical values used in the tests
+# were chosen when these used to be the defaults.
+ff.set_global_constants(init_covariances_at_zero=False,
+                        zero_debias=False,
+                        init_inverses_at_zero=False)
+
+
 def dummy_layer_collection():
   lcoll = lc.LayerCollection()
   dummy = array_ops.constant([1., 2.])
@@ -186,6 +194,11 @@ class OptimizerTest(test.TestCase):
           layer_collection,
           momentum=0.5,
           momentum_type='regular')
+      (cov_update_thunks,
+       inv_update_thunks) = opt.make_vars_and_create_op_thunks()
+      cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
+      inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
+
       grads_and_vars = opt.compute_gradients(output, [weights, bias])
       all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars]
 
@@ -193,6 +206,8 @@ class OptimizerTest(test.TestCase):
 
       sess.run(tf_variables.global_variables_initializer())
       old_vars = sess.run(all_vars)
+      sess.run(cov_update_ops)
+      sess.run(inv_update_ops)
       sess.run(op)
       new_vars = sess.run(all_vars)
 
index 84ebf5e..854f885 100644 (file)
@@ -181,44 +181,6 @@ class FisherEstimator(object):
     return self._name
 
   @abc.abstractmethod
-  def make_ops_and_vars(self, scope=None):
-    """Make ops and vars with a specific placement strategy.
-
-    For each factor, all of that factor's cov variables and their associated
-    update ops will be placed on a particular device.  For example in case of
-    round robin placement a new device is chosen for each factor by cycling
-    through list of devices in the cov_devices argument. If cov_devices is None
-    then no explicit device placement occurs.
-
-    An analogous strategy is followed for inverse update ops, with the list of
-    devices being given by the inv_devices argument.
-
-    Inverse variables on the other hand are not placed on any specific device
-    (they will just use the current the device placement context, whatever
-    that happens to be).  The idea is that the inverse variable belong where
-    they will be accessed most often, which is the device that actually applies
-    the preconditioner to the gradient. The user will be responsible for setting
-    the device context for this.
-
-    Args:
-      scope: A string or None.  If None it will be set to the name of this
-        estimator (given by the name property). All variables will be created,
-        and all ops will execute, inside of a variable scope of the given
-        name. (Default: None)
-
-    Returns:
-      cov_update_ops: List of ops that compute the cov updates. Corresponds
-        one-to-one with the list of factors given by the "factors" property.
-      cov_update_op: cov_update_ops grouped into a single op.
-      inv_update_ops: List of ops that compute the inv updates. Corresponds
-        one-to-one with the list of factors given by the "factors" property.
-      inv_update_op: inv_update_ops grouped into a single op.
-      cov_update_thunks: Thunks that make the ops in cov_update_ops.
-      inv_update_thunks: Thunks that make the ops in inv_update_ops.
-    """
-    pass
-
-  @abc.abstractmethod
   def make_vars_and_create_op_thunks(self, scope=None):
     """Make vars and create op thunks with a specific placement strategy.
 
index 30f8a2a..b43232d 100644 (file)
@@ -43,10 +43,14 @@ from tensorflow.python.util import nest
 
 # Whether to initialize covariance estimators at a zero matrix (or the identity
 # matrix).
-INIT_COVARIANCES_AT_ZERO = False
+INIT_COVARIANCES_AT_ZERO = True
 
 # Whether to zero-debias the moving averages.
-ZERO_DEBIAS = False
+ZERO_DEBIAS = True
+
+# Whether to initialize inverse (and other such matrices computed from the cov
+# matrices) to the zero matrix (or the identity matrix).
+INIT_INVERSES_AT_ZERO = True
 
 # When the number of inverses requested from a FisherFactor exceeds this value,
 # the inverses are computed using an eigenvalue decomposition.
@@ -83,6 +87,7 @@ TOWER_STRATEGY = "concat"
 
 def set_global_constants(init_covariances_at_zero=None,
                          zero_debias=None,
+                         init_inverses_at_zero=None,
                          eigenvalue_decomposition_threshold=None,
                          eigenvalue_clipping_threshold=None,
                          max_num_outer_products_per_cov_row=None,
@@ -93,6 +98,7 @@ def set_global_constants(init_covariances_at_zero=None,
   """Sets various global constants used by the classes in this module."""
   global INIT_COVARIANCES_AT_ZERO
   global ZERO_DEBIAS
+  global INIT_INVERSES_AT_ZERO
   global EIGENVALUE_DECOMPOSITION_THRESHOLD
   global EIGENVALUE_CLIPPING_THRESHOLD
   global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW
@@ -105,6 +111,8 @@ def set_global_constants(init_covariances_at_zero=None,
     INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero
   if zero_debias is not None:
     ZERO_DEBIAS = zero_debias
+  if init_inverses_at_zero is not None:
+    INIT_INVERSES_AT_ZERO = init_inverses_at_zero
   if eigenvalue_decomposition_threshold is not None:
     EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold
   if eigenvalue_clipping_threshold is not None:
@@ -122,19 +130,21 @@ def set_global_constants(init_covariances_at_zero=None,
 
 
 def inverse_initializer(shape, dtype, partition_info=None):  # pylint: disable=unused-argument
-  return array_ops.diag(array_ops.ones(shape[0], dtype))
+  if INIT_INVERSES_AT_ZERO:
+    return array_ops.zeros(shape, dtype=dtype)
+  return linalg_ops.eye(num_rows=shape[0], dtype=dtype)
 
 
 def covariance_initializer(shape, dtype, partition_info=None):  # pylint: disable=unused-argument
   if INIT_COVARIANCES_AT_ZERO:
-    return array_ops.diag(array_ops.zeros(shape[0], dtype))
-  return array_ops.diag(array_ops.ones(shape[0], dtype))
+    return array_ops.zeros(shape, dtype=dtype)
+  return linalg_ops.eye(num_rows=shape[0], dtype=dtype)
 
 
-def diagonal_covariance_initializer(shape, dtype, partition_info):  # pylint: disable=unused-argument
+def diagonal_covariance_initializer(shape, dtype, partition_info=None):  # pylint: disable=unused-argument
   if INIT_COVARIANCES_AT_ZERO:
-    return array_ops.zeros(shape, dtype)
-  return array_ops.ones(shape, dtype)
+    return array_ops.zeros(shape, dtype=dtype)
+  return array_ops.ones(shape, dtype=dtype)
 
 
 @contextlib.contextmanager
index 366e2a8..cbbfe72 100644 (file)
@@ -182,7 +182,7 @@ class LayerCollection(object):
     self._graph = graph or ops.get_default_graph()
     self._loss_dict = {}  # {str: LossFunction}
     self._subgraph = None
-    self._default_generic_approximation = APPROX_FULL_NAME
+    self._default_generic_approximation = APPROX_DIAGONAL_NAME
     self._default_embedding_approximation = APPROX_KRONECKER_NAME
     self._default_fully_connected_approximation = APPROX_KRONECKER_NAME
     self._default_conv2d_approximation = APPROX_KRONECKER_NAME
index f01c5a8..45a760c 100644 (file)
@@ -18,7 +18,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import warnings
 # pylint disable=long-line
 from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
 from tensorflow.contrib.kfac.python.ops import estimator as est
@@ -243,62 +242,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
   def damping_adaptation_interval(self):
     return self._damping_adaptation_interval
 
-  @property
-  def cov_update_thunks(self):
-    self._maybe_make_and_save_everything()
-    return self._cov_update_thunks
-
-  @property
-  def cov_update_ops(self):
-    self._maybe_make_and_save_everything()
-    return self._cov_update_ops
-
-  @property
-  def cov_update_op(self):
-    self._maybe_make_and_save_everything()
-    return self._cov_update_op
-
-  @property
-  def inv_update_thunks(self):
-    self._maybe_make_and_save_everything()
-    return self._inv_update_thunks
-
-  @property
-  def inv_update_ops(self):
-    self._maybe_make_and_save_everything()
-    return self._inv_update_ops
-
-  @property
-  def inv_update_op(self):
-    self._maybe_make_and_save_everything()
-    return self._inv_update_op
-
-  def _maybe_make_and_save_everything(self):
-    if not self._fisher_est.made_vars():
-      warnings.warn("These convenience properties will be depcrecated soon. "
-                    "Please use explicit op/thunk creation methods instead "
-                    "(e.g. make_ops_and_vars, etc).",
-                    DeprecationWarning)
-      (self._cov_update_ops, self._cov_update_op, self._inv_update_ops,
-       self._inv_update_op, self._cov_update_thunks,
-       self._inv_update_thunks) = self.make_ops_and_vars()
-
-  def make_ops_and_vars(self):
-    """Make ops and vars with device placement `self._placement_strategy`.
-
-    See `FisherEstimator.make_ops_and_vars` for details.
-
-    Returns:
-      cov_update_ops: List of ops that compute the cov updates. Corresponds
-        one-to-one with the list of factors given by the "factors" property.
-      cov_update_op: cov_update_ops grouped into a single op.
-      inv_update_ops: List of ops that compute the inv updates. Corresponds
-        one-to-one with the list of factors given by the "factors" property.
-      cov_update_op: cov_update_ops grouped into a single op.
-      inv_update_op: inv_update_ops grouped into a single op.
-    """
-    return self._fisher_est.make_ops_and_vars(scope=self.get_name())
-
   def make_vars_and_create_op_thunks(self):
     """Make vars and create op thunks.
 
@@ -385,7 +328,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
     Returns:
       An `Operation` that applies the specified gradients.
     """
-    self._maybe_make_and_save_everything()
     # In Python 3, grads_and_vars can be a zip() object which can only be
     # iterated over once. By converting it to a list, we ensure that it can be
     # iterated over more than once.
index 38a0e28..8a20ebe 100644 (file)
@@ -21,8 +21,6 @@ from __future__ import print_function
 import itertools
 
 from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import variable_scope
 
 
 def _make_thunk_on_device(func, device):
@@ -52,56 +50,6 @@ class RoundRobinPlacementMixin(object):
     self._cov_devices = cov_devices
     self._inv_devices = inv_devices
 
-  def make_ops_and_vars(self, scope=None):
-    """Make ops and vars with a round-robin device placement strategy.
-
-    For each factor, all of that factor's cov variables and their associated
-    update ops will be placed on a particular device.  A new device is chosen
-    for each factor by cycling through list of devices in the
-    `self._cov_devices` attribute. If `self._cov_devices` is `None` then no
-    explicit device placement occurs.
-
-    An analogous strategy is followed for inverse update ops, with the list of
-    devices being given by the `self._inv_devices` attribute.
-
-    Inverse variables on the other hand are not placed on any specific device
-    (they will just use the current the device placement context, whatever
-    that happens to be).  The idea is that the inverse variable belong where
-    they will be accessed most often, which is the device that actually applies
-    the preconditioner to the gradient. The user will be responsible for setting
-    the device context for this.
-
-    Args:
-      scope: A string or None.  If None it will be set to the name of this
-        estimator (given by the name property). All variables will be created,
-        and all ops will execute, inside of a variable scope of the given
-        name. (Default: None)
-
-    Returns:
-      cov_update_ops: List of ops that compute the cov updates. Corresponds
-        one-to-one with the list of factors given by the "factors" property.
-      cov_update_op: cov_update_ops grouped into a single op.
-      inv_update_ops: List of ops that compute the inv updates. Corresponds
-        one-to-one with the list of factors given by the "factors" property.
-      inv_update_op: inv_update_ops grouped into a single op.
-      cov_update_thunks: Thunks that make the ops in cov_update_ops.
-      inv_update_thunks: Thunks that make the ops in inv_update_ops.
-    """
-    (cov_update_thunks,
-     inv_update_thunks) = self.make_vars_and_create_op_thunks(scope=scope)
-    cov_update_ops = [thunk() for thunk in cov_update_thunks]
-    inv_update_ops = [thunk() for thunk in inv_update_thunks]
-
-    scope = self.name if scope is None else scope
-    with variable_scope.variable_scope(scope):
-      cov_update_op = control_flow_ops.group(cov_update_ops,
-                                             name="cov_update_op")
-      inv_update_op = control_flow_ops.group(inv_update_ops,
-                                             name="inv_update_op")
-
-    return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op,
-            cov_update_thunks, inv_update_thunks)
-
   def make_vars_and_create_op_thunks(self, scope=None):
     """Make vars and create op thunks w/ a round-robin device placement strat.