- Added support for data to be specified in RNN classes as large tensors with time...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 20 Mar 2018 10:11:32 +0000 (03:11 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 10:15:36 +0000 (03:15 -0700)
- Significant refactoring of RNN classes
- Fixed a bunch of issues in the LayerCollection docstrings, especially around the 'reuse' argument.

PiperOrigin-RevId: 189716331

tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
tensorflow/contrib/kfac/python/ops/fisher_blocks.py
tensorflow/contrib/kfac/python/ops/fisher_factors.py
tensorflow/contrib/kfac/python/ops/layer_collection.py
tensorflow/contrib/kfac/python/ops/utils.py

index 16f02f1..e007f70 100644 (file)
@@ -862,8 +862,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
     with tf_ops.Graph().as_default():
       random_seed.set_random_seed(200)
       tensor = array_ops.ones((2, 3), name='a/b/c')
-      tensor_list = [tensor]
-      factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False)
+      factor = ff.FullyConnectedMultiKF((tensor,), has_bias=False)
       factor.instantiate_cov_variables()
       self.assertEqual([3, 3], factor.get_cov().get_shape().as_list())
 
@@ -872,8 +871,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
       dtype = dtypes.float64_ref
       random_seed.set_random_seed(200)
       tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
-      tensor_list = [tensor]
-      factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False)
+      factor = ff.FullyConnectedMultiKF((tensor,), has_bias=False)
       factor.instantiate_cov_variables()
       cov = factor.get_cov()
       self.assertEqual(cov.dtype, dtype)
@@ -883,8 +881,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
     with tf_ops.Graph().as_default(), self.test_session() as sess:
       random_seed.set_random_seed(200)
       tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
-      tensor_list = [tensor]
-      factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True)
+      factor = ff.FullyConnectedMultiKF((tensor,), has_bias=True)
       factor.instantiate_cov_variables()
 
       sess.run(tf_variables.global_variables_initializer())
@@ -895,8 +892,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
     with tf_ops.Graph().as_default(), self.test_session() as sess:
       random_seed.set_random_seed(200)
       tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
-      tensor_list = [tensor]
-      factor = ff.FullyConnectedMultiKF((tensor_list,))
+      factor = ff.FullyConnectedMultiKF((tensor,))
       factor.instantiate_cov_variables()
 
       sess.run(tf_variables.global_variables_initializer())
index 79d0424..f517e31 100644 (file)
@@ -106,55 +106,6 @@ def _make_partitionedtensors_grads(grads_list):
   return tuple(utils.PartitionedTensor(grads) for grads in grads_list)
 
 
-def _make_partitionedtensors_multi_inputs(inputs):
-  """Constructs PartitionedTensors for inputs.
-
-  The purpose of this method is to package up the towers/minibatch dimension
-  of these arrays into PartitionedTensor objects.
-
-  This version of this function is for use with FisherBlocks that deal with
-  multiple uses or time-steps. One PartitionedTensor is created for each
-  use/time-step.  The FisherBlock will be responsible for concatenating
-  (or doing whatever else it wants) with the resulting lists.
-
-  Args:
-    inputs: a 2-D list of Tensors. First index is tower/mini-batch, second is
-      use/time-step.
-
-  Returns:
-    A tuple of PartitionedTensor's, one per use/time-step.
-  """
-  num_uses = len(inputs[0])
-  assert all(len(input_) == num_uses for input_ in inputs)
-
-  return tuple(utils.PartitionedTensor(input_) for input_ in zip(*inputs))
-
-
-def _make_partitionedtensors_multi_grads(grads_list):
-  """Constructs PartitionedTensors for grads_list.
-
-  The purpose of this method is to package up the towers/minibatch dimension
-  of these arrays into PartitionedTensor objects.
-
-  This version of this function is for use with FisherBlocks that deal with
-  multiple uses or time-steps. One PartitionedTensor is created for each
-  use/time-step.  The FisherBlock will be responsible for concatenating
-  (or doing whatever else it wants) with the resulting lists.
-
-  Args:
-    grads_list: 3-D list of Tensors. First index is for source, second is for
-      tower, third is for use/time-step.
-
-  Returns:
-    2-D tuple of PartitionedTensors. First index is for source, second is for
-    use/time-step.
-  """
-  num_uses = len(grads_list[0][0])
-  assert all(len(grad) == num_uses for grads in grads_list for grad in grads)
-  return tuple(tuple(utils.PartitionedTensor(grad)
-                     for grad in zip(*grads)) for grads in grads_list)
-
-
 def normalize_damping(damping, num_replications):
   """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
   if NORMALIZE_DAMPING_POWER:
@@ -662,7 +613,7 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
 
 
 class KroneckerProductFB(FisherBlock):
-  """A base class for FisherBlocks with separate input and output factors.
+  """A base class for blocks with separate input and output Kronecker factors.
 
   The Fisher block is approximated as a Kronecker product of the input and
   output factors.
@@ -783,67 +734,6 @@ class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB):
     self._setup_damping(damping)
 
 
-class EmbeddingKFACMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
-  """K-FAC FisherBlock for embedding layers used multiple times in the graph.
-
-  Similar to EmbeddingKFACFB except that this version supports multiple uses
-  of the parameter within a single model. These uses could correspond to
-  "time-steps", but they don't have to.
-
-  Does not support bias parameters.
-  """
-
-  def __init__(self, layer_collection, vocab_size):
-    """Creates a EmbeddingKFACMultiIndepFB block.
-
-    Args:
-      layer_collection: The collection of all layers in the K-FAC approximate
-          Fisher information matrix to which this FisherBlock belongs.
-      vocab_size: int. Size of vocabulary for this embedding layer.
-    """
-    self._vocab_size = vocab_size
-
-    super(EmbeddingKFACMultiIndepFB, self).__init__(layer_collection)
-
-  def instantiate_factors(self, grads_list, damping):
-    """Instantiate Kronecker Factors for this FisherBlock.
-
-    Args:
-      grads_list: List of list of list of Tensors. grads_list[i][j][k] is the
-        gradient of the loss with respect to 'outputs' from source 'i',
-        tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape
-        [tower_minibatch_size, output_size].
-      damping: 0-D Tensor or float. 'damping' * identity is approximately added
-        to this FisherBlock's Fisher approximation.
-    """
-    inputs = self._inputs
-    self._num_uses = num_uses = len(inputs[0])
-
-    # Check that all mini-batches/towers have the same number of uses
-    assert all(len(input_) == num_uses for input_ in inputs)
-    # Do the same for grads_list
-    assert all(len(grad) == num_uses for grad in grads for grads in grads_list)
-    # Merge uses and towers/minibatches dimensions together so we can handle
-    # it using a non-multi factor.
-    inputs = nest.flatten(inputs)
-
-    # Note that we call the multi version of make_partitionedtensors only for
-    # grads_list here.
-    inputs = _make_partitionedtensors_inputs(inputs)
-    grads_list = _make_partitionedtensors_multi_grads(grads_list)
-
-    self._input_factor = self._layer_collection.make_or_get_factor(
-        fisher_factors.EmbeddingInputKroneckerFactor,
-        (inputs, self._vocab_size))
-    self._output_factor = self._layer_collection.make_or_get_factor(
-        fisher_factors.FullyConnectedMultiKF, (grads_list,))
-    self._setup_damping(damping, normalization=num_uses)
-
-  @property
-  def _renorm_coeff(self):
-    return self._num_uses
-
-
 class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
   """K-FAC FisherBlock for fully-connected (dense) layers.
 
