ff.set_global_constants(colocate_cov_ops_with_inputs=False)
with tf_ops.Graph().as_default():
a = constant_op.constant([2.0], name='a')
- with ff._maybe_colocate_with(a):
+ with ff.maybe_colocate_with(a):
b = constant_op.constant(3.0, name='b')
self.assertEqual([b'loc:@a'], a.op.colocation_groups())
self.assertEqual([b'loc:@b'], b.op.colocation_groups())
ff.set_global_constants(colocate_cov_ops_with_inputs=True)
with tf_ops.Graph().as_default():
a = constant_op.constant([2.0], name='a')
- with ff._maybe_colocate_with(a):
+ with ff.maybe_colocate_with(a):
b = constant_op.constant(3.0, name='b')
self.assertEqual([b'loc:@a'], a.op.colocation_groups())
self.assertEqual([b'loc:@a'], b.op.colocation_groups())
random_seed.set_random_seed(200)
x = npr.randn(100, 3)
- cov = ff._compute_cov(array_ops.constant(x))
+ cov = ff.compute_cov(array_ops.constant(x))
np_cov = np.dot(x.T, x) / x.shape[0]
self.assertAllClose(sess.run(cov), np_cov)
normalizer = 10.
x = npr.randn(100, 3)
- cov = ff._compute_cov(array_ops.constant(x), normalizer=normalizer)
+ cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer)
np_cov = np.dot(x.T, x) / normalizer
self.assertAllClose(sess.run(cov), np_cov)
m, n = 3, 4
a = npr.randn(m, n)
- a_homog = ff._append_homog(array_ops.constant(a))
+ a_homog = ff.append_homog(array_ops.constant(a))
np_result = np.hstack([a, np.ones((m, 1))])
self.assertAllClose(sess.run(a_homog), np_result)
NORMALIZE_DAMPING_POWER = 1.0
# Methods for adjusting damping for FisherBlocks. See
-# _compute_pi_adjusted_damping() for details.
+# compute_pi_adjusted_damping() for details.
PI_OFF_NAME = "off"
PI_TRACENORM_NAME = "tracenorm"
PI_TYPE = PI_TRACENORM_NAME
PI_TYPE = pi_type
-def _compute_pi_tracenorm(left_cov, right_cov):
+def normalize_damping(damping, num_replications):
+ """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
+ if NORMALIZE_DAMPING_POWER:
+ return damping / (num_replications ** NORMALIZE_DAMPING_POWER)
+ return damping
+
+
+def compute_pi_tracenorm(left_cov, right_cov):
"""Computes the scalar constant pi for Tikhonov regularization/damping.
pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) )
return math_ops.sqrt(left_norm / right_norm)
-def _compute_pi_adjusted_damping(left_cov, right_cov, damping):
+def compute_pi_adjusted_damping(left_cov, right_cov, damping):
if PI_TYPE == PI_TRACENORM_NAME:
- pi = _compute_pi_tracenorm(left_cov, right_cov)
+ pi = compute_pi_tracenorm(left_cov, right_cov)
return (damping * pi, damping / pi)
elif PI_TYPE == PI_OFF_NAME:
self._num_locations = (
inputs_shape[1] * inputs_shape[2] //
(self._strides[1] * self._strides[2]))
-
- if NORMALIZE_DAMPING_POWER:
- damping /= self._num_locations**NORMALIZE_DAMPING_POWER
- self._damping = damping
+ self._damping = normalize_damping(damping, self._num_locations)
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvDiagonalFactor,
Args:
damping: The base damping factor (float or Tensor) for the damped inverse.
"""
- self._input_damping, self._output_damping = _compute_pi_adjusted_damping(
+ self._input_damping, self._output_damping = compute_pi_adjusted_damping(
self._input_factor.get_cov(),
self._output_factor.get_cov(),
damping**0.5)
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
# Infer number of locations upon which convolution is applied.
- self._num_locations = _num_conv_locations(inputs.shape.as_list(),
- self._strides)
+ self._num_locations = num_conv_locations(inputs.shape.as_list(),
+ self._strides)
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvInputKroneckerFactor,
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
- if NORMALIZE_DAMPING_POWER:
- damping /= self._num_locations**NORMALIZE_DAMPING_POWER
- self._damping = damping
-
+ damping = normalize_damping(damping, self._num_locations)
self._register_damped_input_and_output_inverses(damping)
+ self._damping = damping
@property
def _renorm_coeff(self):
return array_ops.concat(tensor_list, axis=0)
-def _num_conv_locations(input_shape, strides):
- """Returns the number of locations a Conv kernel is applied to."""
+def num_conv_locations(input_shape, strides):
+ """Returns the number of spatial locations a 2D Conv kernel is applied to.
+
+ Args:
+ input_shape: list representing shape of inputs to the Conv layer.
+ strides: list representing strides for the Conv kernel.
+
+ Returns:
+ A scalar |T| denoting the number of spatial locations for the Conv layer.
+ """
return input_shape[1] * input_shape[2] // (strides[1] * strides[2])
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF, (grads_list,))
- if NORMALIZE_DAMPING_POWER:
- damping /= self._num_uses**NORMALIZE_DAMPING_POWER
-
+ damping = normalize_damping(damping, self._num_uses)
self._register_damped_input_and_output_inverses(damping)
@property
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF, (grads_list,))
- if NORMALIZE_DAMPING_POWER:
- damping /= self._num_timesteps**NORMALIZE_DAMPING_POWER
-
- self._damping_input, self._damping_output = _compute_pi_adjusted_damping(
+ damping = normalize_damping(damping, self._num_timesteps)
+ self._damping_input, self._damping_output = compute_pi_adjusted_damping(
self._input_factor.get_cov(),
self._output_factor.get_cov(),
damping**0.5)
@contextlib.contextmanager
-def _maybe_colocate_with(op):
+def maybe_colocate_with(op):
"""Context to colocate with `op` if `COLOCATE_COV_OPS_WITH_INPUTS`."""
if COLOCATE_COV_OPS_WITH_INPUTS:
if isinstance(op, (list, tuple)):
return array_ops.ones(shape, dtype)
-def _compute_cov(tensor, tensor_right=None, normalizer=None):
+def compute_cov(tensor, tensor_right=None, normalizer=None):
"""Compute the empirical second moment of the rows of a 2D Tensor.
This function is meant to be applied to random matrices for which the true row
math_ops.cast(normalizer, tensor.dtype))
-def _append_homog(tensor):
+def append_homog(tensor):
"""Appends a homogeneous coordinate to the last dimension of a Tensor.
Args:
# the future we might want to perform the cov computations on each tower,
# so that each tower will be considered a "source" (allowing us to reuse
# the existing "source" code for this).
- with _maybe_colocate_with(new_cov_contribs[0]):
+ with maybe_colocate_with(new_cov_contribs[0]):
new_cov = math_ops.add_n(new_cov_contribs)
# Synchronize value across all TPU cores.
if utils.on_tpu():
def _compute_new_cov(self, idx=0):
# This will be a very basic rank 1 estimate
- with _maybe_colocate_with(self._params_grads[idx]):
+ with maybe_colocate_with(self._params_grads[idx]):
params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
return ((params_grads_flat * array_ops.transpose(
params_grads_flat)) / math_ops.cast(self._batch_size,
return self._params_grads[0][0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._params_grads[idx]):
+ with maybe_colocate_with(self._params_grads[idx]):
params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
return (math_ops.square(params_grads_flat) / math_ops.cast(
self._batch_size, params_grads_flat.dtype))
# square of an outer product is the outer-product of the entry-wise squares.
# The gradient is the outer product of the input and the output gradients,
# so we just square both and then take their outer-product.
- with _maybe_colocate_with(self._outputs_grads[idx]):
+ with maybe_colocate_with(self._outputs_grads[idx]):
# We only need to compute squared_inputs once
if self._squared_inputs is None:
inputs = self._inputs
if self._has_bias:
- inputs = _append_homog(self._inputs)
+ inputs = append_homog(self._inputs)
self._squared_inputs = math_ops.square(inputs)
new_cov = math_ops.matmul(
return self._outputs_grads[0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._outputs_grads[idx]):
+ with maybe_colocate_with(self._outputs_grads[idx]):
if self._patches is None:
filter_height, filter_width, _, _ = self._filter_shape
padding=self._padding)
if self._has_bias:
- patches = _append_homog(patches)
+ patches = append_homog(patches)
self._patches = patches
return self._tensors[0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._tensors[idx]):
+ with maybe_colocate_with(self._tensors[idx]):
tensor = self._tensors[idx]
if self._has_bias:
- tensor = _append_homog(tensor)
- return _compute_cov(tensor)
+ tensor = append_homog(tensor)
+ return compute_cov(tensor)
class ConvInputKroneckerFactor(InverseProvidingFactor):
if idx != 0:
raise ValueError("ConvInputKroneckerFactor only supports idx = 0")
- with _maybe_colocate_with(self._inputs):
+ with maybe_colocate_with(self._inputs):
filter_height, filter_width, in_channels, _ = self._filter_shape
# TODO(b/64144716): there is potential here for a big savings in terms of
# We append a homogenous coordinate to patches_flat if the layer has
# bias parameters. This gives us [[A_l]]_H from the paper.
if self._has_bias:
- patches_flat = _append_homog(patches_flat)
- # We call _compute_cov without passing in a normalizer. _compute_cov uses
+ patches_flat = append_homog(patches_flat)
+ # We call compute_cov without passing in a normalizer. compute_cov uses
# the first dimension of patches_flat i.e. M|T| as the normalizer by
# default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with
# shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from
# the paper but has a different scale here for consistency with
# ConvOutputKroneckerFactor.
# (Tilde omitted over A for clarity.)
- return _compute_cov(patches_flat)
+ return compute_cov(patches_flat)
class ConvOutputKroneckerFactor(InverseProvidingFactor):
return self._outputs_grads[0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._outputs_grads[idx]):
+ with maybe_colocate_with(self._outputs_grads[idx]):
# reshaped_tensor below is the matrix DS_l defined in the KFC paper
# (tilde omitted over S for clarity). It has shape M|T| x I, where
# M = minibatch size, |T| = number of spatial locations, and
reshaped_tensor = array_ops.reshape(self._outputs_grads[idx],
[-1, self._out_channels])
# Following the reasoning in ConvInputKroneckerFactor._compute_new_cov,
- # _compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l
+ # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l
# as defined in the paper, with shape I x I.
# (Tilde omitted over S for clarity.)
- return _compute_cov(reshaped_tensor)
+ return compute_cov(reshaped_tensor)
class FullyConnectedMultiKF(InverseProvidingFactor):
new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx)
for idx in range(self._num_sources))
- with _maybe_colocate_with(new_cov_dt1_contribs[0]):
+ with maybe_colocate_with(new_cov_dt1_contribs[0]):
new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs)
op2 = moving_averages.assign_moving_average(
return op
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._tensor_lists[idx]):
+ with maybe_colocate_with(self._tensor_lists[idx]):
tensor = array_ops.concat(self._tensor_lists[idx], 0)
if self._has_bias:
- tensor = _append_homog(tensor)
+ tensor = append_homog(tensor)
# We save these so they can be used by _compute_new_cov_dt1
self._tensors[idx] = tensor
- return _compute_cov(tensor)
+ return compute_cov(tensor)
def _compute_new_cov_dt1(self, idx=0):
tensor = self._tensors[idx]
- with _maybe_colocate_with(tensor):
+ with maybe_colocate_with(tensor):
# Is there a more elegant way to do this computation?
tensor_present = tensor[:-self._batch_size, :]
tensor_future = tensor[self._batch_size:, :]
# block estimate. This is equivalent to padding with zeros, as was done
# in Section B.2 of the appendix.
normalizer = self._num_timesteps * self._batch_size
- return _compute_cov(
+ return compute_cov(
tensor_future, tensor_right=tensor_present, normalizer=normalizer)
@property