From df847112f0a8805dab02cc5581870a8460032ef3 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Thu, 29 Mar 2018 19:57:27 -0700 Subject: [PATCH] Internal Change. PiperOrigin-RevId: 191024708 --- .../contrib/distribute/python/examples/BUILD | 11 ++++ .../python/examples/simple_tfkeras_example.py | 62 ++++++++++++++++++++++ tensorflow/python/keras/_impl/keras/optimizers.py | 29 +++++++--- 3 files changed, 96 insertions(+), 6 deletions(-) create mode 100644 tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD index 27eb3c0..cbfd178 100644 --- a/tensorflow/contrib/distribute/python/examples/BUILD +++ b/tensorflow/contrib/distribute/python/examples/BUILD @@ -17,3 +17,14 @@ py_binary( "//tensorflow:tensorflow_py", ], ) + +py_binary( + name = "simple_tfkeras_example", + srcs = [ + "simple_tfkeras_example.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py new file mode 100644 index 0000000..e714255 --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An example tf.keras model that is trained using MirroredStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from sys import argv +import numpy as np +import tensorflow as tf + + +def input_fn(): + x = np.random.random((1024, 10)) + y = np.random.randint(2, size=(1024, 1)) + x = tf.cast(x, tf.float32) + dataset = tf.data.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(10) + dataset = dataset.batch(32) + return dataset + + +def main(args): + if len(args) < 2: + print('You must specify model_dir for checkpoints such as' + ' /tmp/tfkeras_example./') + return + + print('Using %s to store checkpoints.' % args[1]) + + strategy = tf.contrib.distribute.MirroredStrategy( + ['/device:GPU:0', '/device:GPU:1']) + config = tf.estimator.RunConfig(distribute=strategy) + optimizer = tf.train.GradientDescentOptimizer(0.2) + + model = tf.keras.Sequential() + model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,))) + model.add(tf.keras.layers.Dense(1, activation='sigmoid')) + + model.compile(loss='binary_crossentropy', optimizer=optimizer) + model.summary() + tf.keras.backend.set_learning_phase(True) + keras_estimator = tf.keras.estimator.model_to_estimator( + keras_model=model, config=config, model_dir=args[1]) + + keras_estimator.train(input_fn=input_fn, steps=10) + eval_result = keras_estimator.evaluate(input_fn=input_fn) + print('Eval result: {}'.format(eval_result)) + +if __name__ == '__main__': + tf.app.run(argv=argv) diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py index b715d72..acbb909 100644 --- a/tensorflow/python/keras/_impl/keras/optimizers.py +++ b/tensorflow/python/keras/_impl/keras/optimizers.py @@ -31,7 +31,9 @@ from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_ from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import optimizer as tf_optimizer_module +from tensorflow.python.training import training_util from tensorflow.python.util.tf_export import tf_export @@ -728,12 +730,27 @@ class TFOptimizer(Optimizer): return self.optimizer.compute_gradients(loss, params) def get_updates(self, loss, params): - self.updates = [K.update_add(self.iterations, 1)] - if not params: - return self.updates - grads = self.optimizer.compute_gradients(loss, params) - opt_update = self.optimizer.apply_gradients( - grads, global_step=self.iterations) + if distribute_lib.has_distribution_strategy(): + self.updates = [] + + if not params: + # After the model vars have been created, the second call to get_updates + # is called with params as an empty list. This ensures that we call + # compute_gradients with params=None. + grads = self.optimizer.compute_gradients(loss) + else: + grads = self.optimizer.compute_gradients(loss, params) + global_step = training_util.get_global_step() + opt_update = self.optimizer.apply_gradients(grads, global_step) + else: + self.updates = [K.update_add(self.iterations, 1)] + if not params: + return self.updates + + grads = self.optimizer.compute_gradients(loss, params) + opt_update = self.optimizer.apply_gradients( + grads, global_step=self.iterations) + self.updates.append(opt_update) return self.updates -- 2.7.4