@@ -1232,7 +1122,70 @@ def num_conv_locations(input_shape, strides):
   return spatial_input_locations // spatial_strides_divisor
 
 
-class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
+class InputOutputMultiMinibatchMultiUse(InputOutputMultiMinibatch):
+  """Adds methods for multi-use/time-step case to InputOutputMultiMinibatch."""
+
+  def __init__(self, num_uses=None, *args, **kwargs):
+    self._num_uses = num_uses
+    super(InputOutputMultiMinibatchMultiUse, self).__init__(*args, **kwargs)
+
+  def _process_data(self, grads_list):
+    """Process temporal/multi-use data into a standard format."""
+
+    inputs = self._inputs
+
+    # The first possible data format is where inputs is a list of tensors,
+    # one for each use/time-step.
+    if isinstance(inputs[0], (list, tuple)):
+      # The first index is tower/minibatch, the second is use/time-step
+      num_uses = len(inputs[0])
+      if self._num_uses is not None and self._num_uses != num_uses:
+        raise ValueError("num_uses argument doesn't match length of inputs.")
+      else:
+        self._num_uses = num_uses
+
+      # Check that all mini-batches/towers have the same number of uses
+      if not all(len(input_) == num_uses for input_ in inputs):
+        raise ValueError("Length of inputs argument is inconsistent across "
+                         "mini-batches/towers.")
+      # Fold uses/time-step and towers/minibatches dimensions together
+      inputs = nest.flatten(inputs)
+
+      inputs = _make_partitionedtensors_inputs(inputs)
+    # If inputs is not a tuple then we assume that inputs is a tensor
+    # with 'uses' folded into the batch dimension. (And grads_list is a list
+    # across sources of such Tensors.)  This is the native format that the
+    # factor will take as arguments.
+
+    # Now we perform the analogous processing for grads_list
+    if isinstance(grads_list[0][0], (list, tuple)):
+      num_uses = len(grads_list[0][0])
+      if self._num_uses is not None and self._num_uses != num_uses:
+        raise ValueError("num_uses argument doesn't match length of outputs, "
+                         "or length of outputs is inconsistent with length of "
+                         "inputs.")
+      else:
+        self._num_uses = num_uses
+
+      if not all(len(grad) == num_uses for grads in grads_list
+                 for grad in grads):
+        raise ValueError("Length of outputs argument is inconsistent across "
+                         "mini-batches/towers.")
+
+      grads_list = tuple(nest.flatten(grads) for grads in grads_list)
+      grads_list = _make_partitionedtensors_grads(grads_list)
+
+    if self._num_uses is None:
+      raise ValueError("You must supply a value for the num_uses argument if "
+                       "the number of uses cannot be inferred from inputs or "
+                       "outputs arguments (e.g. if they are both given in the "
+                       "single Tensor format, instead of as lists of Tensors.")
+
+    return inputs, grads_list
+
+
+class FullyConnectedMultiIndepFB(InputOutputMultiMinibatchMultiUse,
+                                 KroneckerProductFB):
   """FisherBlock for fully-connected layers that share parameters.
 
   This class implements the "independence across time" approximation from the
@@ -1240,42 +1193,43 @@ class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
     https://openreview.net/pdf?id=HyMTkQZAb
   """
 
