[TPU]: If the $TPU_NAME env var is set, fallback to that.
authorBrennan Saeta <saeta@google.com>
Fri, 11 May 2018 00:32:40 +0000 (17:32 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 00:35:29 +0000 (17:35 -0700)
PiperOrigin-RevId: 196196939

tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py

index 1403483..8ede286 100644 (file)
@@ -36,6 +36,7 @@ except ImportError:
 
 
 _GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
+_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
 
 
 class TPUClusterResolver(ClusterResolver):
@@ -70,6 +71,12 @@ class TPUClusterResolver(ClusterResolver):
   def _gkeMaster():
     return os.environ[_GKE_ENV_VARIABLE].split(',')[0]
 
+  @staticmethod
+  def _envVarFallback():
+    if _DEFAULT_ENV_VARIABLE in os.environ:
+      return os.environ[_DEFAULT_ENV_VARIABLE]
+    return None
+
   def __init__(self,
                tpu=None,
                zone=None,
@@ -123,8 +130,11 @@ class TPUClusterResolver(ClusterResolver):
 
     in_gke = self._inGke()
     # When using GKE with Cloud TPUs, the env variable will be set.
-    if tpu is None and in_gke:
-      tpu = self._gkeMaster()
+    if tpu is None:
+      if in_gke:
+        tpu = self._gkeMaster()
+      else:
+        tpu = self._envVarFallback()
 
     self._tpu = compat.as_bytes(tpu)  # self._tpu is always bytes
     self._job_name = job_name