Bug Fix: If num_uses > 0 the the inputs tensor need not be a list but can be reshaped to
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 3 Apr 2018 17:00:00 +0000 (10:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 3 Apr 2018 17:03:06 +0000 (10:03 -0700)
[batch_size*num_uses, input_size]. `num_uses` should be incremented by one in this case.'

PiperOrigin-RevId: 191456184

tensorflow/contrib/kfac/python/ops/layer_collection.py

index 586a004..19608ac 100644 (file)
@@ -990,9 +990,11 @@ class LayerCollection(object):
                                                    num_uses=num_uses),
                                 reuse=reuse)
     block.register_additional_tower(inputs, outputs)
-
-    assert len(inputs) == len(outputs)
-    self._add_uses(params, len(inputs))
+    if isinstance(inputs, (tuple, list)):
+      assert len(inputs) == len(outputs)
+      self._add_uses(params, len(inputs))
+    else:
+      self._add_uses(params, 1)
 
   def register_conv2d_multi(self,
                             params,
@@ -1066,9 +1068,11 @@ class LayerCollection(object):
         reuse=reuse)
 
     block.register_additional_tower(inputs, outputs)
-
-    assert len(inputs) == len(outputs)
-    self._add_uses(params, len(inputs))
+    if isinstance(inputs, (tuple, list)):
+      assert len(inputs) == len(outputs)
+      self._add_uses(params, len(inputs))
+    else:
+      self._add_uses(params, 1)
 
   # TODO(b/74108452): change the loss registration functions names to refer
   # to "loss functions" instead of distributions.  Following naming convention
@@ -1088,7 +1092,7 @@ class LayerCollection(object):
       inputs: A list of Tensors, each of shape [batch_size, input_size] and
         dtype int32. Indices into embedding matrix. The list indexes each use
         in the graph (which might correspond to a "time-step" in an RNN).
-        OR, can be single Tensor, of shape [num_usesbatch_size, input_size],
+        OR, can be single Tensor, of shape [num_uses*batch_size, input_size],
         which is a reshaped version of a Tensor of shape [num_uses, batch_size,
         input_size].
       outputs: A list of Tensors, each of shape [batch_size, embedding_size].
@@ -1129,7 +1133,10 @@ class LayerCollection(object):
         params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse)
     block.register_additional_tower(inputs, outputs)
 
-    self._add_uses(params, len(inputs))
+    if isinstance(inputs, (tuple, list)):
+      self._add_uses(params, len(inputs))
+    else:
+      self._add_uses(params, 1)
 
   def register_categorical_predictive_distribution(self,
                                                    logits,