num_uses=num_uses),
reuse=reuse)
block.register_additional_tower(inputs, outputs)
-
- assert len(inputs) == len(outputs)
- self._add_uses(params, len(inputs))
+ if isinstance(inputs, (tuple, list)):
+ assert len(inputs) == len(outputs)
+ self._add_uses(params, len(inputs))
+ else:
+ self._add_uses(params, 1)
def register_conv2d_multi(self,
params,
reuse=reuse)
block.register_additional_tower(inputs, outputs)
-
- assert len(inputs) == len(outputs)
- self._add_uses(params, len(inputs))
+ if isinstance(inputs, (tuple, list)):
+ assert len(inputs) == len(outputs)
+ self._add_uses(params, len(inputs))
+ else:
+ self._add_uses(params, 1)
# TODO(b/74108452): change the loss registration functions names to refer
# to "loss functions" instead of distributions. Following naming convention
inputs: A list of Tensors, each of shape [batch_size, input_size] and
dtype int32. Indices into embedding matrix. The list indexes each use
in the graph (which might correspond to a "time-step" in an RNN).
- OR, can be single Tensor, of shape [num_uses, batch_size, input_size],
+ OR, can be single Tensor, of shape [num_uses*batch_size, input_size],
which is a reshaped version of a Tensor of shape [num_uses, batch_size,
input_size].
outputs: A list of Tensors, each of shape [batch_size, embedding_size].
params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse)
block.register_additional_tower(inputs, outputs)
- self._add_uses(params, len(inputs))
+ if isinstance(inputs, (tuple, list)):
+ self._add_uses(params, len(inputs))
+ else:
+ self._add_uses(params, 1)
def register_categorical_predictive_distribution(self,
logits,