-  def __init__(self, layer_collection, has_bias=False):
+  def __init__(self, layer_collection, has_bias=False, num_uses=None):
     """Creates a FullyConnectedMultiIndepFB block.
 
     Args:
       layer_collection: LayerCollection instance.
       has_bias: bool. If True, estimates Fisher with respect to a bias
         parameter as well as the layer's parameters.
+      num_uses: int or None. Number of uses of the layer in the model's graph.
+        Only required if the data is formatted with uses/time folded into the
+        batch dimension (instead of uses/time being a list dimension).
+        (Default: None)
     """
     self._has_bias = has_bias
 
-    super(FullyConnectedMultiIndepFB, self).__init__(layer_collection)
+    super(FullyConnectedMultiIndepFB, self).__init__(
+        layer_collection=layer_collection,
+        num_uses=num_uses)
 
   def instantiate_factors(self, grads_list, damping):
-
-    self._num_uses = float(len(self._inputs[0]))
-    inputs = _make_partitionedtensors_multi_inputs(self._inputs)
-    grads_list = _make_partitionedtensors_multi_grads(grads_list)
+    inputs, grads_list = self._process_data(grads_list)
 
     self._input_factor = self._layer_collection.make_or_get_factor(
         fisher_factors.FullyConnectedMultiKF,
-        ((inputs,), self._has_bias))
+        ((inputs,), self._num_uses, self._has_bias))
 
     self._output_factor = self._layer_collection.make_or_get_factor(
-        fisher_factors.FullyConnectedMultiKF, (grads_list,))
+        fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
 
     self._setup_damping(damping, normalization=self._num_uses)
 
   @property
   def _renorm_coeff(self):
-    return self._num_uses
-
-  def tensors_to_compute_grads(self):
-    return self._outputs
+    return float(self._num_uses)
 
 
-class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
+class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatchMultiUse,
+                               KroneckerProductFB):
   """FisherBlock for 2D convolutional layers using the basic KFC approx.
 
   Similar to ConvKFCBasicFB except that this version supports multiple
@@ -1291,7 +1245,8 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
                strides=None,
                dilation_rate=None,
                data_format=None,
-               extract_patches_fn=None):
+               extract_patches_fn=None,
+               num_uses=None):
     """Creates a ConvKFCBasicMultiIndepFB block.
 
     Args:
@@ -1312,6 +1267,10 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
       extract_patches_fn: str or None. Name of function that extracts image
         patches. One of "extract_convolution_patches", "extract_image_patches",
         "extract_pointwise_conv2d_patches".
