_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
+_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
class TPUClusterResolver(ClusterResolver):
def _gkeMaster():
return os.environ[_GKE_ENV_VARIABLE].split(',')[0]
+ @staticmethod
+ def _envVarFallback():
+ if _DEFAULT_ENV_VARIABLE in os.environ:
+ return os.environ[_DEFAULT_ENV_VARIABLE]
+ return None
+
def __init__(self,
tpu=None,
zone=None,
in_gke = self._inGke()
# When using GKE with Cloud TPUs, the env variable will be set.
- if tpu is None and in_gke:
- tpu = self._gkeMaster()
+ if tpu is None:
+ if in_gke:
+ tpu = self._gkeMaster()
+ else:
+ tpu = self._envVarFallback()
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
self._job_name = job_name