K-FAC: Expose protected functions from fisher_blocks and fisher_factors and constant...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 17 Jan 2018 13:03:05 +0000 (05:03 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 17 Jan 2018 13:07:01 +0000 (05:07 -0800)
PiperOrigin-RevId: 182197356

tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
tensorflow/contrib/kfac/python/ops/fisher_blocks.py
tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
tensorflow/contrib/kfac/python/ops/fisher_factors.py
tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
tensorflow/contrib/kfac/python/ops/layer_collection_lib.py

index 2d9b28185ce0db32d5cd7d84737fdf96e2c98851..82accd57f0c37d140238f1884fce956654d14227 100644 (file)
@@ -49,7 +49,7 @@ class UtilsTest(test.TestCase):
       right_factor = array_ops.ones([2., 2.])
 
       # pi is the sqrt of the left trace norm divided by the right trace norm
-      pi = fb._compute_pi_tracenorm(left_factor, right_factor)
+      pi = fb.compute_pi_tracenorm(left_factor, right_factor)
 
       pi_val = sess.run(pi)
       self.assertEqual(1., pi_val)
index a2665b92790d7f71573fcc088205de5ba18aef1f..753378d9f4a0d8762bafbee2ec27d6c71783dda1 100644 (file)
@@ -46,7 +46,7 @@ class MaybeColocateTest(test.TestCase):
     ff.set_global_constants(colocate_cov_ops_with_inputs=False)
     with tf_ops.Graph().as_default():
       a = constant_op.constant([2.0], name='a')
-      with ff._maybe_colocate_with(a):
+      with ff.maybe_colocate_with(a):
         b = constant_op.constant(3.0, name='b')
       self.assertEqual([b'loc:@a'], a.op.colocation_groups())
       self.assertEqual([b'loc:@b'], b.op.colocation_groups())
@@ -55,7 +55,7 @@ class MaybeColocateTest(test.TestCase):
     ff.set_global_constants(colocate_cov_ops_with_inputs=True)
     with tf_ops.Graph().as_default():
       a = constant_op.constant([2.0], name='a')
-      with ff._maybe_colocate_with(a):
+      with ff.maybe_colocate_with(a):
         b = constant_op.constant(3.0, name='b')
       self.assertEqual([b'loc:@a'], a.op.colocation_groups())
       self.assertEqual([b'loc:@a'], b.op.colocation_groups())
@@ -129,7 +129,7 @@ class NumericalUtilsTest(test.TestCase):
       random_seed.set_random_seed(200)
 
       x = npr.randn(100, 3)
-      cov = ff._compute_cov(array_ops.constant(x))
+      cov = ff.compute_cov(array_ops.constant(x))
       np_cov = np.dot(x.T, x) / x.shape[0]
 
       self.assertAllClose(sess.run(cov), np_cov)
@@ -141,7 +141,7 @@ class NumericalUtilsTest(test.TestCase):
 
       normalizer = 10.
       x = npr.randn(100, 3)
-      cov = ff._compute_cov(array_ops.constant(x), normalizer=normalizer)
+      cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer)
       np_cov = np.dot(x.T, x) / normalizer
 
       self.assertAllClose(sess.run(cov), np_cov)
@@ -152,7 +152,7 @@ class NumericalUtilsTest(test.TestCase):
 
       m, n = 3, 4
       a = npr.randn(m, n)
-      a_homog = ff._append_homog(array_ops.constant(a))
+      a_homog = ff.append_homog(array_ops.constant(a))
       np_result = np.hstack([a, np.ones((m, 1))])
 
       self.assertAllClose(sess.run(a_homog), np_result)
index 1ccb9e040f2bb6bcfd217886918abd40e3cc1cfb..9436caf9618bc3d3c0dd7b3842420016b119464f 100644 (file)
@@ -54,7 +54,7 @@ from tensorflow.python.ops import math_ops
 NORMALIZE_DAMPING_POWER = 1.0
 
 # Methods for adjusting damping for FisherBlocks. See
