Renames exported signature names in MultiHead so head_name comes first.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 9 Apr 2018 19:37:33 +0000 (12:37 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 9 Apr 2018 19:39:32 +0000 (12:39 -0700)
PiperOrigin-RevId: 192168628

tensorflow/contrib/estimator/python/estimator/multi_head.py
tensorflow/contrib/estimator/python/estimator/multi_head_test.py

index bbbc19c..ce75899 100644 (file)
@@ -345,7 +345,7 @@ class _MultiHead(head_lib._Head):  # pylint:disable=protected-access
         if k == _DEFAULT_SERVING_KEY:
           key = head_name
         else:
-          key = '%s/%s' % (k, head_name)
+          key = '%s/%s' % (head_name, k)
         export_outputs[key] = v
         if (k == head_lib._PREDICT_SERVING_KEY and  # pylint:disable=protected-access
             isinstance(v, export_output_lib.PredictOutput)):
index d9e5aca..3d6fccb 100644 (file)
@@ -127,8 +127,8 @@ class MultiHeadTest(test.TestCase):
         logits=logits)
 
     self.assertItemsEqual(
-        (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1',
-         'predict/head1', 'head2', 'classification/head2', 'predict/head2'),
+        (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification',
+         'head1/predict', 'head2', 'head2/classification', 'head2/predict'),
         spec.export_outputs.keys())
 
     # Assert predictions and export_outputs.
@@ -169,11 +169,11 @@ class MultiHeadTest(test.TestCase):
       self.assertAllClose(
           expected_probabilities['head1'],
           sess.run(
-              spec.export_outputs['predict/head1'].outputs['probabilities']))
+              spec.export_outputs['head1/predict'].outputs['probabilities']))
       self.assertAllClose(
           expected_probabilities['head2'],
           sess.run(
-              spec.export_outputs['predict/head2'].outputs['probabilities']))
+              spec.export_outputs['head2/predict'].outputs['probabilities']))
 
   def test_predict_two_heads_logits_tensor(self):
     """Tests predict with logits as Tensor."""
@@ -197,8 +197,8 @@ class MultiHeadTest(test.TestCase):
         logits=logits)
 
     self.assertItemsEqual(
-        (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1',
-         'predict/head1', 'head2', 'classification/head2', 'predict/head2'),
+        (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification',
+         'head1/predict', 'head2', 'head2/classification', 'head2/predict'),
         spec.export_outputs.keys())
 
     # Assert predictions and export_outputs.
@@ -254,8 +254,8 @@ class MultiHeadTest(test.TestCase):
         logits=logits)
 
     self.assertItemsEqual(
-        (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'regression/head1',
-         'predict/head1', 'head2', 'regression/head2', 'predict/head2'),
+        (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/regression',
+         'head1/predict', 'head2', 'head2/regression', 'head2/predict'),
         spec.export_outputs.keys())
 
     # Assert predictions and export_outputs.