From 5574d6300c5e05dceb92d6d765313a99dd2c417d Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Mon, 5 Mar 2018 19:15:13 -0800 Subject: [PATCH] [TPU Cluster Resolver]: Integrate with GKE This change integrates the TPUClusterResolver with GKE's support for Cloud TPUs PiperOrigin-RevId: 187961802 --- .../python/training/tpu_cluster_resolver.py | 18 ++++++++- .../training/tpu_cluster_resolver_test.py | 39 ++++++++++++------- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index aeccf4c06b..83d26a17a8 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os from six.moves.urllib.request import Request from six.moves.urllib.request import urlopen @@ -34,6 +35,9 @@ except ImportError: _GOOGLE_API_CLIENT_INSTALLED = False +_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' + + class TPUClusterResolver(ClusterResolver): """Cluster Resolver for Google Cloud TPUs. @@ -57,8 +61,15 @@ class TPUClusterResolver(ClusterResolver): return False return True + def _inGke(self): + """When running in GKE, the environment variable will be set.""" + return _GKE_ENV_VARIABLE in os.environ + + def _gkeMaster(self): + return os.environ[_GKE_ENV_VARIABLE].split(',')[0] + def __init__(self, - tpu, + tpu=None, zone=None, project=None, job_name='worker', @@ -107,6 +118,11 @@ class TPUClusterResolver(ClusterResolver): raise NotImplementedError( 'Using multiple TPUs in a single session is not yet implemented') tpu = tpu[0] + + # When using GKE with Cloud TPUs, the env variable will be set. + if tpu is None and self._inGke(): + tpu = self._gkeMaster() + self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name self._credentials = credentials diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 6b4a155152..b7d56fc122 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -73,18 +75,17 @@ class TPUClusterResolverTest(test.TestCase): expected_proto: Expected protobuf """ self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) - self.assertProtoEquals( - expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def()) self.assertProtoEquals( expected_proto, - server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) - self.assertProtoEquals( - expected_proto, - server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) + server_lib.ClusterSpec(cluster_spec).as_cluster_def()) + self.assertProtoEquals(expected_proto, + server_lib.ClusterSpec( + cluster_spec.as_cluster_def()).as_cluster_def()) + self.assertProtoEquals(expected_proto, + server_lib.ClusterSpec( + cluster_spec.as_dict()).as_cluster_def()) - def mock_service_client( - self, - tpu_map=None): + def mock_service_client(self, tpu_map=None): if tpu_map is None: tpu_map = {} @@ -100,8 +101,7 @@ class TPUClusterResolverTest(test.TestCase): return mock_client - @mock.patch.object(TPUClusterResolver, - '_requestComputeMetadata', + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', mock_request_compute_metadata) def testRetrieveProjectAndZoneFromMetadata(self): tpu_map = { @@ -350,11 +350,24 @@ class TPUClusterResolverTest(test.TestCase): def testNoCallComputeMetadata(self): tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/foo/bar') - self.assertEqual(compat.as_bytes('/bns/foo/bar'), - tpu_cluster_resolver.master()) + self.assertEqual( + compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master()) self.assertEqual( server_lib.ClusterSpec({}), tpu_cluster_resolver.cluster_spec()) + def testGkeEnvironment(self): + os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' + self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ) + tpu_cluster_resolver = TPUClusterResolver() + self.assertTrue(tpu_cluster_resolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + tpu_cluster_resolver._gkeMaster()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + tpu_cluster_resolver.get_master()) + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + if __name__ == '__main__': test.main() -- 2.34.1