Tpu cluster resolver only returns TF server addresses for 'HEALTHY' tpu nodes.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 3 Feb 2018 00:18:59 +0000 (16:18 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 3 Feb 2018 00:23:11 +0000 (16:23 -0800)
PiperOrigin-RevId: 184350480

tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py

index 2e75ac226ea74e879edda5e03dff3d53c8a76569..a6a6e642e4e4c721b94821a70d55d6fe931347d6 100644 (file)
@@ -143,7 +143,8 @@ class TPUClusterResolver(ClusterResolver):
       request = self._service.projects().locations().nodes().get(name=full_name)
       response = request.execute()
 
-      instance_url = '%s:%s' % (response['ipAddress'], response['port'])
-      worker_list.append(instance_url)
+      if 'health' in response and response['health'] == 'HEALTHY':
+        instance_url = '%s:%s' % (response['ipAddress'], response['port'])
+        worker_list.append(instance_url)
 
     return ClusterSpec({self._job_name: worker_list})
index 0c4730613af4ad9ca87deb6200ab4bb93d3f6a53..4fd34629cf74f90869c77b8cb098d3c585a49404 100644 (file)
@@ -105,7 +105,8 @@ class TPUClusterResolverTest(test.TestCase):
     tpu_map = {
         'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
             'ipAddress': '10.1.2.3',
-            'port': '8470'
+            'port': '8470',
+            'health': 'HEALTHY'
         }
     }
 
@@ -126,7 +127,8 @@ class TPUClusterResolverTest(test.TestCase):
     tpu_map = {
         'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
             'ipAddress': '10.1.2.3',
-            'port': '8470'
+            'port': '8470',
+            'health': 'HEALTHY'
         }
     }
 
@@ -147,11 +149,13 @@ class TPUClusterResolverTest(test.TestCase):
     tpu_map = {
         'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
             'ipAddress': '10.1.2.3',
-            'port': '8470'
+            'port': '8470',
+            'health': 'HEALTHY'
         },
         'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
             'ipAddress': '10.4.5.6',
-            'port': '8470'
+            'port': '8470',
+            'health': 'HEALTHY'
         }
     }
 
@@ -169,15 +173,54 @@ class TPUClusterResolverTest(test.TestCase):
     """
     self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
 
+  def testHealthyTpuNodeRetrieval(self):
+    tpu_map = {
+        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
+            'ipAddress': '10.1.2.3',
+            'port': '8470',
+            'health': 'HEALTHY'
+        },
+        'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
+            'ipAddress': '10.4.5.6',
+            'port': '8470',
+        },
+        'projects/test-project/locations/us-central1-c/nodes/test-tpu-3': {
+            'ipAddress': '10.7.8.9',
+            'port': '8470',
+            'health': 'UNHEALTHY'
+        }
+    }
+
+    tpu_cluster_resolver = TPUClusterResolver(
+        project='test-project',
+        zone='us-central1-c',
+        tpu_names=['test-tpu-2', 'test-tpu-1', 'test-tpu-3'],
+        credentials=None,
+        service=self.mock_service_client(tpu_map=tpu_map))
+
+    actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
+    expected_proto = """
+    job {
+      name: 'tpu_worker'
+      tasks {
+        key: 0
+        value: '10.1.2.3:8470'
+      }
+    }
+    """
+    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
   def testGetMasterMultipleEntries(self):
     tpu_map = {
         'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
             'ipAddress': '10.1.2.3',
-            'port': '8470'
+            'port': '8470',
+            'health': 'HEALTHY'
         },
         'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
             'ipAddress': '10.4.5.6',
-            'port': '8470'
+            'port': '8470',
+            'health': 'HEALTHY'
         }
     }