[TPU Cluster Resolver]: Integrate with GKE
authorBrennan Saeta <saeta@google.com>
Tue, 6 Mar 2018 03:15:13 +0000 (19:15 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 03:19:51 +0000 (19:19 -0800)
This change integrates the TPUClusterResolver with GKE's support for Cloud TPUs

PiperOrigin-RevId: 187961802

tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py

index aeccf4c..83d26a1 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import os
 
 from six.moves.urllib.request import Request
 from six.moves.urllib.request import urlopen
@@ -34,6 +35,9 @@ except ImportError:
   _GOOGLE_API_CLIENT_INSTALLED = False
 
 
+_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
+
+
 class TPUClusterResolver(ClusterResolver):
   """Cluster Resolver for Google Cloud TPUs.
 
@@ -57,8 +61,15 @@ class TPUClusterResolver(ClusterResolver):
       return False
     return True
 
+  def _inGke(self):
+    """When running in GKE, the environment variable will be set."""
+    return _GKE_ENV_VARIABLE in os.environ
+
+  def _gkeMaster(self):
+    return os.environ[_GKE_ENV_VARIABLE].split(',')[0]
+
   def __init__(self,
-               tpu,
+               tpu=None,
                zone=None,
                project=None,
                job_name='worker',
@@ -107,6 +118,11 @@ class TPUClusterResolver(ClusterResolver):
         raise NotImplementedError(
             'Using multiple TPUs in a single session is not yet implemented')
       tpu = tpu[0]
+
+    # When using GKE with Cloud TPUs, the env variable will be set.
+    if tpu is None and self._inGke():
+      tpu = self._gkeMaster()
+
     self._tpu = compat.as_bytes(tpu)  # self._tpu is always bytes
     self._job_name = job_name
     self._credentials = credentials
index 6b4a155..b7d56fc 100644 (file)
@@ -18,6 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import os
+
 from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
 from tensorflow.python.platform import test
 from tensorflow.python.training import server_lib
@@ -74,17 +76,16 @@ class TPUClusterResolverTest(test.TestCase):
     """
     self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
     self.assertProtoEquals(
-        expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def())
-    self.assertProtoEquals(
         expected_proto,
-        server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
-    self.assertProtoEquals(
-        expected_proto,
-        server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
+        server_lib.ClusterSpec(cluster_spec).as_cluster_def())
+    self.assertProtoEquals(expected_proto,
+                           server_lib.ClusterSpec(
+                               cluster_spec.as_cluster_def()).as_cluster_def())
+    self.assertProtoEquals(expected_proto,
+                           server_lib.ClusterSpec(
+                               cluster_spec.as_dict()).as_cluster_def())
 
-  def mock_service_client(
-      self,
-      tpu_map=None):
+  def mock_service_client(self, tpu_map=None):
 
     if tpu_map is None:
       tpu_map = {}
@@ -100,8 +101,7 @@ class TPUClusterResolverTest(test.TestCase):
 
     return mock_client
 
-  @mock.patch.object(TPUClusterResolver,
-                     '_requestComputeMetadata',
+  @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata',
                      mock_request_compute_metadata)
   def testRetrieveProjectAndZoneFromMetadata(self):
     tpu_map = {
@@ -350,11 +350,24 @@ class TPUClusterResolverTest(test.TestCase):
 
   def testNoCallComputeMetadata(self):
     tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/foo/bar')
-    self.assertEqual(compat.as_bytes('/bns/foo/bar'),
-                     tpu_cluster_resolver.master())
+    self.assertEqual(
+        compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master())
     self.assertEqual(
         server_lib.ClusterSpec({}), tpu_cluster_resolver.cluster_spec())
 
+  def testGkeEnvironment(self):
+    os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
+    self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ)
+    tpu_cluster_resolver = TPUClusterResolver()
+    self.assertTrue(tpu_cluster_resolver._inGke())
+    self.assertEqual(
+        compat.as_bytes('grpc://10.120.27.5:8470'),
+        tpu_cluster_resolver._gkeMaster())
+    self.assertEqual(
+        compat.as_bytes('grpc://10.120.27.5:8470'),
+        tpu_cluster_resolver.get_master())
+    del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
+
 
 if __name__ == '__main__':
   test.main()