-# _compute_pi_adjusted_damping() for details.
+# compute_pi_adjusted_damping() for details.
 PI_OFF_NAME = "off"
 PI_TRACENORM_NAME = "tracenorm"
 PI_TYPE = PI_TRACENORM_NAME
@@ -72,7 +72,14 @@ def set_global_constants(normalize_damping_power=None, pi_type=None):
     PI_TYPE = pi_type
 
 
-def _compute_pi_tracenorm(left_cov, right_cov):
+def normalize_damping(damping, num_replications):
+  """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
+  if NORMALIZE_DAMPING_POWER:
+    return damping / (num_replications ** NORMALIZE_DAMPING_POWER)
+  return damping
+
+
+def compute_pi_tracenorm(left_cov, right_cov):
   """Computes the scalar constant pi for Tikhonov regularization/damping.
 
   pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) )
@@ -92,10 +99,10 @@ def _compute_pi_tracenorm(left_cov, right_cov):
   return math_ops.sqrt(left_norm / right_norm)
 
 
-def _compute_pi_adjusted_damping(left_cov, right_cov, damping):
+def compute_pi_adjusted_damping(left_cov, right_cov, damping):
 
   if PI_TYPE == PI_TRACENORM_NAME:
-    pi = _compute_pi_tracenorm(left_cov, right_cov)
+    pi = compute_pi_tracenorm(left_cov, right_cov)
     return (damping * pi, damping / pi)
 
   elif PI_TYPE == PI_OFF_NAME:
@@ -450,10 +457,7 @@ class ConvDiagonalFB(FisherBlock):
     self._num_locations = (
         inputs_shape[1] * inputs_shape[2] //
         (self._strides[1] * self._strides[2]))