+      num_uses: int or None. Number of uses of the layer in the model's graph.
+        Only required if the data is formatted with uses/time folded into the
+        batch dimension (instead of uses/time being a list dimension).
+        (Default: None)
     """
     self._padding = padding
     self._strides = maybe_tuple(strides)
@@ -1323,28 +1282,16 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
     fltr = params[0] if self._has_bias else params
     self._filter_shape = tuple(fltr.shape.as_list())
 
-    super(ConvKFCBasicMultiIndepFB, self).__init__(layer_collection)
+    super(ConvKFCBasicMultiIndepFB, self).__init__(
+        layer_collection=layer_collection,
+        num_uses=num_uses)
 
   def instantiate_factors(self, grads_list, damping):
-    # Infer number of locations upon which convolution is applied.
-    self._num_locations = num_locations = num_conv_locations(
-        self._inputs[0][0].shape.as_list(), self._strides)
-
-    # The first index is tower/minibatch, the second is use/time-step
-    inputs = self._inputs
-    self._num_uses = num_uses = len(inputs[0])
-
-    # Check that all mini-batches/towers have the same number of uses
-    assert all(len(input_) == num_uses for input_ in inputs)
-    assert all(len(grad) == num_uses for grads in grads_list for grad in grads)
-
-    # Fold uses/time-step and towers/minibatches dimensions together
-    inputs = nest.flatten(inputs)
-    # And do the same for grads_list
-    grads_list = tuple(nest.flatten(grads) for grads in grads_list)
+    inputs, grads_list = self._process_data(grads_list)
 
-    inputs = _make_partitionedtensors_inputs(inputs)
-    grads_list = _make_partitionedtensors_grads(grads_list)
+    # Infer number of locations upon which convolution is applied.
+    self._num_locations = num_conv_locations(inputs.shape.as_list(),
+                                             self._strides)
 
     self._input_factor = self._layer_collection.make_or_get_factor(
         fisher_factors.ConvInputKroneckerFactor,
@@ -1354,20 +1301,75 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
     self._output_factor = self._layer_collection.make_or_get_factor(
         fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
 
-    self._setup_damping(damping, normalization=(num_locations * num_uses))
+    self._setup_damping(damping, normalization=
+                        (self._num_locations * self._num_uses))
 
   @property
   def _renorm_coeff(self):
     return self._num_locations * self._num_uses
 
 
+class EmbeddingKFACMultiIndepFB(InputOutputMultiMinibatchMultiUse,
+                                KroneckerProductFB):
+  """K-FAC FisherBlock for embedding layers used multiple times in the graph.
+
+  Similar to EmbeddingKFACFB except that this version supports multiple uses
+  of the parameter within a single model. These uses could correspond to time
+  steps in an RNN architecture, but they don't have to.
+
+  Does not support bias parameters.
+  """
+
+  def __init__(self, layer_collection, vocab_size, num_uses):
+    """Creates a EmbeddingKFACMultiIndepFB block.
+
+    Args:
+      layer_collection: The collection of all layers in the K-FAC approximate
+          Fisher information matrix to which this FisherBlock belongs.
+      vocab_size: int. Size of vocabulary for this embedding layer.
+      num_uses: int or None. Number of uses of the layer in the model's graph.
+        Only required if the data is formatted with time folded into the batch
+        dimension (instead of time being a list dimension). (Default: None)
+    """
+    self._vocab_size = vocab_size
+
+    super(EmbeddingKFACMultiIndepFB, self).__init__(
+        layer_collection=layer_collection,
+        num_uses=num_uses)
+
+  def instantiate_factors(self, grads_list, damping):
+    """Instantiate Kronecker Factors for this FisherBlock.
+
+    Args:
+      grads_list: List of list of list of Tensors. grads_list[i][j][k] is the
+        gradient of the loss with respect to 'outputs' from source 'i',
+        tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape
+        [tower_minibatch_size, output_size].
+      damping: 0-D Tensor or float. 'damping' * identity is approximately added
+        to this FisherBlock's Fisher approximation.
+    """
+    inputs, grads_list = self._process_data(grads_list)
+
+    self._input_factor = self._layer_collection.make_or_get_factor(
+        fisher_factors.EmbeddingInputKroneckerFactor,
+        (inputs, self._vocab_size))
+    self._output_factor = self._layer_collection.make_or_get_factor(
+        fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
+    self._setup_damping(damping, normalization=self._num_uses)
+
+  @property
+  def _renorm_coeff(self):
+    return float(self._num_uses)
+
+
 class SeriesFBApproximation(enum.IntEnum):
   """See FullyConnectedSeriesFB.__init__ for description and usage."""
   option1 = 1
   option2 = 2
 
 
-class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
+class FullyConnectedSeriesFB(InputOutputMultiMinibatchMultiUse,
+                             KroneckerProductFB):
   """FisherBlock for fully-connected layers that share parameters across time.
 
   This class implements the "Option 1" and "Option 2" approximation from the
@@ -1383,6 +1385,7 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
   def __init__(self,
                layer_collection,
                has_bias=False,
+               num_uses=None,
                option=SeriesFBApproximation.option2):
     """Constructs a new `FullyConnectedSeriesFB`.
 
@@ -1390,6 +1393,10 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
       layer_collection: The collection of all layers in the K-FAC approximate
         Fisher information matrix to which this FisherBlock belongs.
       has_bias: Whether the layer includes a bias parameter.
+      num_uses: int or None. Number of time-steps over which the layer
+        is used. Only required if the data is formatted with time folded into
+        the batch dimension (instead of time being a list dimension).
+        (Default: None)
       option: A `SeriesFBApproximation` specifying the simplifying assumption
         to be used in this block. `option1` approximates the cross-covariance
         over time as a symmetric matrix, while `option2` makes
@@ -1400,39 +1407,33 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
     self._has_bias = has_bias
     self._option = option
 
-    super(FullyConnectedSeriesFB, self).__init__(layer_collection)
+    super(FullyConnectedSeriesFB, self).__init__(
+        layer_collection=layer_collection,
+        num_uses=num_uses)
 
-  def instantiate_factors(self, grads_list, damping):
+  @property
+  def _num_timesteps(self):
+    return self._num_uses
 
-    self._num_timesteps = len(self._inputs[0])
-    assert len(grads_list[0][0]) == self._num_timesteps
+  @property
+  def _renorm_coeff(self):
+    # This should no longer be used since the multiply_X functions from the base
+    # class have been overridden
+    assert False
 
-    inputs = _make_partitionedtensors_multi_inputs(self._inputs)
-    grads_list = _make_partitionedtensors_multi_grads(grads_list)
+  def instantiate_factors(self, grads_list, damping):
+    inputs, grads_list = self._process_data(grads_list)
 
     self._input_factor = self._layer_collection.make_or_get_factor(
-        fisher_factors.FullyConnectedMultiKF, ((inputs,), self._has_bias))
+        fisher_factors.FullyConnectedMultiKF,
+        ((inputs,), self._num_uses, self._has_bias))
     self._input_factor.register_cov_dt1()
 
     self._output_factor = self._layer_collection.make_or_get_factor(
-        fisher_factors.FullyConnectedMultiKF, (grads_list,))
+        fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
     self._output_factor.register_cov_dt1()
 
-    def compute_damping():
-      normalized_damping = normalize_damping(damping, self._num_timesteps)
-      return compute_pi_adjusted_damping(self._input_factor.get_cov(),
-                                         self._output_factor.get_cov(),
-                                         normalized_damping**0.5)
-
-    damping_id = ("compute_pi_adjusted_damping",
-                  "cov", self._input_factor.name,
-                  "cov", self._output_factor.name,
-                  "normalize_damping",
-                  damping, self._num_timesteps, "power", 0.5)
-    self._input_damping_func = _package_func(lambda: compute_damping()[0],
-                                             damping_id + ("ref", 0))
-    self._output_damping_func = _package_func(lambda: compute_damping()[1],
-                                              damping_id + ("ref", 1))
+    self._setup_damping(damping, normalization=self._num_uses)
 
   def register_matpower(self, exp):
     if exp != -1:
@@ -1562,6 +1563,3 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
     return utils.mat2d_to_layer_params(vector, Z)
 
     # pylint: enable=invalid-name
-
-  def tensors_to_compute_grads(self):
-    return self._outputs
index 6fc163e..f521363 100644 (file)
@@ -35,7 +35,6 @@ from tensorflow.python.ops import special_math_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.training import moving_averages
-from tensorflow.python.util import nest
 
 # Whether to initialize covariance estimators at a zero matrix (or the identity
 # matrix).
@@ -1227,27 +1226,24 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
     return compute_cov(reshaped_tensor)
 
 
-class FullyConnectedMultiKF(InverseProvidingFactor):
+class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
   """Kronecker factor for a fully connected layer used multiple times."""
 
   def __init__(self,
-               tensor_lists,
+               tensors,
+               num_uses=None,
                has_bias=False):
     """Constructs a new `FullyConnectedMultiKF`.
 
     Args:
-      tensor_lists: 2D array (list of lists) of Tensors of shape
-        [batch_size, n]. Each of these tensors is usually a layer's inputs or
-        its output's gradients. The first dimension of the array is the source,
-        and the second is the use in the graph (which is sometimes a
-        "time-step").
+      tensors: List of Tensors of shape, each of shape [batch_size, n]. Each of
+        these tensors is usually a layer's inputs or its output's gradients.
+        The list is over sources.
+      num_uses: int. The number of time-steps / uses.
       has_bias: bool. If True, '1' is appended to each row.
     """
 
-    self._tensor_lists = tensor_lists
-    self._has_bias = has_bias
-    self._num_timesteps = len(tensor_lists[0])
-    self._tensors = [None] * len(tensor_lists)
+    self._num_uses = num_uses
 
     self._cov_dt1 = None
     self._make_cov_dt1 = False
@@ -1256,20 +1252,17 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
     self._option1quants_registrations = set()
     self._option2quants_registrations = set()
 
-    super(FullyConnectedMultiKF, self).__init__()
-
-  @property
-  def _var_scope(self):
-    return "ff_fc_multi_" + scope_string_from_params(
-        tuple(nest.flatten(self._tensor_lists)) + (self._has_bias,))
+    super(FullyConnectedMultiKF, self).__init__(tensors=tensors,
+                                                has_bias=has_bias)
 
   @property
-  def _num_sources(self):
-    return len(self._tensor_lists)
+  def _num_timesteps(self):
+    return self._num_uses
 
   @property
-  def _dtype(self):
-    return self._tensor_lists[0][0].dtype
+  def _var_scope(self):
+    return "ff_fc_multi_" + scope_string_from_params(
+        tuple(self._tensors) + (self._num_timesteps, self._has_bias,))
 
   def make_covariance_update_op(self, ema_decay):
 
@@ -1291,36 +1284,28 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
 
     return op
 
-  def _compute_new_cov(self, idx=0):
-    # Concatenate across time/replications
-    tensor = array_ops.concat(self._tensor_lists[idx], 0)
+  def _compute_new_cov_dt1(self, idx=0):  # pylint: disable=missing-docstring
+    tensor = self._tensors[idx]
     if self._has_bias:
+      # This appending is technically done twice (the other time is for
+      # _compute_new_cov())
       tensor = append_homog(tensor)
-    # We save these so they can be used by _compute_new_cov_dt1
-    self._tensors[idx] = tensor
-    return compute_cov(tensor)
 
-  def _compute_new_cov_dt1(self, idx=0):  # pylint: disable=missing-docstring
-    tensor = self._tensors[idx]
-    batch_size = array_ops.shape(self._tensor_lists[idx][0])[0]
-    # Is there a more elegant way to do this computation?
+    total_len = array_ops.shape(tensor)[0]
+    batch_size = total_len // self._num_timesteps
+
     tensor_present = tensor[:-batch_size, :]
     tensor_future = tensor[batch_size:, :]
+
     # We specify a normalizer for this computation to ensure a PSD Fisher
     # block estimate.  This is equivalent to padding with zeros, as was done
     # in Section B.2 of the appendix.
-    normalizer = self._num_timesteps * batch_size
     return compute_cov(
-        tensor_future, tensor_right=tensor_present, normalizer=normalizer)
-
-  @property
-  def _cov_shape(self):
-    size = self._tensor_lists[0][0].shape[1] + self._has_bias
-    return [size, size]
+        tensor_future, tensor_right=tensor_present, normalizer=total_len)
 
   @property
   def _vec_shape(self):
-    size = self._tensor_lists[0][0].shape[1] + self._has_bias
+    size = self._tensors[0].shape[1] + self._has_bias
     return [size]
 
   def get_option1quants(self, damping_func):
index 00eae8b..7727c60 100644 (file)
@@ -572,13 +572,15 @@ class LayerCollection(object):
       params: Embedding matrix of shape [vocab_size, embedding_size].
       inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices
         into embedding matrix.
-      outputs: Tensor of shape [batch_size, output_size]. Outputs
+      outputs: Tensor of shape [batch_size, embedding_size]. Outputs
         produced by layer.
       approx: str or None. If not None must be "kron".  The Fisher
         approximation to use. If None the default value is used. (Default: None)
-      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
-        create a new FisherBlock.  If "VARIABLE_SCOPE", use
-        tf.get_variable_scope().reuse.
+      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
+        additional mini-batch/tower of data to use when estimating the Fisher
+        block for this layer (which must have already been registered). If
+        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+        (Default: "VARIABLE_SCOPE")
 
     Raises:
       ValueError: For improper value to 'approx'.
@@ -616,9 +618,11 @@ class LayerCollection(object):
       approx: str or None. If not None must be one of "kron" or "diagonal".
         The Fisher approximation to use. If None the default value is used.
         (Default: None)
-      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
-        create a new FisherBlock.  If "VARIABLE_SCOPE", use
-        tf.get_variable_scope().reuse.
+      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
+        additional mini-batch/tower of data to use when estimating the Fisher
+        block for this layer (which must have already been registered). If
+        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+        (Default: "VARIABLE_SCOPE")
 
     Raises:
       ValueError: For improper value to 'approx'.
@@ -665,9 +669,11 @@ class LayerCollection(object):
       approx: str or None. If not None must be one of "kron" or "diagonal".
         The Fisher approximation to use. If None the default value is used.
         (Default: None)
-      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
-        create a new FisherBlock.  If "VARIABLE_SCOPE", use
-        tf.get_variable_scope().reuse.
+      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
+        additional mini-batch/tower of data to use when estimating the Fisher
+        block for this layer (which must have already been registered). If
+        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+        (Default: "VARIABLE_SCOPE")
 
     Raises:
       ValueError: For improper value to 'approx'.
@@ -743,9 +749,11 @@ class LayerCollection(object):
       approx: str or None. If not None must be one of "kron" or "diagonal".
         The Fisher approximation to use. If None the default value is used.
         (Default: None)
-      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
-        create a new FisherBlock.  If "VARIABLE_SCOPE", use
-        tf.get_variable_scope().reuse.
+      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
+        additional mini-batch/tower of data to use when estimating the Fisher
+        block for this layer (which must have already been registered). If
+        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+        (Default: "VARIABLE_SCOPE")
 
     Raises:
       ValueError: For improper value to 'approx'.
@@ -796,9 +804,11 @@ class LayerCollection(object):
       data_format: str or None. Format of data.
       approx: str or None. If not None must "diagonal".  The Fisher
         approximation to use. If None the default value is used. (Default: None)
-      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
-        create a new FisherBlock.  If "VARIABLE_SCOPE", use
-        tf.get_variable_scope().reuse.
+      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
+        additional mini-batch/tower of data to use when estimating the Fisher
+        block for this layer (which must have already been registered). If
+        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+        (Default: "VARIABLE_SCOPE")
 
     Raises:
       ValueError: For improper value to 'approx'.
@@ -862,9 +872,11 @@ class LayerCollection(object):
       approx: str or None. If not None must be one of "kron" or "diagonal".
         The Fisher approximation to use. If None the default value is used.
         (Default: None)
-      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
-        create a new FisherBlock.  If "VARIABLE_SCOPE", use
-        tf.get_variable_scope().reuse.
+      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
+        additional mini-batch/tower of data to use when estimating the Fisher
+        block for this layer (which must have already been registered). If
+        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+        (Default: "VARIABLE_SCOPE")
 
     Raises:
       ValueError: For improper value to 'approx'.
@@ -905,9 +917,10 @@ class LayerCollection(object):
       approx: str or None. It not None, must be one of "full" or "diagonal".
         The Fisher approximation to use. If None the default value is used.
         (Default: None)
-      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
-        create a new FisherBlock.  If "VARIABLE_SCOPE", use
-        tf.get_variable_scope().reuse.
+      reuse: bool or str. If True, this adds 'batch_size' to the total
+        mini-batch size use when estimating the Fisher block for this layer
+        (which must have already been registered). If "VARIABLE_SCOPE", use
+        tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
 
     Raises:
       ValueError: For improper value to 'approx'.
@@ -924,7 +937,8 @@ class LayerCollection(object):
     self._add_uses(params, float("inf"))
 
   def register_fully_connected_multi(self, params, inputs, outputs,
-                                     approx=None, reuse=VARIABLE_SCOPE):
+                                     num_uses=None, approx=None,
+                                     reuse=VARIABLE_SCOPE):
     """Register fully connected layers with shared parameters.
 
     This can handle general fully-connected layers with shared parameters, but
