# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""An example tf.keras model that is trained using MirroredStrategy."""
+"""An example of training tf.keras Model using MirroredStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from sys import argv
+
+import sys
+
import numpy as np
import tensorflow as tf
def main(args):
if len(args) < 2:
- print('You must specify model_dir for checkpoints such as'
- ' /tmp/tfkeras_example./')
+ print('You must specify model_dir for checkpoints such as'
+ ' /tmp/tfkeras_example/.')
return
- print('Using %s to store checkpoints.' % args[1])
-
- strategy = tf.contrib.distribute.MirroredStrategy(
- ['/device:GPU:0', '/device:GPU:1'])
- config = tf.estimator.RunConfig(train_distribute=strategy)
- optimizer = tf.train.GradientDescentOptimizer(0.2)
+ model_dir = args[1]
+ print('Using %s to store checkpoints.' % model_dir)
+ # Define tf.keras Model.
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,)))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
+ # Compile tf.keras Model.
+ optimizer = tf.train.GradientDescentOptimizer(0.2)
model.compile(loss='binary_crossentropy', optimizer=optimizer)
model.summary()
tf.keras.backend.set_learning_phase(True)
+
+ # Define a DistributionStrategy and convert the tf.keras Model to a
+ # tf.Estimator that utilizes the DistributionStrategy.
+ strategy = tf.contrib.distribute.MirroredStrategy(
+ ['/device:GPU:0', '/device:GPU:1'])
+ config = tf.estimator.RunConfig(train_distribute=strategy)
keras_estimator = tf.keras.estimator.model_to_estimator(
- keras_model=model, config=config, model_dir=args[1])
+ keras_model=model, config=config, model_dir=model_dir)
+ # Train and evaluate the tf.Estimator.
keras_estimator.train(input_fn=input_fn, steps=10)
eval_result = keras_estimator.evaluate(input_fn=input_fn)
print('Eval result: {}'.format(eval_result))
+
if __name__ == '__main__':
- tf.app.run(argv=argv)
+ tf.app.run(argv=sys.argv)
on different slices of the input data. This is in contrast to
_model parallelism_ where we divide up a single copy of a model
across multiple devices.
- Note: for now we only support data parallelism at this time, but
+ Note: we only support data parallelism for now, but
hope to add support for model parallelism in the future.
* A _tower_ is one copy of the model, running on one slice of the
input data.
- * _Synchronous_, or more commonly _sync_, training is when the
+ * _Synchronous_, or more commonly _sync_, training is where the
updates from each tower are aggregated together before updating
the model variables. This is in contrast to _asynchronous_, or
- _async_ training where each tower updates the model variables
+ _async_ training, where each tower updates the model variables
independently.
* Furthermore you might run your computation on multiple devices
on one machine (or "host"), or on multiple machines/hosts.
* Reductions and Allreduce: A _reduction_ is some method of
aggregating multiple values into one value, like "sum" or
"mean". If doing sync training, we will perform a reduction on the
- gradients to a parameter from each tower before applying the
+ gradients to a parameter from all towers before applying the
update. Allreduce is an algorithm for performing a reduction on
values from multiple devices and making the result available on
all of those devices.
- * In the future we will have support for TensorFlows' partitioned
+ * In the future we will have support for TensorFlow's partitioned
variables, where a single variable is split across multiple
devices.
`tower_fn` can use the `get_tower_context()` API to get enhanced
behavior in this case.
- You can also create an initializable iterator instead of one shot iterator.
- In that case, you will need to ensure that you initialize the iterator
- before calling get_next.
+ You can also create an initializable iterator instead of a one-shot
+ iterator. In that case, you will need to ensure that you initialize the
+ iterator before calling get_next.
```
iterator = my_distribution.distribute_dataset(
dataset).make_initializable_iterator())