- Adds support for shared embedding layers (e.g. in RNNs), and shared Conv2D layers.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 16 Mar 2018 14:54:08 +0000 (07:54 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 16 Mar 2018 14:58:33 +0000 (07:58 -0700)
- Some minor refactoring of internal structure in fisher_blocks and layer_collection

PiperOrigin-RevId: 189338874

tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
tensorflow/contrib/kfac/python/ops/fisher_blocks.py
tensorflow/contrib/kfac/python/ops/layer_collection.py
tensorflow/contrib/kfac/python/ops/utils.py

index bae6bd7..ba22099 100644 (file)
@@ -135,8 +135,22 @@ class LayerCollectionTest(test.TestCase):
           array_ops.constant(6),
           16,
           approx=layer_collection.APPROX_DIAGONAL_NAME)
-
-      self.assertEqual(9, len(lc.get_blocks()))
+      lc.register_fully_connected_multi(
+          array_ops.constant(1),
+          (array_ops.constant(2), array_ops.constant(3)),
+          (array_ops.constant(4), array_ops.constant(5)))
+      lc.register_conv2d_multi(
+          params=array_ops.ones((2, 3, 4, 5)),
+          strides=[1, 1, 1, 1],
+          padding='SAME',
+          inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))),
+          outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10))))
+      lc.register_embedding_multi(
+          array_ops.constant((1,)),
+          (array_ops.constant(2), array_ops.constant(3)),
+          (array_ops.constant(4), array_ops.constant(5)))
+
+      self.assertEqual(12, len(lc.get_blocks()))
 
   def testRegisterBlocksMultipleRegistrations(self):
     with ops.Graph().as_default():
index 31f4689..79d0424 100644 (file)
@@ -48,6 +48,7 @@ from tensorflow.contrib.kfac.python.ops import utils
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.util import nest
 
 # For blocks corresponding to convolutional layers, or any type of block where
 # the parameters can be thought of as being replicated in time or space,
@@ -74,6 +75,86 @@ def set_global_constants(normalize_damping_power=None, pi_type=None):
     PI_TYPE = pi_type
 
 
+def _make_partitionedtensors_inputs(inputs):
+  """Constructs PartitionedTensor for inputs.
+
+  The purpose of this method is to package up the towers/minibatch dimension
+  of these arrays into PartitionedTensor objects.
+
+  Args:
+    inputs: a 1-D list of Tensors. Index is tower/mini-batch.
+
+  Returns:
+    A PartitionedTensor.
+  """
+  return utils.PartitionedTensor(inputs)
+
+
+def _make_partitionedtensors_grads(grads_list):
+  """Constructs PartitionedTensor for grads_list.
+
+  The purpose of this method is to package up the towers/minibatch dimension
+  of these arrays into PartitionedTensor objects.
+
+  Args:
+    grads_list: 2-D list of Tensors. First index is for source, second
+      index for tower.
+
+  Returns:
+    Tuple of PartitionedTensors, one per source.
+  """
+  return tuple(utils.PartitionedTensor(grads) for grads in grads_list)
+
+
+def _make_partitionedtensors_multi_inputs(inputs):
+  """Constructs PartitionedTensors for inputs.
+
+  The purpose of this method is to package up the towers/minibatch dimension
+  of these arrays into PartitionedTensor objects.
+
+  This version of this function is for use with FisherBlocks that deal with
+  multiple uses or time-steps. One PartitionedTensor is created for each
+  use/time-step.  The FisherBlock will be responsible for concatenating
+  (or doing whatever else it wants) with the resulting lists.
+
+  Args:
+    inputs: a 2-D list of Tensors. First index is tower/mini-batch, second is
+      use/time-step.
+
+  Returns:
+    A tuple of PartitionedTensor's, one per use/time-step.
+  """
+  num_uses = len(inputs[0])
+  assert all(len(input_) == num_uses for input_ in inputs)
+
+  return tuple(utils.PartitionedTensor(input_) for input_ in zip(*inputs))
+
+
+def _make_partitionedtensors_multi_grads(grads_list):
+  """Constructs PartitionedTensors for grads_list.
+
+  The purpose of this method is to package up the towers/minibatch dimension
+  of these arrays into PartitionedTensor objects.
+
+  This version of this function is for use with FisherBlocks that deal with
+  multiple uses or time-steps. One PartitionedTensor is created for each
+  use/time-step.  The FisherBlock will be responsible for concatenating
+  (or doing whatever else it wants) with the resulting lists.
+
+  Args:
+    grads_list: 3-D list of Tensors. First index is for source, second is for
+      tower, third is for use/time-step.
+
+  Returns:
+    2-D tuple of PartitionedTensors. First index is for source, second is for
+    use/time-step.
+  """
+  num_uses = len(grads_list[0][0])
+  assert all(len(grad) == num_uses for grads in grads_list for grad in grads)
+  return tuple(tuple(utils.PartitionedTensor(grad)
+                     for grad in zip(*grads)) for grads in grads_list)
+
+
 def normalize_damping(damping, num_replications):
   """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
   if NORMALIZE_DAMPING_POWER:
@@ -396,57 +477,6 @@ class InputOutputMultiMinibatch(object):
   def _outputs(self):
     return self.__outputs
 
-  def _package_minibatches(self, grads_list):
-    """Constructs PartitionedTensor for inputs, grads_list.
-
-    The purpose of this method is to package up the towers/minibatch dimension
-    of these arrays into PartitionedTensor objects.
-
-    Args:
-      grads_list: 2-D list of Tensors. First index is for source, second
-        index for tower.
-
-    Returns:
-      inputs: PartitionedTensor.
-      grads_list: Tuple of PartitionedTensors, one per source.
-    """
-    inputs = utils.PartitionedTensor(self._inputs)
-    grads_list = tuple(utils.PartitionedTensor(grads) for grads in grads_list)
-
-    return inputs, grads_list
-
-  def _package_minibatches_multi(self, grads_list):
-    """Constructs PartitionedTensors for inputs, grads_list.
-
-    The purpose of this method is to package up the towers/minibatch dimension
-    of these arrays into PartitionedTensor objects.
-
-    This version of this function is for use with FisherBlocks that deal with
-    multiple uses or time-steps. One PartitionedTensor is created for each
-    use/time-step.
-
-    Args:
-      grads_list: 3-D tuple of Tensors. First index is for source, second
-        index is for tower, third is for use/time-step.
-
-    Returns:
-      inputs: A tuple of PartitionedTensor's, one per use/time-step.
-      grads_list: 2-D tuple of PartitionedTensors. First index is for source,
-        second is for use/time-step.
-    """
-    # self._inputs is a 2-D tuple.  First index is tower/mini-batch, second is
-    # use/time-step.
-    inputs = self._inputs
-    num_uses = len(inputs[0])
-    assert all(len(input_) == num_uses for input_ in inputs)
-    assert all(len(grad) == num_uses for grads in grads_list for grad in grads)
-
-    inputs = tuple(utils.PartitionedTensor(input_) for input_ in zip(*inputs))
-    grads_list = tuple(tuple(utils.PartitionedTensor(grad)
-                             for grad in zip(*grads)) for grads in grads_list)
-
-    return inputs, grads_list
-
 
 class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
   """FisherBlock for fully-connected (dense) layers using a diagonal approx.
