From 218647db25d1e754ad85fd1fa8a0960b82ae83bf Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Fri, 6 Apr 2018 09:26:08 -0700 Subject: [PATCH] [TPUClusterResolver] Start a TFServer when running in GKE This change allows advanced input pipelines (e.g. StreamingFilesDataset, or split-pipelines that use py_func's) to run in GKE- and GKE-like enviornments. PiperOrigin-RevId: 191897639 --- .../python/training/tpu_cluster_resolver.py | 75 ++++++++++++++-------- .../python/training/tpu_cluster_resolver_test.py | 8 +-- 2 files changed, 51 insertions(+), 32 deletions(-) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 300b197..a520a06 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -61,11 +61,13 @@ class TPUClusterResolver(ClusterResolver): return False return True - def _inGke(self): + @staticmethod + def _inGke(): """When running in GKE, the environment variable will be set.""" return _GKE_ENV_VARIABLE in os.environ - def _gkeMaster(self): + @staticmethod + def _gkeMaster(): return os.environ[_GKE_ENV_VARIABLE].split(',')[0] def __init__(self, @@ -119,8 +121,9 @@ class TPUClusterResolver(ClusterResolver): 'Using multiple TPUs in a single session is not yet implemented') tpu = tpu[0] + in_gke = self._inGke() # When using GKE with Cloud TPUs, the env variable will be set. - if tpu is None and self._inGke(): + if tpu is None and in_gke: tpu = self._gkeMaster() self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes @@ -158,7 +161,8 @@ class TPUClusterResolver(ClusterResolver): self._service = service self._coordinator_name = coordinator_name - if coordinator_name and not coordinator_address and should_resolve: + if coordinator_name and not coordinator_address and (should_resolve or + in_gke): self._start_local_server() else: self._coordinator_address = coordinator_address @@ -204,31 +208,50 @@ class TPUClusterResolver(ClusterResolver): Raises: RuntimeError: If the provided TPU is not healthy. """ - if not self._shouldResolve(): - return server_lib.ClusterSpec({}) - - full_name = 'projects/%s/locations/%s/nodes/%s' % ( - self._project, self._zone, compat.as_text(self._tpu)) - request = self._service.projects().locations().nodes().get(name=full_name) - response = request.execute() - - if 'health' in response and response['health'] != 'HEALTHY': - raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, - response['health'])) - - if 'networkEndpoints' in response: - worker_list = [ - '%s:%s' % (endpoint['ipAddress'], endpoint['port']) - for endpoint in response['networkEndpoints'] - ] + ############################################################################ + # There are 5 potential cases this code must handle: + # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and + # a. Create a ClusterSpec that includes the coordinator job + # b. Create a ClusterSpec without the coordinator job. + # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of + # tasks and + # a. Create a ClusterSpec with the coordinator + # b. Create a ClusterSpec without the coordinator + # 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec. + ############################################################################ + + if self._shouldResolve(): + # Case 1. + full_name = 'projects/%s/locations/%s/nodes/%s' % ( + self._project, self._zone, compat.as_text(self._tpu)) + request = self._service.projects().locations().nodes().get(name=full_name) + response = request.execute() + + if 'health' in response and response['health'] != 'HEALTHY': + raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, + response['health'])) + + if 'networkEndpoints' in response: + worker_list = [ + '%s:%s' % (endpoint['ipAddress'], endpoint['port']) + for endpoint in response['networkEndpoints'] + ] + else: + # Fall back to the deprecated response format + instance_url = '%s:%s' % (response['ipAddress'], response['port']) + worker_list = [instance_url] + + cluster_spec = {self._job_name: worker_list} else: - # Fall back to the deprecated response format - instance_url = '%s:%s' % (response['ipAddress'], response['port']) - worker_list = [instance_url] - - cluster_spec = {self._job_name: worker_list} + if not self._tpu.startswith(compat.as_bytes('grpc://')): + # Case 3. + return server_lib.ClusterSpec({}) + # Case 2. + cluster_spec = {self._job_name: [self._tpu[len( + compat.as_bytes('grpc://')):]]} if self._coordinator_address: + # {1, 2}.a cluster_spec[self._coordinator_name] = [self._coordinator_address] return server_lib.ClusterSpec(cluster_spec) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 48c3f6b..cfddca1 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -358,14 +358,10 @@ class TPUClusterResolverTest(test.TestCase): def testGkeEnvironment(self): os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ) - tpu_cluster_resolver = TPUClusterResolver() - self.assertTrue(tpu_cluster_resolver._inGke()) + self.assertTrue(TPUClusterResolver._inGke()) self.assertEqual( compat.as_bytes('grpc://10.120.27.5:8470'), - compat.as_bytes(tpu_cluster_resolver._gkeMaster())) - self.assertEqual( - compat.as_bytes('grpc://10.120.27.5:8470'), - compat.as_bytes(tpu_cluster_resolver.get_master())) + compat.as_bytes(TPUClusterResolver._gkeMaster())) del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] -- 2.7.4