Change back TFE_Execute logic to set '*num_retvals' on return.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 13 Mar 2018 18:36:28 +0000 (11:36 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Mar 2018 18:46:31 +0000 (11:46 -0700)
PiperOrigin-RevId: 188903892

tensorflow/c/eager/c_api.cc
tensorflow/c/eager/c_api.h
tensorflow/c/eager/c_api_test.cc
tensorflow/python/eager/core_test.py

index 56cec2d..0811bd3 100644 (file)
@@ -714,10 +714,7 @@ tensorflow::Status Execute(
       *dev_stats->add_node_stats() = *maybe_stats;
     }
   }
-  if (num_retvals != outputs.size()) {
-    return tensorflow::errors::InvalidArgument(
-        "Expecting ", num_retvals, " outputs but got ", outputs.size());
-  }
+  DCHECK_EQ(num_retvals, outputs.size());
   tensorflow::Device* op_device = IsCPU(device) ? nullptr : device;
   for (int i = 0; i < num_retvals; ++i) {
     tensorflow::Device* d = op_device;
@@ -1154,7 +1151,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
   }
   const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes();
-  if (output_dtypes.size() != *num_retvals) {
+  const int output_dtypes_size = output_dtypes.size();
+  if (output_dtypes_size > *num_retvals) {
     TF_SetStatus(status, TF_INVALID_ARGUMENT,
                  tensorflow::strings::StrCat("Expecting ", output_dtypes.size(),
                                              " outputs, but *num_retvals is ",
@@ -1162,6 +1160,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
                      .c_str());
     return;
   }
+  *num_retvals = output_dtypes_size;
   if (device == nullptr) {
     // TODO(apassos) debug how the assignment below might return a different
     // device from the one requested above.
index 316006b..a5029bf 100644 (file)
@@ -285,7 +285,8 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op,
 //
 // 'retvals' must point to a pre-allocated array of TFE_TensorHandle* and
 // '*num_retvals' should be set to the size of this array. It is an error if
-// the number of outputs is different from *num_retvals.
+// the size of 'retvals' is less than the number of outputs. This call sets
+// *num_retvals to the number of outputs.
 //
 // If async execution is enabled, the call may simply enqueue the execution
 // and return "non-ready" handles in `retvals`. Note that any handles contained
index 927d119..2268aba 100644 (file)
@@ -553,9 +553,10 @@ void Execute_MatMul_CPU(bool async) {
 
   TFE_TensorHandle* m = TestMatrixTensorHandle();
   TFE_Op* matmul = MatMulOp(ctx, m, m);
-  TFE_TensorHandle* retvals[1] = {nullptr};
-  int num_retvals = 1;
+  TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
+  int num_retvals = 2;
   TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+  EXPECT_EQ(1, num_retvals);
   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TFE_DeleteOp(matmul);
   TFE_DeleteTensorHandle(m);
index 012c68f..61c5526 100644 (file)
@@ -250,16 +250,16 @@ class TFETest(test_util.TensorFlowTestCase):
 
   def testExecuteTooManyNumOutputs(self):
     # num_outputs provided is 50, but only one output is produced.
-    with self.assertRaises(errors.InvalidArgumentError):
-      _ = execute(
-          b'Mul',
-          num_outputs=50,
-          inputs=[constant_op.constant(3),
-                  constant_op.constant(5)],
-          attrs=('T', dtypes.int32.as_datatype_enum))[0]
+    product = execute(
+        b'Mul',
+        num_outputs=50,
+        inputs=[constant_op.constant(3),
+                constant_op.constant(5)],
+        attrs=('T', dtypes.int32.as_datatype_enum))[0]
+    self.assertAllEqual(15, product)
 
   def testExecuteTooFewNumOutputs(self):
-    # num_outputs provided is 50, but only one output is produced.
+    # num_outputs provided is 0, but one output is produced.
     with self.assertRaises(errors.InvalidArgumentError):
       _ = execute(
           b'Mul',