Fixed the bug that predict input_fn requires the labels.
authorJianwei Xie <xiejw@google.com>
Tue, 6 Mar 2018 18:31:07 +0000 (10:31 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 18:38:51 +0000 (10:38 -0800)
PiperOrigin-RevId: 188042708

tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

index 1b2eda1..a7991eb 100644 (file)
@@ -2308,6 +2308,11 @@ class _InputsWithStoppingSignals(_Inputs):
     """
 
     def _map_fn(*args):
+      """The map fn to insert signals."""
+      if len(args) == 1:
+        # Unpack the single Tensor/dict argument as features. This is required
+        # for the input_fn returns no labels.
+        args = args[0]
       features, labels = _Inputs._parse_inputs(args)
       new_input_dict = {}
       new_input_dict['features'] = features