from __future__ import division
from __future__ import print_function
+import os
from six.moves.urllib.request import Request
from six.moves.urllib.request import urlopen
_GOOGLE_API_CLIENT_INSTALLED = False
+_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
+
+
class TPUClusterResolver(ClusterResolver):
"""Cluster Resolver for Google Cloud TPUs.
return False
return True
+ def _inGke(self):
+ """When running in GKE, the environment variable will be set."""
+ return _GKE_ENV_VARIABLE in os.environ
+
+ def _gkeMaster(self):
+ return os.environ[_GKE_ENV_VARIABLE].split(',')[0]
+
def __init__(self,
- tpu,
+ tpu=None,
zone=None,
project=None,
job_name='worker',
raise NotImplementedError(
'Using multiple TPUs in a single session is not yet implemented')
tpu = tpu[0]
+
+ # When using GKE with Cloud TPUs, the env variable will be set.
+ if tpu is None and self._inGke():
+ tpu = self._gkeMaster()
+
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
self._job_name = job_name
self._credentials = credentials
from __future__ import division
from __future__ import print_function
+import os
+
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
"""
self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
self.assertProtoEquals(
- expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def())
- self.assertProtoEquals(
expected_proto,
- server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
- self.assertProtoEquals(
- expected_proto,
- server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
+ server_lib.ClusterSpec(cluster_spec).as_cluster_def())
+ self.assertProtoEquals(expected_proto,
+ server_lib.ClusterSpec(
+ cluster_spec.as_cluster_def()).as_cluster_def())
+ self.assertProtoEquals(expected_proto,
+ server_lib.ClusterSpec(
+ cluster_spec.as_dict()).as_cluster_def())
- def mock_service_client(
- self,
- tpu_map=None):
+ def mock_service_client(self, tpu_map=None):
if tpu_map is None:
tpu_map = {}
return mock_client
- @mock.patch.object(TPUClusterResolver,
- '_requestComputeMetadata',
+ @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata',
mock_request_compute_metadata)
def testRetrieveProjectAndZoneFromMetadata(self):
tpu_map = {
def testNoCallComputeMetadata(self):
tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/foo/bar')
- self.assertEqual(compat.as_bytes('/bns/foo/bar'),
- tpu_cluster_resolver.master())
+ self.assertEqual(
+ compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master())
self.assertEqual(
server_lib.ClusterSpec({}), tpu_cluster_resolver.cluster_spec())
+ def testGkeEnvironment(self):
+ os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
+ self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ)
+ tpu_cluster_resolver = TPUClusterResolver()
+ self.assertTrue(tpu_cluster_resolver._inGke())
+ self.assertEqual(
+ compat.as_bytes('grpc://10.120.27.5:8470'),
+ tpu_cluster_resolver._gkeMaster())
+ self.assertEqual(
+ compat.as_bytes('grpc://10.120.27.5:8470'),
+ tpu_cluster_resolver.get_master())
+ del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
+
if __name__ == '__main__':
test.main()