-
-    if NORMALIZE_DAMPING_POWER:
-      damping /= self._num_locations**NORMALIZE_DAMPING_POWER
-    self._damping = damping
+    self._damping = normalize_damping(damping, self._num_locations)
 
     self._factor = self._layer_collection.make_or_get_factor(
         fisher_factors.ConvDiagonalFactor,
@@ -506,7 +510,7 @@ class KroneckerProductFB(FisherBlock):
     Args:
       damping: The base damping factor (float or Tensor) for the damped inverse.
     """
-    self._input_damping, self._output_damping = _compute_pi_adjusted_damping(
+    self._input_damping, self._output_damping = compute_pi_adjusted_damping(
         self._input_factor.get_cov(),
         self._output_factor.get_cov(),
         damping**0.5)
@@ -691,8 +695,8 @@ class ConvKFCBasicFB(KroneckerProductFB):
     grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
 
     # Infer number of locations upon which convolution is applied.
-    self._num_locations = _num_conv_locations(inputs.shape.as_list(),
-                                              self._strides)
+    self._num_locations = num_conv_locations(inputs.shape.as_list(),
+                                             self._strides)
 
     self._input_factor = self._layer_collection.make_or_get_factor(
         fisher_factors.ConvInputKroneckerFactor,
@@ -701,11 +705,9 @@ class ConvKFCBasicFB(KroneckerProductFB):
     self._output_factor = self._layer_collection.make_or_get_factor(
         fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
 
-    if NORMALIZE_DAMPING_POWER:
-      damping /= self._num_locations**NORMALIZE_DAMPING_POWER
-    self._damping = damping
-
+    damping = normalize_damping(damping, self._num_locations)
     self._register_damped_input_and_output_inverses(damping)
+    self._damping = damping
 
   @property
   def _renorm_coeff(self):
@@ -758,8 +760,16 @@ def _concat_along_batch_dim(tensor_list):
     return array_ops.concat(tensor_list, axis=0)
 
 
-def _num_conv_locations(input_shape, strides):
-  """Returns the number of locations a Conv kernel is applied to."""
+def num_conv_locations(input_shape, strides):
+  """Returns the number of spatial locations a 2D Conv kernel is applied to.
+
+  Args:
+    input_shape: list representing shape of inputs to the Conv layer.
+    strides: list representing strides for the Conv kernel.
+
+  Returns:
+    A scalar |T| denoting the number of spatial locations for the Conv layer.
+  """
   return input_shape[1] * input_shape[2] // (strides[1] * strides[2])
 
 
@@ -804,9 +814,7 @@ class FullyConnectedMultiIndepFB(KroneckerProductFB):
     self._output_factor = self._layer_collection.make_or_get_factor(
         fisher_factors.FullyConnectedMultiKF, (grads_list,))
 
-    if NORMALIZE_DAMPING_POWER:
-      damping /= self._num_uses**NORMALIZE_DAMPING_POWER
-
+    damping = normalize_damping(damping, self._num_uses)
     self._register_damped_input_and_output_inverses(damping)
 
   @property
@@ -885,10 +893,8 @@ class FullyConnectedSeriesFB(FisherBlock):
     self._output_factor = self._layer_collection.make_or_get_factor(
         fisher_factors.FullyConnectedMultiKF, (grads_list,))
 
-    if NORMALIZE_DAMPING_POWER:
-      damping /= self._num_timesteps**NORMALIZE_DAMPING_POWER
-
-    self._damping_input, self._damping_output = _compute_pi_adjusted_damping(
+    damping = normalize_damping(damping, self._num_timesteps)
+    self._damping_input, self._damping_output = compute_pi_adjusted_damping(
         self._input_factor.get_cov(),
         self._output_factor.get_cov(),
         damping**0.5)
index 59389f8d385c18f50914d690cfaa2825ef807ed3..ac396309206fe09af65c2b70840a513fb25b579b 100644 (file)
@@ -33,6 +33,10 @@ _allowed_symbols = [
     'ConvKFCBasicFB',
     'ConvDiagonalFB',
     'set_global_constants',
+    'compute_pi_tracenorm',
+    'compute_pi_adjusted_damping',
+    'num_conv_locations',
+    'normalize_damping'
 ]
 
 remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
index 826e8b7732c1d164915baf8ce0215351c072a65f..a069f6bdd9d790ddf1c7d4b185f062ff08fa9a60 100644 (file)
@@ -58,7 +58,7 @@ COLOCATE_COV_OPS_WITH_INPUTS = True
 
 
 @contextlib.contextmanager
-def _maybe_colocate_with(op):
+def maybe_colocate_with(op):
   """Context to colocate with `op` if `COLOCATE_COV_OPS_WITH_INPUTS`."""
   if COLOCATE_COV_OPS_WITH_INPUTS:
     if isinstance(op, (list, tuple)):
@@ -111,7 +111,7 @@ def diagonal_covariance_initializer(shape, dtype, partition_info):  # pylint: di
   return array_ops.ones(shape, dtype)
 
 
-def _compute_cov(tensor, tensor_right=None, normalizer=None):
+def compute_cov(tensor, tensor_right=None, normalizer=None):
   """Compute the empirical second moment of the rows of a 2D Tensor.
 
   This function is meant to be applied to random matrices for which the true row
@@ -139,7 +139,7 @@ def _compute_cov(tensor, tensor_right=None, normalizer=None):
             math_ops.cast(normalizer, tensor.dtype))
 
 