@@ -935,19 +949,31 @@ class LayerCollection(object):
       params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
         this layer. Weight matrix should have shape [input_size, output_size].
         Bias should have shape [output_size].
-      inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs
+      inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs
         to layer. The list indexes each use in the graph (which might
-        correspond to a "time-step" in an RNN).
-      outputs: A list of tensors, the same length as 'inputs', each of shape
+        correspond to a "time-step" in an RNN). OR, can be single Tensor, of
+        shape [batch_size * num_uses, input_size], which is a reshaped version
+        of a Tensor of shape [batch_size, num_uses, input_size].
+      outputs: A list of Tensors, the same length as 'inputs', each of shape
         [batch_size, output_size]. Outputs produced by layer. The list indexes
         each use in the graph (which might correspond to a "time-step" in an
-        RNN). Needs to correspond with the order used in 'inputs'.
+        RNN). Needs to correspond with the order used in 'inputs'.  OR, can be
+        a single Tensor of shape [batch_size * num_uses, output_size], which is
+        a reshaped version of a Tensor of shape [batch_size, num_uses,
+        output_size].
+      num_uses: int or None. The number uses/time-steps in the graph where the
+        layer appears. Only needed if both inputs and outputs are given in the
+        single Tensor format. (Default: None)
       approx: str or None. If not None, must be of "kron_indep", "kron_series_1"
         or "kron_series_2". The Fisher approximation to use. If None the default
         value is used. (Default: None)
