From 3252d0d41c15c1b26376c9c86c537aa275a1bb65 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 17 Jan 2018 05:03:05 -0800 Subject: [PATCH] K-FAC: Expose protected functions from fisher_blocks and fisher_factors and constant strings from layer_collection in their respective library modules. This allows consistent development of blocks and factors outside tensorflow.contrib.kfac. PiperOrigin-RevId: 182197356 --- .../python/kernel_tests/fisher_blocks_test.py | 2 +- .../kernel_tests/fisher_factors_test.py | 10 ++-- .../contrib/kfac/python/ops/fisher_blocks.py | 54 ++++++++++--------- .../kfac/python/ops/fisher_blocks_lib.py | 4 ++ .../contrib/kfac/python/ops/fisher_factors.py | 52 +++++++++--------- .../kfac/python/ops/fisher_factors_lib.py | 3 ++ .../kfac/python/ops/layer_collection_lib.py | 3 ++ 7 files changed, 72 insertions(+), 56 deletions(-) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 2d9b28185c..82accd57f0 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -49,7 +49,7 @@ class UtilsTest(test.TestCase): right_factor = array_ops.ones([2., 2.]) # pi is the sqrt of the left trace norm divided by the right trace norm - pi = fb._compute_pi_tracenorm(left_factor, right_factor) + pi = fb.compute_pi_tracenorm(left_factor, right_factor) pi_val = sess.run(pi) self.assertEqual(1., pi_val) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index a2665b9279..753378d9f4 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -46,7 +46,7 @@ class MaybeColocateTest(test.TestCase): ff.set_global_constants(colocate_cov_ops_with_inputs=False) with tf_ops.Graph().as_default(): a = constant_op.constant([2.0], name='a') - with ff._maybe_colocate_with(a): + with ff.maybe_colocate_with(a): b = constant_op.constant(3.0, name='b') self.assertEqual([b'loc:@a'], a.op.colocation_groups()) self.assertEqual([b'loc:@b'], b.op.colocation_groups()) @@ -55,7 +55,7 @@ class MaybeColocateTest(test.TestCase): ff.set_global_constants(colocate_cov_ops_with_inputs=True) with tf_ops.Graph().as_default(): a = constant_op.constant([2.0], name='a') - with ff._maybe_colocate_with(a): + with ff.maybe_colocate_with(a): b = constant_op.constant(3.0, name='b') self.assertEqual([b'loc:@a'], a.op.colocation_groups()) self.assertEqual([b'loc:@a'], b.op.colocation_groups()) @@ -129,7 +129,7 @@ class NumericalUtilsTest(test.TestCase): random_seed.set_random_seed(200) x = npr.randn(100, 3) - cov = ff._compute_cov(array_ops.constant(x)) + cov = ff.compute_cov(array_ops.constant(x)) np_cov = np.dot(x.T, x) / x.shape[0] self.assertAllClose(sess.run(cov), np_cov) @@ -141,7 +141,7 @@ class NumericalUtilsTest(test.TestCase): normalizer = 10. x = npr.randn(100, 3) - cov = ff._compute_cov(array_ops.constant(x), normalizer=normalizer) + cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer) np_cov = np.dot(x.T, x) / normalizer self.assertAllClose(sess.run(cov), np_cov) @@ -152,7 +152,7 @@ class NumericalUtilsTest(test.TestCase): m, n = 3, 4 a = npr.randn(m, n) - a_homog = ff._append_homog(array_ops.constant(a)) + a_homog = ff.append_homog(array_ops.constant(a)) np_result = np.hstack([a, np.ones((m, 1))]) self.assertAllClose(sess.run(a_homog), np_result) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 1ccb9e040f..9436caf961 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -54,7 +54,7 @@ from tensorflow.python.ops import math_ops NORMALIZE_DAMPING_POWER = 1.0 # Methods for adjusting damping for FisherBlocks. See -# _compute_pi_adjusted_damping() for details. +# compute_pi_adjusted_damping() for details. PI_OFF_NAME = "off" PI_TRACENORM_NAME = "tracenorm" PI_TYPE = PI_TRACENORM_NAME @@ -72,7 +72,14 @@ def set_global_constants(normalize_damping_power=None, pi_type=None): PI_TYPE = pi_type -def _compute_pi_tracenorm(left_cov, right_cov): +def normalize_damping(damping, num_replications): + """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER.""" + if NORMALIZE_DAMPING_POWER: + return damping / (num_replications ** NORMALIZE_DAMPING_POWER) + return damping + + +def compute_pi_tracenorm(left_cov, right_cov): """Computes the scalar constant pi for Tikhonov regularization/damping. pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) @@ -92,10 +99,10 @@ def _compute_pi_tracenorm(left_cov, right_cov): return math_ops.sqrt(left_norm / right_norm) -def _compute_pi_adjusted_damping(left_cov, right_cov, damping): +def compute_pi_adjusted_damping(left_cov, right_cov, damping): if PI_TYPE == PI_TRACENORM_NAME: - pi = _compute_pi_tracenorm(left_cov, right_cov) + pi = compute_pi_tracenorm(left_cov, right_cov) return (damping * pi, damping / pi) elif PI_TYPE == PI_OFF_NAME: @@ -450,10 +457,7 @@ class ConvDiagonalFB(FisherBlock): self._num_locations = ( inputs_shape[1] * inputs_shape[2] // (self._strides[1] * self._strides[2])) - - if NORMALIZE_DAMPING_POWER: - damping /= self._num_locations**NORMALIZE_DAMPING_POWER - self._damping = damping + self._damping = normalize_damping(damping, self._num_locations) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvDiagonalFactor, @@ -506,7 +510,7 @@ class KroneckerProductFB(FisherBlock): Args: damping: The base damping factor (float or Tensor) for the damped inverse. """ - self._input_damping, self._output_damping = _compute_pi_adjusted_damping( + self._input_damping, self._output_damping = compute_pi_adjusted_damping( self._input_factor.get_cov(), self._output_factor.get_cov(), damping**0.5) @@ -691,8 +695,8 @@ class ConvKFCBasicFB(KroneckerProductFB): grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) # Infer number of locations upon which convolution is applied. - self._num_locations = _num_conv_locations(inputs.shape.as_list(), - self._strides) + self._num_locations = num_conv_locations(inputs.shape.as_list(), + self._strides) self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputKroneckerFactor, @@ -701,11 +705,9 @@ class ConvKFCBasicFB(KroneckerProductFB): self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) - if NORMALIZE_DAMPING_POWER: - damping /= self._num_locations**NORMALIZE_DAMPING_POWER - self._damping = damping - + damping = normalize_damping(damping, self._num_locations) self._register_damped_input_and_output_inverses(damping) + self._damping = damping @property def _renorm_coeff(self): @@ -758,8 +760,16 @@ def _concat_along_batch_dim(tensor_list): return array_ops.concat(tensor_list, axis=0) -def _num_conv_locations(input_shape, strides): - """Returns the number of locations a Conv kernel is applied to.""" +def num_conv_locations(input_shape, strides): + """Returns the number of spatial locations a 2D Conv kernel is applied to. + + Args: + input_shape: list representing shape of inputs to the Conv layer. + strides: list representing strides for the Conv kernel. + + Returns: + A scalar |T| denoting the number of spatial locations for the Conv layer. + """ return input_shape[1] * input_shape[2] // (strides[1] * strides[2]) @@ -804,9 +814,7 @@ class FullyConnectedMultiIndepFB(KroneckerProductFB): self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedMultiKF, (grads_list,)) - if NORMALIZE_DAMPING_POWER: - damping /= self._num_uses**NORMALIZE_DAMPING_POWER - + damping = normalize_damping(damping, self._num_uses) self._register_damped_input_and_output_inverses(damping) @property @@ -885,10 +893,8 @@ class FullyConnectedSeriesFB(FisherBlock): self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedMultiKF, (grads_list,)) - if NORMALIZE_DAMPING_POWER: - damping /= self._num_timesteps**NORMALIZE_DAMPING_POWER - - self._damping_input, self._damping_output = _compute_pi_adjusted_damping( + damping = normalize_damping(damping, self._num_timesteps) + self._damping_input, self._damping_output = compute_pi_adjusted_damping( self._input_factor.get_cov(), self._output_factor.get_cov(), damping**0.5) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py index 59389f8d38..ac39630920 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py @@ -33,6 +33,10 @@ _allowed_symbols = [ 'ConvKFCBasicFB', 'ConvDiagonalFB', 'set_global_constants', + 'compute_pi_tracenorm', + 'compute_pi_adjusted_damping', + 'num_conv_locations', + 'normalize_damping' ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 826e8b7732..a069f6bdd9 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -58,7 +58,7 @@ COLOCATE_COV_OPS_WITH_INPUTS = True @contextlib.contextmanager -def _maybe_colocate_with(op): +def maybe_colocate_with(op): """Context to colocate with `op` if `COLOCATE_COV_OPS_WITH_INPUTS`.""" if COLOCATE_COV_OPS_WITH_INPUTS: if isinstance(op, (list, tuple)): @@ -111,7 +111,7 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di return array_ops.ones(shape, dtype) -def _compute_cov(tensor, tensor_right=None, normalizer=None): +def compute_cov(tensor, tensor_right=None, normalizer=None): """Compute the empirical second moment of the rows of a 2D Tensor. This function is meant to be applied to random matrices for which the true row @@ -139,7 +139,7 @@ def _compute_cov(tensor, tensor_right=None, normalizer=None): math_ops.cast(normalizer, tensor.dtype)) -def _append_homog(tensor): +def append_homog(tensor): """Appends a homogeneous coordinate to the last dimension of a Tensor. Args: @@ -281,7 +281,7 @@ class FisherFactor(object): # the future we might want to perform the cov computations on each tower, # so that each tower will be considered a "source" (allowing us to reuse # the existing "source" code for this). - with _maybe_colocate_with(new_cov_contribs[0]): + with maybe_colocate_with(new_cov_contribs[0]): new_cov = math_ops.add_n(new_cov_contribs) # Synchronize value across all TPU cores. if utils.on_tpu(): @@ -472,7 +472,7 @@ class FullFactor(InverseProvidingFactor): def _compute_new_cov(self, idx=0): # This will be a very basic rank 1 estimate - with _maybe_colocate_with(self._params_grads[idx]): + with maybe_colocate_with(self._params_grads[idx]): params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) return ((params_grads_flat * array_ops.transpose( params_grads_flat)) / math_ops.cast(self._batch_size, @@ -528,7 +528,7 @@ class NaiveDiagonalFactor(DiagonalFactor): return self._params_grads[0][0].dtype def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._params_grads[idx]): + with maybe_colocate_with(self._params_grads[idx]): params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) return (math_ops.square(params_grads_flat) / math_ops.cast( self._batch_size, params_grads_flat.dtype)) @@ -589,12 +589,12 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): # square of an outer product is the outer-product of the entry-wise squares. # The gradient is the outer product of the input and the output gradients, # so we just square both and then take their outer-product. - with _maybe_colocate_with(self._outputs_grads[idx]): + with maybe_colocate_with(self._outputs_grads[idx]): # We only need to compute squared_inputs once if self._squared_inputs is None: inputs = self._inputs if self._has_bias: - inputs = _append_homog(self._inputs) + inputs = append_homog(self._inputs) self._squared_inputs = math_ops.square(inputs) new_cov = math_ops.matmul( @@ -662,7 +662,7 @@ class ConvDiagonalFactor(DiagonalFactor): return self._outputs_grads[0].dtype def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._outputs_grads[idx]): + with maybe_colocate_with(self._outputs_grads[idx]): if self._patches is None: filter_height, filter_width, _, _ = self._filter_shape @@ -676,7 +676,7 @@ class ConvDiagonalFactor(DiagonalFactor): padding=self._padding) if self._has_bias: - patches = _append_homog(patches) + patches = append_homog(patches) self._patches = patches @@ -736,11 +736,11 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): return self._tensors[0].dtype def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._tensors[idx]): + with maybe_colocate_with(self._tensors[idx]): tensor = self._tensors[idx] if self._has_bias: - tensor = _append_homog(tensor) - return _compute_cov(tensor) + tensor = append_homog(tensor) + return compute_cov(tensor) class ConvInputKroneckerFactor(InverseProvidingFactor): @@ -803,7 +803,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): if idx != 0: raise ValueError("ConvInputKroneckerFactor only supports idx = 0") - with _maybe_colocate_with(self._inputs): + with maybe_colocate_with(self._inputs): filter_height, filter_width, in_channels, _ = self._filter_shape # TODO(b/64144716): there is potential here for a big savings in terms of @@ -825,15 +825,15 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): # We append a homogenous coordinate to patches_flat if the layer has # bias parameters. This gives us [[A_l]]_H from the paper. if self._has_bias: - patches_flat = _append_homog(patches_flat) - # We call _compute_cov without passing in a normalizer. _compute_cov uses + patches_flat = append_homog(patches_flat) + # We call compute_cov without passing in a normalizer. compute_cov uses # the first dimension of patches_flat i.e. M|T| as the normalizer by # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from # the paper but has a different scale here for consistency with # ConvOutputKroneckerFactor. # (Tilde omitted over A for clarity.) - return _compute_cov(patches_flat) + return compute_cov(patches_flat) class ConvOutputKroneckerFactor(InverseProvidingFactor): @@ -876,7 +876,7 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): return self._outputs_grads[0].dtype def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._outputs_grads[idx]): + with maybe_colocate_with(self._outputs_grads[idx]): # reshaped_tensor below is the matrix DS_l defined in the KFC paper # (tilde omitted over S for clarity). It has shape M|T| x I, where # M = minibatch size, |T| = number of spatial locations, and @@ -884,10 +884,10 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): reshaped_tensor = array_ops.reshape(self._outputs_grads[idx], [-1, self._out_channels]) # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, - # _compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l + # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l # as defined in the paper, with shape I x I. # (Tilde omitted over S for clarity.) - return _compute_cov(reshaped_tensor) + return compute_cov(reshaped_tensor) class FullyConnectedMultiKF(InverseProvidingFactor): @@ -935,7 +935,7 @@ class FullyConnectedMultiKF(InverseProvidingFactor): new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx) for idx in range(self._num_sources)) - with _maybe_colocate_with(new_cov_dt1_contribs[0]): + with maybe_colocate_with(new_cov_dt1_contribs[0]): new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs) op2 = moving_averages.assign_moving_average( @@ -951,17 +951,17 @@ class FullyConnectedMultiKF(InverseProvidingFactor): return op def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._tensor_lists[idx]): + with maybe_colocate_with(self._tensor_lists[idx]): tensor = array_ops.concat(self._tensor_lists[idx], 0) if self._has_bias: - tensor = _append_homog(tensor) + tensor = append_homog(tensor) # We save these so they can be used by _compute_new_cov_dt1 self._tensors[idx] = tensor - return _compute_cov(tensor) + return compute_cov(tensor) def _compute_new_cov_dt1(self, idx=0): tensor = self._tensors[idx] - with _maybe_colocate_with(tensor): + with maybe_colocate_with(tensor): # Is there a more elegant way to do this computation? tensor_present = tensor[:-self._batch_size, :] tensor_future = tensor[self._batch_size:, :] @@ -969,7 +969,7 @@ class FullyConnectedMultiKF(InverseProvidingFactor): # block estimate. This is equivalent to padding with zeros, as was done # in Section B.2 of the appendix. normalizer = self._num_timesteps * self._batch_size - return _compute_cov( + return compute_cov( tensor_future, tensor_right=tensor_present, normalizer=normalizer) @property diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py index 23ee93cd40..ad93919149 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py @@ -41,6 +41,9 @@ _allowed_symbols = [ "ConvOutputKroneckerFactor", "ConvDiagonalFactor", "set_global_constants", + "maybe_colocate_with", + "compute_cov", + "append_homog" ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py index d6bf61a210..f8aa230d9c 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py @@ -36,6 +36,9 @@ _allowed_symbols = [ "APPROX_DIAGONAL_NAME", "APPROX_FULL_NAME", "VARIABLE_SCOPE", + "APPROX_KRONECKER_INDEP_NAME", + "APPROX_KRONECKER_SERIES_1_NAME", + "APPROX_KRONECKER_SERIES_2_NAME" ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) -- 2.34.1