*dev_stats->add_node_stats() = *maybe_stats;
}
}
- if (num_retvals != outputs.size()) {
- return tensorflow::errors::InvalidArgument(
- "Expecting ", num_retvals, " outputs but got ", outputs.size());
- }
+ DCHECK_EQ(num_retvals, outputs.size());
tensorflow::Device* op_device = IsCPU(device) ? nullptr : device;
for (int i = 0; i < num_retvals; ++i) {
tensorflow::Device* d = op_device;
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
}
const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes();
- if (output_dtypes.size() != *num_retvals) {
+ const int output_dtypes_size = output_dtypes.size();
+ if (output_dtypes_size > *num_retvals) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat("Expecting ", output_dtypes.size(),
" outputs, but *num_retvals is ",
.c_str());
return;
}
+ *num_retvals = output_dtypes_size;
if (device == nullptr) {
// TODO(apassos) debug how the assignment below might return a different
// device from the one requested above.
//
// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* and
// '*num_retvals' should be set to the size of this array. It is an error if
-// the number of outputs is different from *num_retvals.
+// the size of 'retvals' is less than the number of outputs. This call sets
+// *num_retvals to the number of outputs.
//
// If async execution is enabled, the call may simply enqueue the execution
// and return "non-ready" handles in `retvals`. Note that any handles contained
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
- TFE_TensorHandle* retvals[1] = {nullptr};
- int num_retvals = 1;
+ TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
+ int num_retvals = 2;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
def testExecuteTooManyNumOutputs(self):
# num_outputs provided is 50, but only one output is produced.
- with self.assertRaises(errors.InvalidArgumentError):
- _ = execute(
- b'Mul',
- num_outputs=50,
- inputs=[constant_op.constant(3),
- constant_op.constant(5)],
- attrs=('T', dtypes.int32.as_datatype_enum))[0]
+ product = execute(
+ b'Mul',
+ num_outputs=50,
+ inputs=[constant_op.constant(3),
+ constant_op.constant(5)],
+ attrs=('T', dtypes.int32.as_datatype_enum))[0]
+ self.assertAllEqual(15, product)
def testExecuteTooFewNumOutputs(self):
- # num_outputs provided is 50, but only one output is produced.
+ # num_outputs provided is 0, but one output is produced.
with self.assertRaises(errors.InvalidArgumentError):
_ = execute(
b'Mul',