logits=logits)
self.assertItemsEqual(
- (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1',
- 'predict/head1', 'head2', 'classification/head2', 'predict/head2'),
+ (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification',
+ 'head1/predict', 'head2', 'head2/classification', 'head2/predict'),
spec.export_outputs.keys())
# Assert predictions and export_outputs.
self.assertAllClose(
expected_probabilities['head1'],
sess.run(
- spec.export_outputs['predict/head1'].outputs['probabilities']))
+ spec.export_outputs['head1/predict'].outputs['probabilities']))
self.assertAllClose(
expected_probabilities['head2'],
sess.run(
- spec.export_outputs['predict/head2'].outputs['probabilities']))
+ spec.export_outputs['head2/predict'].outputs['probabilities']))
def test_predict_two_heads_logits_tensor(self):
"""Tests predict with logits as Tensor."""
logits=logits)
self.assertItemsEqual(
- (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1',
- 'predict/head1', 'head2', 'classification/head2', 'predict/head2'),
+ (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification',
+ 'head1/predict', 'head2', 'head2/classification', 'head2/predict'),
spec.export_outputs.keys())
# Assert predictions and export_outputs.
logits=logits)
self.assertItemsEqual(
- (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'regression/head1',
- 'predict/head1', 'head2', 'regression/head2', 'predict/head2'),
+ (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/regression',
+ 'head1/predict', 'head2', 'head2/regression', 'head2/predict'),
spec.export_outputs.keys())
# Assert predictions and export_outputs.