-      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
-        create a new FisherBlock.  If "VARIABLE_SCOPE", use
-        tf.get_variable_scope().reuse.
+      reuse: bool or str.  If True, this adds inputs and outputs as an
+        additional mini-batch/tower of data to use when estimating the Fisher
+        block for this layer (which must have already been registered). If
+        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.  (Note that the
+        word 'use' here has a completely different meaning to "use in the graph"
+        as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.)
+        (Default: "VARIABLE_SCOPE")
 
     Raises:
       ValueError: For improper value to 'approx'.
@@ -960,7 +986,8 @@ class LayerCollection(object):
     # should be added back in here (and for the other block types, arguably).
 
     has_bias = isinstance(params, (tuple, list))
-    block = self.register_block(params, block_type(self, has_bias=has_bias),
+    block = self.register_block(params, block_type(self, has_bias=has_bias,
+                                                   num_uses=num_uses),
                                 reuse=reuse)
     block.register_additional_minibatch(inputs, outputs)
 
@@ -973,6 +1000,7 @@ class LayerCollection(object):
                             padding,
                             inputs,
                             outputs,
+                            num_uses=None,
                             data_format=None,
                             dilations=None,
                             approx=None,
@@ -988,19 +1016,32 @@ class LayerCollection(object):
       padding: string. see tf.nn.conv2d for valid values.
       inputs: A list of Tensors, each of shape [batch_size, height, width,
         in_channels]. Inputs to layer. The list indexes each use in the graph
-        (which might correspond to a "time-step" in an RNN).
+        (which might correspond to a "time-step" in an RNN). OR, can be single
+        Tensor, of shape [batch_size * num_uses, height, width, in_channels],
+        which is a reshaped version of a Tensor of shape [batch_size, num_uses,
+        height, width, in_channels].
       outputs: A list of Tensors, each of shape [batch_size, height, width,
         out_channels]. Output produced by layer. The list indexes each use
         in the graph (which might correspond to a "time-step" in an RNN).
-        Needs to correspond with the order used in 'inputs'.
+        Needs to correspond with the order used in 'inputs'.  OR, can be a
+        single Tensor, of shape [batch_size*num_uses, height, width,
+        out_channels], which is a reshaped version of a Tensor of shape
+        [batch_size, num_uses, height, width, out_channels].
+      num_uses: int or None. The number uses/time-steps in the graph where the
+        layer appears. Only needed if both inputs and outputs are given in the
+        single Tensor format. (Default: None)
       data_format: str or None. Format of data.
       dilations: List of 4 ints. Dilations along each dimension.
       approx: str or None. If not None must by "kron_indep". The Fisher
         approximation to use. If None the default value is used.
         (Default: None)
-      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
-        create a new FisherBlock.  If "VARIABLE_SCOPE", use
-        tf.get_variable_scope().reuse.
+      reuse: bool or str.  If True, this adds inputs and outputs as an
+        additional mini-batch/tower of data to use when estimating the Fisher
+        block for this layer (which must have already been registered). If
+        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.  (Note that the
+        word 'use' here has a completely different meaning to "use in the graph"
+        as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.)
+        (Default: "VARIABLE_SCOPE")
 
     Raises:
       ValueError: For improper value to 'approx'.
@@ -1020,7 +1061,8 @@ class LayerCollection(object):
             strides=strides,
             data_format=data_format,
             dilation_rate=dilations,
-            extract_patches_fn="extract_image_patches"),
+            extract_patches_fn="extract_image_patches",
+            num_uses=num_uses),
         reuse=reuse)
 
     block.register_additional_minibatch(inputs, outputs)