@@ -485,7 +515,8 @@ class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
     super(FullyConnectedDiagonalFB, self).__init__(layer_collection)
 
   def instantiate_factors(self, grads_list, damping):
-    inputs, grads_list = self._package_minibatches(grads_list)
+    inputs = _make_partitionedtensors_inputs(self._inputs)
+    grads_list = _make_partitionedtensors_grads(grads_list)
 
     self._factor = self._layer_collection.make_or_get_factor(
         fisher_factors.FullyConnectedDiagonalFactor,
@@ -598,7 +629,8 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
     super(ConvDiagonalFB, self).__init__(layer_collection)
 
   def instantiate_factors(self, grads_list, damping):
-    inputs, grads_list = self._package_minibatches(grads_list)
+    inputs = _make_partitionedtensors_inputs(self._inputs)
+    grads_list = _make_partitionedtensors_grads(grads_list)
 
     # Infer number of locations upon which convolution is applied.
     self._num_locations = num_conv_locations(inputs.shape.as_list(),
@@ -711,7 +743,7 @@ class KroneckerProductFB(FisherBlock):
 class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB):
   """K-FAC FisherBlock for embedding layers.
 
-  This FisherBlock is similar to EmbeddingKFACFB, except that its
+  This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its
   input factor is approximated by a diagonal matrix. In the case that each
   example references exactly one embedding, this approximation is exact.
 
@@ -740,17 +772,78 @@ class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB):
       damping: 0-D Tensor or float. 'damping' * identity is approximately added
         to this FisherBlock's Fisher approximation.
     """
-    inputs, grads_list = self._package_minibatches(grads_list)
+    inputs = _make_partitionedtensors_inputs(self._inputs)
+    grads_list = _make_partitionedtensors_grads(grads_list)
 
-    self._input_factor = self._layer_collection.make_or_get_factor(  #
-        fisher_factors.EmbeddingInputKroneckerFactor,  #
+    self._input_factor = self._layer_collection.make_or_get_factor(
+        fisher_factors.EmbeddingInputKroneckerFactor,
         (inputs, self._vocab_size))
-    self._output_factor = self._layer_collection.make_or_get_factor(  #
-        fisher_factors.FullyConnectedKroneckerFactor,  #
-        (grads_list,))
+    self._output_factor = self._layer_collection.make_or_get_factor(
+        fisher_factors.FullyConnectedKroneckerFactor, (grads_list,))
     self._setup_damping(damping)
 
 
+class EmbeddingKFACMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
+  """K-FAC FisherBlock for embedding layers used multiple times in the graph.
+
+  Similar to EmbeddingKFACFB except that this version supports multiple uses
+  of the parameter within a single model. These uses could correspond to
+  "time-steps", but they don't have to.
+
+  Does not support bias parameters.
+  """
+
+  def __init__(self, layer_collection, vocab_size):
+    """Creates a EmbeddingKFACMultiIndepFB block.
+
+    Args:
+      layer_collection: The collection of all layers in the K-FAC approximate
+          Fisher information matrix to which this FisherBlock belongs.
+      vocab_size: int. Size of vocabulary for this embedding layer.
+    """
+    self._vocab_size = vocab_size
+
+    super(EmbeddingKFACMultiIndepFB, self).__init__(layer_collection)
+
+  def instantiate_factors(self, grads_list, damping):
+    """Instantiate Kronecker Factors for this FisherBlock.
+
+    Args:
+      grads_list: List of list of list of Tensors. grads_list[i][j][k] is the
+        gradient of the loss with respect to 'outputs' from source 'i',
+        tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape
+        [tower_minibatch_size, output_size].
+      damping: 0-D Tensor or float. 'damping' * identity is approximately added
+        to this FisherBlock's Fisher approximation.
+    """
+    inputs = self._inputs
+    self._num_uses = num_uses = len(inputs[0])
+
+    # Check that all mini-batches/towers have the same number of uses
+    assert all(len(input_) == num_uses for input_ in inputs)
+    # Do the same for grads_list
+    assert all(len(grad) == num_uses for grad in grads for grads in grads_list)
+    # Merge uses and towers/minibatches dimensions together so we can handle
+    # it using a non-multi factor.
+    inputs = nest.flatten(inputs)
+
+    # Note that we call the multi version of make_partitionedtensors only for
+    # grads_list here.
+    inputs = _make_partitionedtensors_inputs(inputs)
+    grads_list = _make_partitionedtensors_multi_grads(grads_list)
+
+    self._input_factor = self._layer_collection.make_or_get_factor(
+        fisher_factors.EmbeddingInputKroneckerFactor,
+        (inputs, self._vocab_size))
+    self._output_factor = self._layer_collection.make_or_get_factor(
+        fisher_factors.FullyConnectedMultiKF, (grads_list,))
+    self._setup_damping(damping, normalization=num_uses)
+
+  @property
+  def _renorm_coeff(self):
+    return self._num_uses
+
+
 class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
   """K-FAC FisherBlock for fully-connected (dense) layers.
 
@@ -781,13 +874,14 @@ class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
       damping: 0-D Tensor or float. 'damping' * identity is approximately added
         to this FisherBlock's Fisher approximation.
     """
-    inputs, grads_list = self._package_minibatches(grads_list)
+    inputs = _make_partitionedtensors_inputs(self._inputs)
+    grads_list = _make_partitionedtensors_grads(grads_list)
 
-    self._input_factor = self._layer_collection.make_or_get_factor(  #
-        fisher_factors.FullyConnectedKroneckerFactor,  #
+    self._input_factor = self._layer_collection.make_or_get_factor(
+        fisher_factors.FullyConnectedKroneckerFactor,
         ((inputs,), self._has_bias))
-    self._output_factor = self._layer_collection.make_or_get_factor(  #
-        fisher_factors.FullyConnectedKroneckerFactor,  #
+    self._output_factor = self._layer_collection.make_or_get_factor(
+        fisher_factors.FullyConnectedKroneckerFactor,
         (grads_list,))
     self._setup_damping(damping)
 
@@ -858,12 +952,13 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
     super(ConvKFCBasicFB, self).__init__(layer_collection)
 
   def instantiate_factors(self, grads_list, damping):
-    inputs, grads_list = self._package_minibatches(grads_list)
-
     # Infer number of locations upon which convolution is applied.
     self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(),
                                              self._strides)
 
+    inputs = _make_partitionedtensors_inputs(self._inputs)
+    grads_list = _make_partitionedtensors_grads(grads_list)
+
     self._input_factor = self._layer_collection.make_or_get_factor(
         fisher_factors.ConvInputKroneckerFactor,
         (inputs, self._filter_shape, self._padding, self._strides,
@@ -1139,6 +1234,10 @@ def num_conv_locations(input_shape, strides):
 
 class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
   """FisherBlock for fully-connected layers that share parameters.
+
+  This class implements the "independence across time" approximation from the
+  following paper:
+    https://openreview.net/pdf?id=HyMTkQZAb
   """
 
   def __init__(self, layer_collection, has_bias=False):
@@ -1156,7 +1255,8 @@ class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
   def instantiate_factors(self, grads_list, damping):
 
     self._num_uses = float(len(self._inputs[0]))
-    inputs, grads_list = self._package_minibatches_multi(grads_list)
+    inputs = _make_partitionedtensors_multi_inputs(self._inputs)
+    grads_list = _make_partitionedtensors_multi_grads(grads_list)
 
     self._input_factor = self._layer_collection.make_or_get_factor(
         fisher_factors.FullyConnectedMultiKF,
@@ -1175,6 +1275,92 @@ class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
     return self._outputs
 
 
+class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
+  """FisherBlock for 2D convolutional layers using the basic KFC approx.
+
+  Similar to ConvKFCBasicFB except that this version supports multiple
+  uses/time-steps via a standard independence approximation.  Similar to the
+  "independence across time" used in FullyConnectedMultiIndepFB but generalized
+  in the obvious way to conv layers.
+  """
+
+  def __init__(self,
+               layer_collection,
+               params,
+               padding,
+               strides=None,
+               dilation_rate=None,
+               data_format=None,
+               extract_patches_fn=None):
+    """Creates a ConvKFCBasicMultiIndepFB block.
+
+    Args:
+      layer_collection: The collection of all layers in the K-FAC approximate
+          Fisher information matrix to which this FisherBlock belongs.
+      params: The parameters (Tensor or tuple of Tensors) of this layer. If
+        kernel alone, a Tensor of shape [..spatial_filter_shape..,
+        in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
+        containing the previous and a Tensor of shape [out_channels].
+      padding: str. Padding method.
+      strides: List of ints or None. Contains [..spatial_filter_strides..] if
+        'extract_patches_fn' is compatible with tf.nn.convolution(), else
+        [1, ..spatial_filter_strides, 1].
+      dilation_rate: List of ints or None. Rate for dilation along each spatial
+        dimension if 'extract_patches_fn' is compatible with
+        tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
+      data_format: str or None. Format of input data.
+      extract_patches_fn: str or None. Name of function that extracts image
+        patches. One of "extract_convolution_patches", "extract_image_patches",
+        "extract_pointwise_conv2d_patches".
+    """
+    self._padding = padding
+    self._strides = maybe_tuple(strides)
+    self._dilation_rate = maybe_tuple(dilation_rate)
+    self._data_format = data_format
+    self._extract_patches_fn = extract_patches_fn
+    self._has_bias = isinstance(params, (tuple, list))
+
+    fltr = params[0] if self._has_bias else params
+    self._filter_shape = tuple(fltr.shape.as_list())
+
+    super(ConvKFCBasicMultiIndepFB, self).__init__(layer_collection)
+
+  def instantiate_factors(self, grads_list, damping):
+    # Infer number of locations upon which convolution is applied.
+    self._num_locations = num_locations = num_conv_locations(
+        self._inputs[0][0].shape.as_list(), self._strides)
+
+    # The first index is tower/minibatch, the second is use/time-step
+    inputs = self._inputs
+    self._num_uses = num_uses = len(inputs[0])
+
+    # Check that all mini-batches/towers have the same number of uses
+    assert all(len(input_) == num_uses for input_ in inputs)
+    assert all(len(grad) == num_uses for grads in grads_list for grad in grads)
+
+    # Fold uses/time-step and towers/minibatches dimensions together
+    inputs = nest.flatten(inputs)
+    # And do the same for grads_list
+    grads_list = tuple(nest.flatten(grads) for grads in grads_list)
+
+    inputs = _make_partitionedtensors_inputs(inputs)
+    grads_list = _make_partitionedtensors_grads(grads_list)
+
+    self._input_factor = self._layer_collection.make_or_get_factor(
+        fisher_factors.ConvInputKroneckerFactor,
+        (inputs, self._filter_shape, self._padding, self._strides,
+         self._dilation_rate, self._data_format, self._extract_patches_fn,
+         self._has_bias))
+    self._output_factor = self._layer_collection.make_or_get_factor(
+        fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
+
+    self._setup_damping(damping, normalization=(num_locations * num_uses))
+
+  @property
+  def _renorm_coeff(self):
+    return self._num_locations * self._num_uses
+
+
 class SeriesFBApproximation(enum.IntEnum):
   """See FullyConnectedSeriesFB.__init__ for description and usage."""
   option1 = 1
@@ -1184,7 +1370,8 @@ class SeriesFBApproximation(enum.IntEnum):
 class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
   """FisherBlock for fully-connected layers that share parameters across time.
 
-  See the following preprint for details:
+  This class implements the "Option 1" and "Option 2" approximation from the
+  following paper:
     https://openreview.net/pdf?id=HyMTkQZAb
 
   See the end of the appendix of the paper for a pseudo-code of the
@@ -1218,7 +1405,10 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
   def instantiate_factors(self, grads_list, damping):
 
     self._num_timesteps = len(self._inputs[0])
-    inputs, grads_list = self._package_minibatches_multi(grads_list)
+    assert len(grads_list[0][0]) == self._num_timesteps
+
+    inputs = _make_partitionedtensors_multi_inputs(self._inputs)
+    grads_list = _make_partitionedtensors_multi_grads(grads_list)
 
     self._input_factor = self._layer_collection.make_or_get_factor(
         fisher_factors.FullyConnectedMultiKF, ((inputs,), self._has_bias))
index 4eb5e4c..00eae8b 100644 (file)
@@ -60,6 +60,10 @@ _CONV2D_APPROX_TO_BLOCK_TYPES = {
     APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,
 }
 
+_EMBEDDING_APPROX_TO_BLOCK_TYPES = {
+    APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB
+}
+
 APPROX_KRONECKER_INDEP_NAME = "kron_indep"
 APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1"
 APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2"
@@ -72,6 +76,14 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = {
                                             option=2)
 }
 
+_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = {
+    APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB
+}
+
+_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = {
+    APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB
+}
+
 # Possible value for 'reuse' keyword argument. Sets 'reuse' to
 # tf.get_variable_scope().reuse.
 VARIABLE_SCOPE = "VARIABLE_SCOPE"
@@ -169,9 +181,12 @@ class LayerCollection(object):
     self._default_generic_approximation = APPROX_FULL_NAME
     self._default_embedding_approximation = APPROX_KRONECKER_NAME
     self._default_fully_connected_approximation = APPROX_KRONECKER_NAME
-    self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME
+    self._default_conv2d_approximation = APPROX_KRONECKER_NAME
     self._default_fully_connected_multi_approximation = (
-        APPROX_KRONECKER_SERIES_2_NAME)
+        APPROX_KRONECKER_INDEP_NAME)
+    self._default_conv2d_multi_approximation = (
+        APPROX_KRONECKER_INDEP_NAME)
+    self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME
     self.loss_colocation_ops = {}
     self._vars_to_uses = defaultdict(lambda: 0)
 
@@ -245,14 +260,14 @@ class LayerCollection(object):
 
   @property
   def default_conv2d_approximation(self):
-    return self._default_convolution_2d_approximation
+    return self._default_conv2d_approximation
 
   def set_default_conv2d_approximation(self, value):
     if value not in _CONV2D_APPROX_TO_BLOCK_TYPES:
       raise ValueError(
           "{} is not a valid approximation for 2d convolutional layers.".format(
               value))
-    self._default_convolution_2d_approximation = value
+    self._default_conv2d_approximation = value
 
   @property
   def default_fully_connected_multi_approximation(self):
@@ -264,6 +279,14 @@ class LayerCollection(object):
                        "multi layer.".format(value))
     self._default_fully_connected_multi_approximation = value
 
+  @property
+  def default_conv2d_multi_approximation(self):
+    return self._default_conv2d_multi_approximation
+
+  @property
+  def default_embedding_multi_approximation(self):
+    return self._default_embedding_multi_approximation
+
   def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
     """Validates and registers the layer_key associated with the fisher_block.
 
@@ -526,13 +549,24 @@ class LayerCollection(object):
     else:
       return None
 
+  def _get_block_type(self, params, approx, default, approx_to_type):
+    if approx is None:
+      approx = self._get_linked_approx(params)
+      if approx is None:
+        approx = default
+
+    if approx not in approx_to_type:
+      raise ValueError("Bad value {} for approx.".format(approx))
+
+    return approx_to_type[approx], approx
+
   def register_embedding(self,
                          params,
                          inputs,
                          outputs,
                          approx=None,
                          reuse=VARIABLE_SCOPE):
-    """Registers a fully connnected layer.
+    """Registers an embedding layer.
 
     Args:
       params: Embedding matrix of shape [vocab_size, embedding_size].
@@ -540,7 +574,8 @@ class LayerCollection(object):
         into embedding matrix.
       outputs: Tensor of shape [batch_size, output_size]. Outputs
         produced by layer.
-      approx: str. Must be "kron".
+      approx: str or None. If not None must be "kron".  The Fisher
+        approximation to use. If None the default value is used. (Default: None)
       reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
         create a new FisherBlock.  If "VARIABLE_SCOPE", use
         tf.get_variable_scope().reuse.
@@ -550,20 +585,15 @@ class LayerCollection(object):
       KeyError: If reuse == True but no FisherBlock found for 'params'.
       ValueError: If reuse == True and FisherBlock found but of the wrong type.
     """
-    if approx is None:
-      approx = self._get_linked_approx(params)
-      if approx is None:
-        approx = self.default_embedding_approximation
-
-    if approx != APPROX_KRONECKER_NAME:
-      raise ValueError("Bad value {} for approx.".format(approx))
+    block_type, approx = self._get_block_type(
+        params, approx, self.default_embedding_approximation,
+        _EMBEDDING_APPROX_TO_BLOCK_TYPES)
 
     if isinstance(params, (tuple, list)):
       raise ValueError("Bias not supported.")
-
     vocab_size = int(params.shape[0])
     block = self.register_block(
-        params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse)
+        params, block_type(self, vocab_size), reuse=reuse)
     block.register_additional_minibatch(inputs, outputs)
 
     self._add_uses(params, 1)
@@ -583,7 +613,9 @@ class LayerCollection(object):
       inputs: Tensor of shape [batch_size, input_size]. Inputs to layer.
       outputs: Tensor of shape [batch_size, output_size]. Outputs
         produced by layer.
-      approx: str. One of "kron" or "diagonal".
+      approx: str or None. If not None must be one of "kron" or "diagonal".
+        The Fisher approximation to use. If None the default value is used.
+        (Default: None)
       reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
         create a new FisherBlock.  If "VARIABLE_SCOPE", use
         tf.get_variable_scope().reuse.
@@ -593,17 +625,12 @@ class LayerCollection(object):
       KeyError: If reuse == True but no FisherBlock found for 'params'.
       ValueError: If reuse == True and FisherBlock found but of the wrong type.
     """
-    if approx is None:
-      approx = self._get_linked_approx(params)
-      if approx is None:
-        approx = self.default_fully_connected_approximation
 
-    if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES:
-      raise ValueError("Bad value {} for approx.".format(approx))
+    block_type, approx = self._get_block_type(
+        params, approx, self.default_fully_connected_approximation,
+        _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES)
 
-    block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx]
     has_bias = isinstance(params, (tuple, list))
-
     block = self.register_block(params, block_type(self, has_bias=has_bias),
                                 reuse=reuse)
     block.register_additional_minibatch(inputs, outputs)
@@ -635,7 +662,9 @@ class LayerCollection(object):
         Output produced by layer.
       data_format: str or None. Format of data.
       dilations: List of 4 ints. Dilations along each dimension.
-      approx: str. One of "kron" or "diagonal".
+      approx: str or None. If not None must be one of "kron" or "diagonal".
+        The Fisher approximation to use. If None the default value is used.
+        (Default: None)
       reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
         create a new FisherBlock.  If "VARIABLE_SCOPE", use
         tf.get_variable_scope().reuse.
@@ -646,15 +675,14 @@ class LayerCollection(object):
       ValueError: If reuse == True and FisherBlock found but of the wrong type.
     """
 
-    if approx is None:
-      approx = self._get_linked_approx(params)
-      if approx is None:
-        approx = self.default_conv2d_approximation
-
-    if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES:
-      raise ValueError("Bad value {} for approx.".format(approx))
+    block_type, approx = self._get_block_type(
+        params, approx, self.default_conv2d_approximation,
+        _CONV2D_APPROX_TO_BLOCK_TYPES)
 
-    block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx]
+    # It feels bad to pass in configuration that has to do with the internal
+    # implementation.  And then we can't use the same constructor for both
+    # anymore and are thus forced to use this ugly if-statement.
+    # TODO(b/74793309): Clean this up?
     if approx == APPROX_KRONECKER_NAME:
       block = self.register_block(
           params,
@@ -680,7 +708,7 @@ class LayerCollection(object):
               data_format=data_format),
           reuse=reuse)
     else:
-      raise NotImplementedError
+      raise NotImplementedError(approx)
 
     block.register_additional_minibatch(inputs, outputs)
 
@@ -712,7 +740,9 @@ class LayerCollection(object):
       dilation_rate: List of ints of length len(..input_spatial_size..).
         Dilations along spatial dimension.
       data_format: str or None. Format of data.
-      approx: str. One of "kron" or "diagonal".
+      approx: str or None. If not None must be one of "kron" or "diagonal".
+        The Fisher approximation to use. If None the default value is used.
+        (Default: None)
       reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
         create a new FisherBlock.  If "VARIABLE_SCOPE", use
         tf.get_variable_scope().reuse.
@@ -722,6 +752,8 @@ class LayerCollection(object):
       KeyError: If reuse == True but no FisherBlock found for 'params'.
       ValueError: If reuse == True and FisherBlock found but of the wrong type.
     """
+    # TODO(b/74793309): Have this use _get_block_type like the other
+    # registration functions?
     assert approx is None or approx == APPROX_KRONECKER_NAME
 
     block = self.register_block(
@@ -762,7 +794,8 @@ class LayerCollection(object):
       rate: None or List of ints of length 2. Dilation rates in spatial
         dimensions.
       data_format: str or None. Format of data.
-      approx: None or str. Must be "diagonal" if non-None.
+      approx: str or None. If not None must "diagonal".  The Fisher
+        approximation to use. If None the default value is used. (Default: None)
       reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
         create a new FisherBlock.  If "VARIABLE_SCOPE", use
         tf.get_variable_scope().reuse.
@@ -772,6 +805,8 @@ class LayerCollection(object):
       KeyError: If reuse == True but no FisherBlock found for 'params'.
       ValueError: If reuse == True and FisherBlock found but of the wrong type.
     """
+    # TODO(b/74793309): Have this use _get_block_type like the other
+    # registration functions?
     assert approx is None or approx == APPROX_DIAGONAL_NAME
     assert data_format in [None, "NHWC"]
 
@@ -803,7 +838,7 @@ class LayerCollection(object):
                                 reuse=VARIABLE_SCOPE):
     """Register a call to tf.nn.separable_conv2d().
 
-    Note: This requires access to intermediate outputs betwee depthwise and
+    Note: This requires access to intermediate outputs between depthwise and
     pointwise convolutions.
 
     Args:
@@ -824,7 +859,9 @@ class LayerCollection(object):
       rate: None or List of ints of length 2. Dilation rate of depthwise conv2d
         kernel in spatial dimensions.
       data_format: str or None. Format of data.
-      approx: None or str. Must be "kron" if non-None.
+      approx: str or None. If not None must be one of "kron" or "diagonal".
+        The Fisher approximation to use. If None the default value is used.
+        (Default: None)
       reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
         create a new FisherBlock.  If "VARIABLE_SCOPE", use
         tf.get_variable_scope().reuse.
@@ -865,7 +902,9 @@ class LayerCollection(object):
     Args:
       params: Tensor or tuple of Tensors corresponding to the parameters.
       batch_size: 0-D Tensor. Size of the minibatch.
-      approx: str. One of "full" or "diagonal".
+      approx: str or None. It not None, must be one of "full" or "diagonal".
+        The Fisher approximation to use. If None the default value is used.
+        (Default: None)
       reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
         create a new FisherBlock.  If "VARIABLE_SCOPE", use
         tf.get_variable_scope().reuse.
@@ -875,16 +914,10 @@ class LayerCollection(object):
       KeyError: If reuse == True but no FisherBlock found for 'params'.
       ValueError: If reuse == True and FisherBlock found but of the wrong type.
     """
+    block_type, approx = self._get_block_type(
+        params, approx, self.default_generic_approximation,
+        _GENERIC_APPROX_TO_BLOCK_TYPES)
 
-    if approx is None:
-      approx = self._get_linked_approx(params)
-      if approx is None:
-        approx = self.default_generic_approximation
-
-    if approx not in _GENERIC_APPROX_TO_BLOCK_TYPES:
-      raise ValueError("Bad value {} for approx.".format(approx))
-
-    block_type = _GENERIC_APPROX_TO_BLOCK_TYPES[approx]
     block = self.register_block(params, block_type(self, params), reuse=reuse)
     block.register_additional_minibatch(batch_size)
 
@@ -903,11 +936,15 @@ class LayerCollection(object):
         this layer. Weight matrix should have shape [input_size, output_size].
         Bias should have shape [output_size].
       inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs
-        to layer. In the case of RNNs, one Tensor per time step.
+        to layer. The list indexes each use in the graph (which might
+        correspond to a "time-step" in an RNN).
       outputs: A list of tensors, the same length as 'inputs', each of shape
-        [batch_size, output_size]. Outputs produced by layer. In the case of
-        RNNs, one Tensor per time step.
-      approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2".
+        [batch_size, output_size]. Outputs produced by layer. The list indexes
+        each use in the graph (which might correspond to a "time-step" in an
+        RNN). Needs to correspond with the order used in 'inputs'.
+      approx: str or None. If not None, must be of "kron_indep", "kron_series_1"
+        or "kron_series_2". The Fisher approximation to use. If None the default
+        value is used. (Default: None)
       reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
         create a new FisherBlock.  If "VARIABLE_SCOPE", use
         tf.get_variable_scope().reuse.
@@ -915,28 +952,129 @@ class LayerCollection(object):
     Raises:
       ValueError: For improper value to 'approx'.
     """
-    if approx is None:
-      approx = self._get_linked_approx(params)
-      if approx is None:
-        approx = self.default_fully_connected_multi_approximation
-    has_bias = isinstance(params, (tuple, list))
+    block_type, approx = self._get_block_type(
+        params, approx, self.default_fully_connected_multi_approximation,
+        _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES)
 
     # TODO(b/70283649): something along the lines of find_canonical_output
     # should be added back in here (and for the other block types, arguably).
 
-    if approx not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES:
-      raise ValueError("Bad value {} for approx.".format(approx))
-    block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx]
-
+    has_bias = isinstance(params, (tuple, list))
     block = self.register_block(params, block_type(self, has_bias=has_bias),
                                 reuse=reuse)
     block.register_additional_minibatch(inputs, outputs)
+
+    assert len(inputs) == len(outputs)
+    self._add_uses(params, len(inputs))
+
+  def register_conv2d_multi(self,
+                            params,
+                            strides,
+                            padding,
+                            inputs,
+                            outputs,
+                            data_format=None,
+                            dilations=None,
+                            approx=None,
+                            reuse=VARIABLE_SCOPE):
+    """Registers convolutional layers with shared parameters.
+
+    Args:
+      params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
+        this layer. Weight matrix should have shape [kernel_height,
+        kernel_width, in_channels, out_channels].  Bias should have shape
+        [out_channels].
+      strides: 1-D Tensor of length 4. Strides for convolution kernel.
+      padding: string. see tf.nn.conv2d for valid values.
+      inputs: A list of Tensors, each of shape [batch_size, height, width,
+        in_channels]. Inputs to layer. The list indexes each use in the graph
+        (which might correspond to a "time-step" in an RNN).
+      outputs: A list of Tensors, each of shape [batch_size, height, width,
+        out_channels]. Output produced by layer. The list indexes each use
+        in the graph (which might correspond to a "time-step" in an RNN).
+        Needs to correspond with the order used in 'inputs'.
+      data_format: str or None. Format of data.
+      dilations: List of 4 ints. Dilations along each dimension.
+      approx: str or None. If not None must by "kron_indep". The Fisher
+        approximation to use. If None the default value is used.
+        (Default: None)
+      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
+        create a new FisherBlock.  If "VARIABLE_SCOPE", use
+        tf.get_variable_scope().reuse.
+
+    Raises:
+      ValueError: For improper value to 'approx'.
+      KeyError: If reuse == True but no FisherBlock found for 'params'.
+      ValueError: If reuse == True and FisherBlock found but of the wrong type.
+    """
+    block_type, approx = self._get_block_type(
+        params, approx, self.default_conv2d_multi_approximation,
+        _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES)
+
+    block = self.register_block(
+        params,
+        block_type(
+            layer_collection=self,
+            params=params,
+            padding=padding,
+            strides=strides,
+            data_format=data_format,
+            dilation_rate=dilations,
+            extract_patches_fn="extract_image_patches"),
+        reuse=reuse)
+
+    block.register_additional_minibatch(inputs, outputs)
+
+    assert len(inputs) == len(outputs)
     self._add_uses(params, len(inputs))
 
   # TODO(b/74108452): change the loss registration functions names to refer
   # to "loss functions" instead of distributions.  Following naming convention
   # of the loss function classes themselves.
 
+  def register_embedding_multi(self,
+                               params,
+                               inputs,
+                               outputs,
+                               approx=None,
+                               reuse=VARIABLE_SCOPE):
+    """Registers embedding layers with shared parameters.
+
+    Args:
+      params: Embedding matrix of shape [vocab_size, embedding_size].
+      inputs: A list of Tensors, each of shape [batch_size, input_size] and
+        dtype int32. Indices into embedding matrix. The list indexes each use
+        in the graph (which might correspond to a "time-step" in an RNN).
+      outputs: A list of Tensors, each of shape [batch_size, output_size].
+        Outputs produced by layer. The list indexes each use in the graph
+        (which might correspond to a "time-step" in an RNN). Needs to
+        correspond with the order used in 'inputs'.
+      approx: str or None. If not None must by "kron_indep". The Fisher
+        approximation to use. If None the default value is used.
+        (Default: None)
+      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
+        create a new FisherBlock.  If "VARIABLE_SCOPE", use
+        tf.get_variable_scope().reuse.
+
+    Raises:
+      ValueError: For improper value to 'approx'.
+      KeyError: If reuse == True but no FisherBlock found for 'params'.
+      ValueError: If reuse == True and FisherBlock found but of the wrong type.
+    """
+    block_type, approx = self._get_block_type(
+        params, approx, self.default_embedding_multi_approximation,
+        _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES)
+
+    if isinstance(params, (tuple, list)):
+      raise ValueError("Bias not supported.")
+    vocab_size = int(params.shape[0])
+
+    block = self.register_block(
+        params, block_type(self, vocab_size), reuse=reuse)
+    block.register_additional_minibatch(inputs, outputs)
+
+    self._add_uses(params, len(inputs))
+
   def register_categorical_predictive_distribution(self,
                                                    logits,
                                                    seed=None,
index af26f5e..c589b18 100644 (file)
@@ -659,6 +659,14 @@ class PartitionedTensor(object):
   def __hash__(self):
     return hash(tuple(self.tensors))
 
+  def __eq__(self, other):
+    if not isinstance(other, PartitionedTensor):
+      return False
+    return self.tensors == other.tensors
+
+  def __ne__(self, other):
+    return not self == other  # pylint: disable=g-comparison-negation
+
   def as_tensor(self, dtype=None, name=None, as_ref=False):
     with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
       assert not as_ref