-def _append_homog(tensor):
+def append_homog(tensor):
   """Appends a homogeneous coordinate to the last dimension of a Tensor.
 
   Args:
@@ -281,7 +281,7 @@ class FisherFactor(object):
     # the future we might want to perform the cov computations on each tower,
     # so that each tower will be considered a "source" (allowing us to reuse
     # the existing "source" code for this).
-    with _maybe_colocate_with(new_cov_contribs[0]):
+    with maybe_colocate_with(new_cov_contribs[0]):
       new_cov = math_ops.add_n(new_cov_contribs)
       # Synchronize value across all TPU cores.
       if utils.on_tpu():
@@ -472,7 +472,7 @@ class FullFactor(InverseProvidingFactor):
 
   def _compute_new_cov(self, idx=0):
     # This will be a very basic rank 1 estimate
-    with _maybe_colocate_with(self._params_grads[idx]):
+    with maybe_colocate_with(self._params_grads[idx]):
       params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
       return ((params_grads_flat * array_ops.transpose(
           params_grads_flat)) / math_ops.cast(self._batch_size,
@@ -528,7 +528,7 @@ class NaiveDiagonalFactor(DiagonalFactor):
     return self._params_grads[0][0].dtype
 
   def _compute_new_cov(self, idx=0):
-    with _maybe_colocate_with(self._params_grads[idx]):
+    with maybe_colocate_with(self._params_grads[idx]):
       params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
       return (math_ops.square(params_grads_flat) / math_ops.cast(
           self._batch_size, params_grads_flat.dtype))
@@ -589,12 +589,12 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
     # square of an outer product is the outer-product of the entry-wise squares.
     # The gradient is the outer product of the input and the output gradients,
     # so we just square both and then take their outer-product.
-    with _maybe_colocate_with(self._outputs_grads[idx]):
+    with maybe_colocate_with(self._outputs_grads[idx]):
       # We only need to compute squared_inputs once
       if self._squared_inputs is None:
         inputs = self._inputs
         if self._has_bias:
-          inputs = _append_homog(self._inputs)
+          inputs = append_homog(self._inputs)
         self._squared_inputs = math_ops.square(inputs)
 
       new_cov = math_ops.matmul(
@@ -662,7 +662,7 @@ class ConvDiagonalFactor(DiagonalFactor):
     return self._outputs_grads[0].dtype
 
   def _compute_new_cov(self, idx=0):
-    with _maybe_colocate_with(self._outputs_grads[idx]):
+    with maybe_colocate_with(self._outputs_grads[idx]):
       if self._patches is None:
         filter_height, filter_width, _, _ = self._filter_shape
 
@@ -676,7 +676,7 @@ class ConvDiagonalFactor(DiagonalFactor):
             padding=self._padding)
 
         if self._has_bias:
-          patches = _append_homog(patches)
+          patches = append_homog(patches)
 
         self._patches = patches
 
@@ -736,11 +736,11 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
     return self._tensors[0].dtype
 
   def _compute_new_cov(self, idx=0):
-    with _maybe_colocate_with(self._tensors[idx]):
+    with maybe_colocate_with(self._tensors[idx]):
       tensor = self._tensors[idx]
       if self._has_bias:
-        tensor = _append_homog(tensor)
-      return _compute_cov(tensor)
+        tensor = append_homog(tensor)
+      return compute_cov(tensor)
 
 
 class ConvInputKroneckerFactor(InverseProvidingFactor):
@@ -803,7 +803,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
     if idx != 0:
       raise ValueError("ConvInputKroneckerFactor only supports idx = 0")
 
-    with _maybe_colocate_with(self._inputs):
+    with maybe_colocate_with(self._inputs):
       filter_height, filter_width, in_channels, _ = self._filter_shape
 
       # TODO(b/64144716): there is potential here for a big savings in terms of
@@ -825,15 +825,15 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
       # We append a homogenous coordinate to patches_flat if the layer has
       # bias parameters. This gives us [[A_l]]_H from the paper.
       if self._has_bias:
-        patches_flat = _append_homog(patches_flat)
-      # We call _compute_cov without passing in a normalizer. _compute_cov uses
+        patches_flat = append_homog(patches_flat)
+      # We call compute_cov without passing in a normalizer. compute_cov uses
       # the first dimension of patches_flat i.e. M|T| as the normalizer by
       # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with
       # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from
       # the paper but has a different scale here for consistency with
       # ConvOutputKroneckerFactor.
       # (Tilde omitted over A for clarity.)
-      return _compute_cov(patches_flat)
+      return compute_cov(patches_flat)
 
 
 class ConvOutputKroneckerFactor(InverseProvidingFactor):
@@ -876,7 +876,7 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
     return self._outputs_grads[0].dtype
 
   def _compute_new_cov(self, idx=0):
-    with _maybe_colocate_with(self._outputs_grads[idx]):
+    with maybe_colocate_with(self._outputs_grads[idx]):
       # reshaped_tensor below is the matrix DS_l defined in the KFC paper
       # (tilde omitted over S for clarity). It has shape M|T| x I, where
       # M = minibatch size, |T| = number of spatial locations, and
@@ -884,10 +884,10 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
       reshaped_tensor = array_ops.reshape(self._outputs_grads[idx],
                                           [-1, self._out_channels])
       # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov,
-      # _compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l
+      # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l
       # as defined in the paper, with shape I x I.
       # (Tilde omitted over S for clarity.)
-      return _compute_cov(reshaped_tensor)
+      return compute_cov(reshaped_tensor)
 
 
 class FullyConnectedMultiKF(InverseProvidingFactor):
@@ -935,7 +935,7 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
       new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx)
                                    for idx in range(self._num_sources))
 
-      with _maybe_colocate_with(new_cov_dt1_contribs[0]):
+      with maybe_colocate_with(new_cov_dt1_contribs[0]):
         new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs)
 
         op2 = moving_averages.assign_moving_average(
@@ -951,17 +951,17 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
     return op
 
   def _compute_new_cov(self, idx=0):
-    with _maybe_colocate_with(self._tensor_lists[idx]):
+    with maybe_colocate_with(self._tensor_lists[idx]):
       tensor = array_ops.concat(self._tensor_lists[idx], 0)
       if self._has_bias:
-        tensor = _append_homog(tensor)
+        tensor = append_homog(tensor)
       # We save these so they can be used by _compute_new_cov_dt1
       self._tensors[idx] = tensor
-      return _compute_cov(tensor)
+      return compute_cov(tensor)
 
   def _compute_new_cov_dt1(self, idx=0):
     tensor = self._tensors[idx]
-    with _maybe_colocate_with(tensor):
+    with maybe_colocate_with(tensor):
       # Is there a more elegant way to do this computation?
       tensor_present = tensor[:-self._batch_size, :]
       tensor_future = tensor[self._batch_size:, :]
@@ -969,7 +969,7 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
       # block estimate.  This is equivalent to padding with zeros, as was done
       # in Section B.2 of the appendix.
       normalizer = self._num_timesteps * self._batch_size
-      return _compute_cov(
+      return compute_cov(
           tensor_future, tensor_right=tensor_present, normalizer=normalizer)
 
   @property
index 23ee93cd405bbf719939df89d525c812ee061f8b..ad93919149c287b1932dd2b6bd772c0dab26192d 100644 (file)
@@ -41,6 +41,9 @@ _allowed_symbols = [
     "ConvOutputKroneckerFactor",
     "ConvDiagonalFactor",
     "set_global_constants",
+    "maybe_colocate_with",
+    "compute_cov",
+    "append_homog"
 ]
 
 remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
index d6bf61a210203dd74d4e93b65005f660b1fab4ff..f8aa230d9ca1f542950f56b1e6cf1ab7ccd3d05f 100644 (file)
@@ -36,6 +36,9 @@ _allowed_symbols = [
     "APPROX_DIAGONAL_NAME",
     "APPROX_FULL_NAME",
     "VARIABLE_SCOPE",
+    "APPROX_KRONECKER_INDEP_NAME",
+    "APPROX_KRONECKER_SERIES_1_NAME",
+    "APPROX_KRONECKER_SERIES_2_NAME"
 ]
 
 remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)