@@ -1036,6 +1078,7 @@ class LayerCollection(object):
                                params,
                                inputs,
                                outputs,
+                               num_uses=None,
                                approx=None,
                                reuse=VARIABLE_SCOPE):
     """Registers embedding layers with shared parameters.
@@ -1045,16 +1088,29 @@ class LayerCollection(object):
       inputs: A list of Tensors, each of shape [batch_size, input_size] and
         dtype int32. Indices into embedding matrix. The list indexes each use
         in the graph (which might correspond to a "time-step" in an RNN).
-      outputs: A list of Tensors, each of shape [batch_size, output_size].
+        OR, can be single Tensor, of shape [batch_size * num_uses, input_size],
+        which is a reshaped version of a Tensor of shape [batch_size, num_uses,
+        input_size].
+      outputs: A list of Tensors, each of shape [batch_size, embedding_size].
         Outputs produced by layer. The list indexes each use in the graph
         (which might correspond to a "time-step" in an RNN). Needs to
-        correspond with the order used in 'inputs'.
+        correspond with the order used in 'inputs'. OR, can be a
+        single Tensor, of shape [batch_size*num_uses, embedding_size], which
+        is a reshaped version of a Tensor of shape [batch_size, num_uses,
+        embedding_size].
+      num_uses: int or None. The number uses/time-steps in the graph where the
+        layer appears. Only needed if both inputs and outputs are given in the
+        single Tensor format. (Default: None)
       approx: str or None. If not None must by "kron_indep". The Fisher
         approximation to use. If None the default value is used.
         (Default: None)
-      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
-        create a new FisherBlock.  If "VARIABLE_SCOPE", use
-        tf.get_variable_scope().reuse.
+      reuse: bool or str.  If True, this adds inputs and outputs as an
+        additional mini-batch/tower of data to use when estimating the Fisher
+        block for this layer (which must have already been registered). If
+        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.  (Note that the
+        word 'use' here has a completely different meaning to "use in the graph"
+        as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.)
+        (Default: "VARIABLE_SCOPE")
 
     Raises:
       ValueError: For improper value to 'approx'.
@@ -1070,7 +1126,7 @@ class LayerCollection(object):
     vocab_size = int(params.shape[0])
 
     block = self.register_block(
-        params, block_type(self, vocab_size), reuse=reuse)
+        params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse)
     block.register_additional_minibatch(inputs, outputs)
 
     self._add_uses(params, len(inputs))
@@ -1093,9 +1149,10 @@ class LayerCollection(object):
         (Default: None)
       name: (OPTIONAL) str or None. Unique name for this loss function. If None,
         a new name is generated. (Default: None)
-      reuse: (OPTIONAL) bool or str.  If True, reuse an existing FisherBlock.
-        If False, create a new FisherBlock.  If VARIABLE_SCOPE, use
-        tf.get_variable_scope().reuse.
+      reuse: bool or str.  If True, this adds 'logits' as an additional
+        mini-batch/tower of inputs to the loss-function/predictive distribution
+        (which must have already been registered). If "VARIABLE_SCOPE", use
+        tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
     """
     loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,
                                                    seed=seed)
@@ -1126,9 +1183,10 @@ class LayerCollection(object):
         (Default: None)
       name: (OPTIONAL) str or None. Unique name for this loss function. If None,
         a new name is generated. (Default: None)
-      reuse: (OPTIONAL) bool or str.  If True, reuse an existing FisherBlock.
-        If False, create a new FisherBlock.  If VARIABLE_SCOPE, use
-        tf.get_variable_scope().reuse.
+      reuse: bool or str.  If True, this adds 'mean' and 'var' as an additional
+        mini-batch/tower of inputs to the loss-function/predictive distribution
+        (which must have already been registered). If "VARIABLE_SCOPE", use
+        tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
     """
     loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets,
                                             seed=seed)
@@ -1154,9 +1212,10 @@ class LayerCollection(object):
         (Default: None)
       name: (OPTIONAL) str or None. Unique name for this loss function. If None,
         a new name is generated. (Default: None)
-      reuse: (OPTIONAL) bool or str.  If True, reuse an existing FisherBlock.
-        If False, create a new FisherBlock.  If VARIABLE_SCOPE, use
-        tf.get_variable_scope().reuse.
+      reuse: bool or str.  If True, this adds 'logits' as an additional
+        mini-batch/tower of inputs to the loss-function/predictive distribution
+        (which must have already been registered). If "VARIABLE_SCOPE", use
+        tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
     """
     loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets,
                                                 seed=seed)
index c589b18..c9de0c7 100644 (file)
@@ -667,6 +667,9 @@ class PartitionedTensor(object):
   def __ne__(self, other):
     return not self == other  # pylint: disable=g-comparison-negation
 
+  def __getitem__(self, key):
+    return self.as_tensor()[key]
+
   def as_tensor(self, dtype=None, name=None, as_ref=False):
     with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
       assert not as_ref