From cf8c504688c5f5813c8772eb107ed3d4a1385888 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 3 Apr 2018 10:00:00 -0700 Subject: [PATCH] Bug Fix: If num_uses > 0 the the inputs tensor need not be a list but can be reshaped to [batch_size*num_uses, input_size]. `num_uses` should be incremented by one in this case.' PiperOrigin-RevId: 191456184 --- .../contrib/kfac/python/ops/layer_collection.py | 23 ++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 586a004..19608ac 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -990,9 +990,11 @@ class LayerCollection(object): num_uses=num_uses), reuse=reuse) block.register_additional_tower(inputs, outputs) - - assert len(inputs) == len(outputs) - self._add_uses(params, len(inputs)) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) def register_conv2d_multi(self, params, @@ -1066,9 +1068,11 @@ class LayerCollection(object): reuse=reuse) block.register_additional_tower(inputs, outputs) - - assert len(inputs) == len(outputs) - self._add_uses(params, len(inputs)) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) # TODO(b/74108452): change the loss registration functions names to refer # to "loss functions" instead of distributions. Following naming convention @@ -1088,7 +1092,7 @@ class LayerCollection(object): inputs: A list of Tensors, each of shape [batch_size, input_size] and dtype int32. Indices into embedding matrix. The list indexes each use in the graph (which might correspond to a "time-step" in an RNN). - OR, can be single Tensor, of shape [num_uses, batch_size, input_size], + OR, can be single Tensor, of shape [num_uses*batch_size, input_size], which is a reshaped version of a Tensor of shape [num_uses, batch_size, input_size]. outputs: A list of Tensors, each of shape [batch_size, embedding_size]. @@ -1129,7 +1133,10 @@ class LayerCollection(object): params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse) block.register_additional_tower(inputs, outputs) - self._add_uses(params, len(inputs)) + if isinstance(inputs, (tuple, list)): + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) def register_categorical_predictive_distribution(self, logits, -- 2.7.4