self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
def testAttentionCellWrapperFailures(self):
- with self.assertRaisesRegexp(TypeError,
- "The parameter cell is not RNNCell."):
+ with self.assertRaisesRegexp(
+ TypeError, rnn_cell_impl.ASSERT_LIKE_RNNCELL_ERROR_REGEXP):
contrib_rnn_cell.AttentionCellWrapper(None, 0)
num_units = 8
h1 = array_ops.zeros([1, 2])
state1 = rnn_cell.LSTMStateTuple(c1, h1)
state = (state0, state1)
- single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False)
+ single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False) # pylint: disable=line-too-long
cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
g, out_m = cell(x, state)
sess.run([variables.global_variables_initializer()])
self.assertAllClose(expected_state1_h, actual_state1_h, 1e-5)
with variable_scope.variable_scope(
- "other", initializer=init_ops.constant_initializer(0.5)) as vs:
+ "other", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros(
[1, 3]) # Test BasicLSTMCell with input_size != num_units.
c = array_ops.zeros([1, 2])
# pylint: disable=protected-access,invalid-name
RNNCell = rnn_cell_impl.RNNCell
-_like_rnncell = rnn_cell_impl._like_rnncell
_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME
_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME
# pylint: enable=protected-access,invalid-name
ValueError: if embedding_classes is not positive.
"""
super(EmbeddingWrapper, self).__init__(_reuse=reuse)
- if not _like_rnncell(cell):
- raise TypeError("The parameter cell is not RNNCell.")
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
if embedding_classes <= 0 or embedding_size <= 0:
raise ValueError("Both embedding_classes and embedding_size must be > 0: "
"%d, %d." % (embedding_classes, embedding_size))
super(InputProjectionWrapper, self).__init__(_reuse=reuse)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
- if not _like_rnncell(cell):
- raise TypeError("The parameter cell is not RNNCell.")
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
self._cell = cell
self._num_proj = num_proj
self._activation = activation
ValueError: if output_size is not positive.
"""
super(OutputProjectionWrapper, self).__init__(_reuse=reuse)
- if not _like_rnncell(cell):
- raise TypeError("The parameter cell is not RNNCell.")
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
if output_size < 1:
raise ValueError("Parameter output_size must be > 0: %d." % output_size)
self._cell = cell
`state_is_tuple` is `False` or if attn_length is zero or less.
"""
super(AttentionCellWrapper, self).__init__(_reuse=reuse)
- if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
- raise TypeError("The parameter cell is not RNNCell.")
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
if nest.is_sequence(cell.state_size) and not state_is_tuple:
raise ValueError(
"Cell returns tuple of states, but the flag "
is a list, and its length does not match that of `attention_layer_size`.
"""
super(AttentionWrapper, self).__init__(name=name)
- if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
- raise TypeError(
- "cell must be an RNNCell, saw type: %s" % type(cell).__name__)
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
if isinstance(attention_mechanism, (list, tuple)):
self._is_multi = True
attention_mechanisms = attention_mechanism
Raises:
TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
"""
- if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
- raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
if not isinstance(helper, helper_py.Helper):
raise TypeError("helper must be a Helper, received: %s" % type(helper))
if (output_layer is not None
ValueError: If `start_tokens` is not a vector or
`end_token` is not a scalar.
"""
- if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
- raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
+ rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access
if (output_layer is not None and
not isinstance(output_layer, layers_base.Layer)):
raise TypeError(
# pylint: disable=protected-access
_concat = rnn_cell_impl._concat
-_like_rnncell = rnn_cell_impl._like_rnncell
# pylint: enable=protected-access
Raises:
TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
"""
-
- if not _like_rnncell(cell_fw):
- raise TypeError("cell_fw must be an instance of RNNCell")
- if not _like_rnncell(cell_bw):
- raise TypeError("cell_bw must be an instance of RNNCell")
+ rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
+ rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
with vs.variable_scope(scope or "bidirectional_rnn"):
# Forward direction
TypeError: If `cell` is not an instance of RNNCell.
ValueError: If inputs is None or an empty list.
"""
- if not _like_rnncell(cell):
- raise TypeError("cell must be an instance of RNNCell")
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
with vs.variable_scope(scope or "rnn") as varscope:
# Create a new scope in which the caching device is either
TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not
a `callable`.
"""
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
- if not _like_rnncell(cell):
- raise TypeError("cell must be an instance of RNNCell")
if not callable(loop_fn):
raise TypeError("loop_fn must be a callable")
ValueError: If `inputs` is `None` or an empty list, or if the input depth
(column size) cannot be inferred from inputs via shape inference.
"""
-
- if not _like_rnncell(cell):
- raise TypeError("cell must be an instance of RNNCell")
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
if not nest.is_sequence(inputs):
raise TypeError("inputs must be a sequence")
if not inputs:
TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
ValueError: If inputs is None or an empty list.
"""
-
- if not _like_rnncell(cell_fw):
- raise TypeError("cell_fw must be an instance of RNNCell")
- if not _like_rnncell(cell_bw):
- raise TypeError("cell_bw must be an instance of RNNCell")
+ rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
+ rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
if not nest.is_sequence(inputs):
raise TypeError("inputs must be a sequence")
if not inputs:
_WEIGHTS_VARIABLE_NAME = "kernel"
+# TODO(jblespiau): Remove this function when we are sure there are no longer
+# any usage (even if protected, it is being used). Prefer assert_like_rnncell.
def _like_rnncell(cell):
"""Checks that a given object is an RNNCell by using duck typing."""
conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
return all(conditions)
+# This can be used with self.assertRaisesRegexp for assert_like_rnncell.
+ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell"
+
+
+def assert_like_rnncell(cell_name, cell):
+ """Raises a TypeError if cell is not like an RNNCell.
+
+ NOTE: Do not rely on the error message (in particular in tests) which can be
+ subject to change to increase readability. Use
+ ASSERT_LIKE_RNNCELL_ERROR_REGEXP.
+
+ Args:
+ cell_name: A string to give a meaningful error referencing to the name
+ of the functionargument.
+ cell: The object which should behave like an RNNCell.
+
+ Raises:
+ TypeError: A human-friendly exception.
+ """
+ conditions = [
+ hasattr(cell, "output_size"),
+ hasattr(cell, "state_size"),
+ hasattr(cell, "zero_state"),
+ callable(cell),
+ ]
+ errors = [
+ "'output_size' property is missing",
+ "'state_size' property is missing",
+ "'zero_state' method is missing",
+ "is not callable"
+ ]
+
+ if not all(conditions):
+
+ errors = [error for error, cond in zip(errors, conditions) if not cond]
+ raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format(
+ cell_name, cell, ", ".join(errors)))
+
+
def _concat(prefix, suffix, static=False):
"""Concat that enables int, Tensor, or TensorShape values.
but not `callable`.
ValueError: if any of the keep_probs are not between 0 and 1.
"""
- if not _like_rnncell(cell):
- raise TypeError("The parameter cell is not a RNNCell.")
+ assert_like_rnncell("cell", cell)
+
if (dropout_state_filter_visitor is not None
and not callable(dropout_state_filter_visitor)):
raise TypeError("dropout_state_filter_visitor must be callable")