import tensorflow.contrib.eager as tfe
from tensorflow.contrib.eager.python.examples.spinn import data
from third_party.examples.eager.spinn import spinn
+from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
right_in.append(tf.random_normal((1, size * 2)))
tracking.append(tf.random_normal((1, tracker_size * 2)))
- out = reducer(left_in, right_in, tracking=tracking)
+ out = reducer(left_in, right_in=right_in, tracking=tracking)
self.assertEqual(batch_size, len(out))
self.assertEqual(tf.float32, out[0].dtype)
self.assertEqual((1, size * 2), out[0].shape)
self.assertEqual((batch_size, size * 2), stacks[0][0].shape)
for _ in range(2):
- out1, out2 = tracker(bufs, stacks)
+ out1, out2 = tracker(bufs, stacks=stacks)
self.assertIsNone(out2)
self.assertEqual(batch_size, len(out1))
self.assertEqual(tf.float32, out1[0].dtype)
self.assertEqual(tf.int64, transitions.dtype)
self.assertEqual((num_transitions, 1), transitions.shape)
- out = s(buffers, transitions, training=True)
+ out = s(buffers, transitions=transitions, training=True)
self.assertEqual(tf.float32, out.dtype)
self.assertEqual((1, embedding_dims), out.shape)
vocab_size)
# Invoke model under non-training mode.
- logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
+ logits = model(
+ prem, premise_transition=prem_trans, hypothesis=hypo,
+ hypothesis_transition=hypo_trans, training=False)
self.assertEqual(tf.float32, logits.dtype)
self.assertEqual((batch_size, d_out), logits.shape)
# Invoke model under training model.
- logits = model(prem, prem_trans, hypo, hypo_trans, training=True)
+ logits = model(prem, premise_transition=prem_trans, hypothesis=hypo,
+ hypothesis_transition=hypo_trans, training=True)
self.assertEqual(tf.float32, logits.dtype)
self.assertEqual((batch_size, d_out), logits.shape)
# 5. Verify that checkpoints exist and contains all the expected variables.
self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
- ckpt_variable_names = [
- item[0] for item in checkpoint_utils.list_variables(config.logdir)]
+ object_graph_string = checkpoint_utils.load_variable(
+ config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH")
+ object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph()
+ object_graph.ParseFromString(object_graph_string)
+ ckpt_variable_names = set()
+ for node in object_graph.nodes:
+ for attribute in node.attributes:
+ ckpt_variable_names.add(attribute.full_name)
self.assertIn("global_step", ckpt_variable_names)
for v in trainer.variables:
variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name
from tensorflow.contrib.eager.python.examples.spinn import data
+layers = tf.keras.layers
+
+
def _bundle(lstm_iter):
"""Concatenate a list of Tensors along 1st axis and split result into two.
return tf.split(tf.concat(state, 1), state[0].shape[0], axis=0)
-class Reducer(tfe.Network):
+# pylint: disable=not-callable
+class Reducer(tf.keras.Model):
"""A module that applies reduce operation on left and right vectors."""
def __init__(self, size, tracker_size=None):
super(Reducer, self).__init__()
- self.left = self.track_layer(tf.layers.Dense(5 * size, activation=None))
- self.right = self.track_layer(
- tf.layers.Dense(5 * size, activation=None, use_bias=False))
+ self.left = layers.Dense(5 * size, activation=None)
+ self.right = layers.Dense(5 * size, activation=None, use_bias=False)
if tracker_size is not None:
- self.track = self.track_layer(
- tf.layers.Dense(5 * size, activation=None, use_bias=False))
+ self.track = layers.Dense(5 * size, activation=None, use_bias=False)
else:
self.track = None
return h, c
-class Tracker(tfe.Network):
+class Tracker(tf.keras.Model):
"""A module that tracks the history of the sentence with an LSTM."""
def __init__(self, tracker_size, predict):
predict: (`bool`) Whether prediction mode is enabled.
"""
super(Tracker, self).__init__()
- self._rnn = self.track_layer(tf.nn.rnn_cell.LSTMCell(tracker_size))
+ self._rnn = tf.nn.rnn_cell.LSTMCell(tracker_size)
self._state_size = tracker_size
if predict:
- self._transition = self.track_layer(tf.layers.Dense(4))
+ self._transition = layers.Dense(4)
else:
self._transition = None
return unbundled, None
-class SPINN(tfe.Network):
+class SPINN(tf.keras.Model):
"""Stack-augmented Parser-Interpreter Neural Network.
See https://arxiv.org/abs/1603.06021 for more details.
"""
super(SPINN, self).__init__()
self.config = config
- self.reducer = self.track_layer(Reducer(config.d_hidden, config.d_tracker))
+ self.reducer = Reducer(config.d_hidden, config.d_tracker)
if config.d_tracker is not None:
- self.tracker = self.track_layer(Tracker(config.d_tracker, config.predict))
+ self.tracker = Tracker(config.d_tracker, config.predict)
else:
self.tracker = None
trans = transitions[i]
if self.tracker:
# Invoke tracker to obtain the current tracker states for the sentences.
- tracker_states, trans_hypothesis = self.tracker(buffers, stacks)
+ tracker_states, trans_hypothesis = self.tracker(buffers, stacks=stacks)
if trans_hypothesis:
trans = tf.argmax(trans_hypothesis, axis=-1)
else:
trackings.append(tracking)
if rights:
- reducer_output = self.reducer(lefts, rights, trackings)
+ reducer_output = self.reducer(
+ lefts, right_in=rights, tracking=trackings)
reduced = iter(reducer_output)
for transition, stack in zip(trans, stacks):
return _bundle([stack.pop() for stack in stacks])[0]
-class SNLIClassifier(tfe.Network):
+class Perceptron(tf.keras.Model):
+ """One layer of the SNLIClassifier multi-layer perceptron."""
+
+ def __init__(self, dimension, dropout_rate, previous_layer):
+ """Configure the Perceptron."""
+ super(Perceptron, self).__init__()
+ self.dense = tf.keras.layers.Dense(dimension, activation=tf.nn.elu)
+ self.batchnorm = layers.BatchNormalization()
+ self.dropout = layers.Dropout(rate=dropout_rate)
+ self.previous_layer = previous_layer
+
+ def call(self, x, training):
+ """Run previous Perceptron layers, then this one."""
+ x = self.previous_layer(x, training=training)
+ x = self.dense(x)
+ x = self.batchnorm(x, training=training)
+ x = self.dropout(x, training=training)
+ return x
+
+
+class SNLIClassifier(tf.keras.Model):
"""SNLI Classifier Model.
A model aimed at solving the SNLI (Standford Natural Language Inference)
self.config = config
self.embed = tf.constant(embed)
- self.projection = self.track_layer(tf.layers.Dense(config.d_proj))
- self.embed_bn = self.track_layer(tf.layers.BatchNormalization())
- self.embed_dropout = self.track_layer(
- tf.layers.Dropout(rate=config.embed_dropout))
- self.encoder = self.track_layer(SPINN(config))
-
- self.feature_bn = self.track_layer(tf.layers.BatchNormalization())
- self.feature_dropout = self.track_layer(
- tf.layers.Dropout(rate=config.mlp_dropout))
-
- self.mlp_dense = []
- self.mlp_bn = []
- self.mlp_dropout = []
- for _ in xrange(config.n_mlp_layers):
- self.mlp_dense.append(self.track_layer(tf.layers.Dense(config.d_mlp)))
- self.mlp_bn.append(
- self.track_layer(tf.layers.BatchNormalization()))
- self.mlp_dropout.append(
- self.track_layer(tf.layers.Dropout(rate=config.mlp_dropout)))
- self.mlp_output = self.track_layer(tf.layers.Dense(
+ self.projection = layers.Dense(config.d_proj)
+ self.embed_bn = layers.BatchNormalization()
+ self.embed_dropout = layers.Dropout(rate=config.embed_dropout)
+ self.encoder = SPINN(config)
+
+ self.feature_bn = layers.BatchNormalization()
+ self.feature_dropout = layers.Dropout(rate=config.mlp_dropout)
+
+ current_mlp = lambda result, training: result
+ for _ in range(config.n_mlp_layers):
+ current_mlp = Perceptron(dimension=config.d_mlp,
+ dropout_rate=config.mlp_dropout,
+ previous_layer=current_mlp)
+ self.mlp = current_mlp
+ self.mlp_output = layers.Dense(
config.d_out,
kernel_initializer=tf.random_uniform_initializer(minval=-5e-3,
- maxval=5e-3)))
+ maxval=5e-3))
def call(self,
premise,
# Run the batch-normalized and dropout-processed word vectors through the
# SPINN encoder.
- premise = self.encoder(premise_embed, premise_transition,
- training=training)
- hypothesis = self.encoder(hypothesis_embed, hypothesis_transition,
- training=training)
+ premise = self.encoder(
+ premise_embed, transitions=premise_transition, training=training)
+ hypothesis = self.encoder(
+ hypothesis_embed, transitions=hypothesis_transition, training=training)
# Combine encoder outputs for premises and hypotheses into logits.
# Then apply batch normalization and dropuout on the logits.
self.feature_bn(logits, training=training), training=training)
# Apply the multi-layer perceptron on the logits.
- for dense, bn, dropout in zip(
- self.mlp_dense, self.mlp_bn, self.mlp_dropout):
- logits = tf.nn.elu(dense(logits))
- logits = dropout(bn(logits, training=training), training=training)
+ logits = self.mlp(logits, training=training)
logits = self.mlp_output(logits)
return logits
-class SNLIClassifierTrainer(object):
+class SNLIClassifierTrainer(tfe.Checkpointable):
"""A class that coordinates the training of an SNLIClassifier."""
def __init__(self, snli_classifier, lr):
"""
with tfe.GradientTape() as tape:
tape.watch(self._model.variables)
+ # TODO(allenl): Allow passing Layer inputs as position arguments.
logits = self._model(premise,
- premise_transition,
- hypothesis,
- hypothesis_transition,
+ premise_transition=premise_transition,
+ hypothesis=hypothesis,
+ hypothesis_transition=hypothesis_transition,
training=True)
loss = self.loss(labels, logits)
gradients = tape.gradient(loss, self._model.variables)
snli_data, batch_size):
if use_gpu:
label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
- logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
+ logits = trainer.model(
+ prem, premise_transition=prem_trans, hypothesis=hypo,
+ hypothesis_transition=hypo_trans, training=False)
loss_val = trainer.loss(label, logits)
batch_size = tf.shape(label)[0]
mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
with tf.device(device), \
summary_writer.as_default(), \
tf.contrib.summary.always_record_summaries():
- with tfe.restore_variables_on_create(
- tf.train.latest_checkpoint(config.logdir)):
- model = SNLIClassifier(config, embed)
- global_step = tf.train.get_or_create_global_step()
- trainer = SNLIClassifierTrainer(model, config.lr)
+ model = SNLIClassifier(config, embed)
+ global_step = tf.train.get_or_create_global_step()
+ trainer = SNLIClassifierTrainer(model, config.lr)
+ checkpoint = tfe.Checkpoint(trainer=trainer, global_step=global_step)
+ checkpoint.restore(tf.train.latest_checkpoint(config.logdir))
if inference_sentence_pair:
# Inference mode.
- with tfe.restore_variables_on_create(
- tf.train.latest_checkpoint(config.logdir)):
- prem, prem_trans = inference_sentence_pair[0]
- hypo, hypo_trans = inference_sentence_pair[1]
- hypo_trans = inference_sentence_pair[1][1]
- inference_logits = model( # pylint: disable=not-callable
- tf.constant(prem), tf.constant(prem_trans),
- tf.constant(hypo), tf.constant(hypo_trans), training=False)
- inference_logits = inference_logits[0][1:]
- max_index = tf.argmax(inference_logits)
- print("\nInference logits:")
- for i, (label, logit) in enumerate(
- zip(data.POSSIBLE_LABELS, inference_logits)):
- winner_tag = " (winner)" if max_index == i else ""
- print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag))
+ prem, prem_trans = inference_sentence_pair[0]
+ hypo, hypo_trans = inference_sentence_pair[1]
+ hypo_trans = inference_sentence_pair[1][1]
+ inference_logits = model(
+ tf.constant(prem),
+ premise_transition=tf.constant(prem_trans),
+ hypothesis=tf.constant(hypo),
+ hypothesis_transition=tf.constant(hypo_trans),
+ training=False)
+ inference_logits = inference_logits[0][1:]
+ max_index = tf.argmax(inference_logits)
+ print("\nInference logits:")
+ for i, (label, logit) in enumerate(
+ zip(data.POSSIBLE_LABELS, inference_logits)):
+ winner_tag = " (winner)" if max_index == i else ""
+ print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag))
return inference_logits
train_len = train_data.num_batches(config.batch_size)
# remain on CPU. Same in _evaluate_on_dataset().
iterations += 1
- with tfe.restore_variables_on_create(
- tf.train.latest_checkpoint(config.logdir)):
- batch_train_loss, batch_train_logits = trainer.train_batch(
- label, prem, prem_trans, hypo, hypo_trans)
+ batch_train_loss, batch_train_logits = trainer.train_batch(
+ label, prem, prem_trans, hypo, hypo_trans)
batch_size = tf.shape(label)[0]
mean_loss(batch_train_loss.numpy(),
weights=batch_size.gpu() if use_gpu else batch_size)
accuracy(tf.argmax(batch_train_logits, axis=1), label)
if iterations % config.save_every == 0:
- all_variables = trainer.variables + [global_step]
- saver = tfe.Saver(all_variables)
- saver.save(os.path.join(config.logdir, "ckpt"),
- global_step=global_step)
+ checkpoint.save(os.path.join(config.logdir, "ckpt"))
if iterations % config.dev_every == 0:
dev_loss, dev_frac_correct = _evaluate_on_dataset(