Check to ensure the Cloud TPU is ready before resolving.
authorBrennan Saeta <saeta@google.com>
Wed, 13 Jun 2018 01:09:35 +0000 (18:09 -0700)
committerBrennan Saeta <saeta@google.com>
Wed, 13 Jun 2018 01:09:35 +0000 (18:09 -0700)
Cherry picking this into the TF 1.9 release.

PiperOrigin-RevId: 200095692

Previous commit: 32c8013f0ab3feb139648ae759e2d0168fb5dc95

tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py

index 880fca4..935ad5f 100644 (file)
@@ -255,6 +255,9 @@ class TPUClusterResolver(ClusterResolver):
       request = self._service.projects().locations().nodes().get(name=full_name)
       response = request.execute()
 
+      if 'state' in response and response['state'] != 'READY':
+        raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
+                           (self._tpu, response['state']))
       if 'health' in response and response['health'] != 'HEALTHY':
         raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu,
                                                             response['health']))
index 5fac55f..7e002cc 100644 (file)
@@ -157,6 +157,50 @@ class TPUClusterResolverTest(test.TestCase):
     job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } }
     """
     self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+    
+  @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata',
+                     mock_request_compute_metadata)
+  def testUnhealthyCloudTpu(self):
+    tpu_map = {
+        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
+            'ipAddress': '10.1.2.3',
+            'port': '8470',
+            'health': 'UNHEALTHY'
+        }
+    }
+
+    tpu_cluster_resolver = TPUClusterResolver(
+        project=None,
+        zone=None,
+        tpu='test-tpu-1',
+        coordinator_name=None,
+        credentials=None,
+        service=self.mock_service_client(tpu_map=tpu_map))
+
+    with self.assertRaises(RuntimeError):
+      tpu_cluster_resolver.cluster_spec()
+
+  @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata',
+                     mock_request_compute_metadata)
+  def testNotReadyCloudTpu(self):
+    tpu_map = {
+        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
+            'ipAddress': '10.1.2.3',
+            'port': '8470',
+            'state': 'CREATING'
+        }
+    }
+
+    tpu_cluster_resolver = TPUClusterResolver(
+        project=None,
+        zone=None,
+        tpu='test-tpu-1',
+        coordinator_name=None,
+        credentials=None,
+        service=self.mock_service_client(tpu_map=tpu_map))
+
+    with self.assertRaises(RuntimeError):
+      tpu_cluster_resolver.cluster_spec()
 
   def testSimpleSuccessfulRetrieval(self):
     tpu_map = {