From 315369aacd002d8c668b86a52f3cd88956a9b9a2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 12 Mar 2018 12:44:29 -0700 Subject: [PATCH] Extend TF Eager C API to allow asynchronous execution. PiperOrigin-RevId: 188763442 --- tensorflow/c/eager/BUILD | 1 + tensorflow/c/eager/c_api.cc | 824 +++++++++++++++++++++++------- tensorflow/c/eager/c_api.h | 58 ++- tensorflow/c/eager/c_api_internal.h | 206 +++++++- tensorflow/c/eager/c_api_test.cc | 380 +++++++++++--- tensorflow/c/eager/runtime.h | 3 +- tensorflow/python/eager/core_test.py | 24 +- tensorflow/python/eager/pywrap_tensor.cc | 6 +- tensorflow/python/eager/pywrap_tfe_src.cc | 9 +- tensorflow/python/lib/core/py_func.cc | 16 +- 10 files changed, 1222 insertions(+), 305 deletions(-) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index e55cb67..3046d90 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -58,6 +58,7 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:framework_lite", + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index b9a47ea..56cec2d 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/version.h" @@ -67,6 +68,7 @@ string DeviceName(const tensorflow::Device* d) { #ifdef TENSORFLOW_EAGER_USE_XLA std::atomic_int_fast64_t func_id_generator(0); #endif // TENSORFLOW_EAGER_USE_XLA + } // namespace TFE_ContextDevicePlacementPolicy PlacementPolicy( @@ -90,11 +92,33 @@ void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, TF_SetConfig(&options->session_options, proto, proto_len, status); } +void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options, + unsigned char async) { + options->async = async; +} void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { options->policy = policy; } +TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, + unsigned char async, + TF_Status* status) { + { + tensorflow::mutex_lock l(ctx->async_map_mu); + ctx->thread_local_async[std::this_thread::get_id()] = async; + } + if (async) { + ctx->executor.EnableAsync(); + } else { + // TODO(agarwal): Currently we add a wait here to handle cases where a sync + // op has a control dependency on an async op, and the latter has not + // executed yet. This wait can be removed by storing all the control inputs + // and waiting for them when executing ops. + status->status = ctx->executor.WaitForAllPendingNodes(); + } +} + void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { @@ -113,7 +137,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { } void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { - status->status = tensorflow::Status::OK(); + status->status = ctx->executor.WaitForAllPendingNodes(); { tensorflow::mutex_lock ml(ctx->cache_mu); tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); @@ -139,6 +163,9 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( ctx->thread_local_policies[std::this_thread::get_id()] = policy; } +// Note: this function looks up a thread local policy. So it should be called in +// the appropriate client thread. In particular, in async mode, it may not be +// safe to call this function from the async TFE_Executor threads. extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { tensorflow::mutex_lock ml(ctx->policy_map_mu); @@ -150,6 +177,18 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( return ctx->policy; } +void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { + status->status = ctx->executor.WaitForAllPendingNodes(); +} + +void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { + status->status = ctx->executor.status(); +} + +void TFE_ContextAsyncClearError(TFE_Context* ctx) { + ctx->executor.ClearError(); +} + TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; status->status = tensorflow::TF_TensorToTensor(t, &tensor); @@ -157,56 +196,70 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { return new TFE_TensorHandle(tensor, nullptr, nullptr); } -void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; } +void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { + DCHECK(h); + h->Unref(); +} TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { - return static_cast(h->t.dtype()); + return static_cast(h->dtype); } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { - status->status = tensorflow::Status::OK(); - return h->t.dims(); + const tensorflow::Tensor* t = nullptr; + status->status = h->Tensor(&t); + return t == nullptr ? 0 : t->dims(); } int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { - status->status = tensorflow::Status::OK(); - return h->t.dim_size(dim_index); + const tensorflow::Tensor* t = nullptr; + status->status = h->Tensor(&t); + return t == nullptr ? 0 : t->dim_size(dim_index); } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { - status->status = tensorflow::Status::OK(); - return (h->op_device == nullptr) - ? "/job:localhost/replica:0/task:0/device:CPU:0" - : h->op_device->name().c_str(); + tensorflow::Device* d = nullptr; + status->status = h->OpDevice(&d); + return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" + : d->name().c_str(); } TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { - if (!IsCPU(h->d)) { + // TODO(agarwal): move this implementation inside TFE_TensorHandle. + tensorflow::Device* d = nullptr; + tensorflow::Device* op_device = nullptr; + const tensorflow::Tensor* t = nullptr; + status->status = h->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) return nullptr; + if (!IsCPU(d)) { TF_SetStatus(status, TF_UNIMPLEMENTED, tensorflow::strings::StrCat( "TFE_TensorHandle can be resolved iff it is on CPU (this " "handle is on ", - h->d->name(), + d->name(), "). Consider using TFE_TensorHandleCopyToDevice to get a " "copy of the tensor on CPU") .c_str()); return nullptr; } - return tensorflow::TF_TensorFromTensor(h->t, status); + return tensorflow::TF_TensorFromTensor(*t, status); } +} // extern "C" -TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, - TFE_Context* ctx, - const char* device_name, - TF_Status* status) { - tensorflow::Device* dstd = ctx->devices[0]; - if (device_name != nullptr && strlen(device_name) > 0) { - status->status = ctx->device_manager->LookupDevice(device_name, &dstd); - if (!status->status.ok()) return nullptr; - } +namespace { - tensorflow::Device* srcd = h->d == nullptr ? ctx->devices[0] : h->d; +tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h, + TFE_Context* ctx, + tensorflow::Device* dstd, + TFE_TensorHandle** output) { + const tensorflow::Tensor* src = nullptr; + tensorflow::Device* srcd = nullptr; + // TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept + // nullptr. + tensorflow::Device* src_opd = nullptr; + TF_RETURN_IF_ERROR(h->TensorAndDevice(&src, &srcd, &src_opd)); + if (srcd == nullptr) srcd = ctx->devices[0]; bool is_same_device = (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd)); const bool dst_cpu = IsCPU(dstd); @@ -216,18 +269,15 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, const bool both_on_cpu = src_cpu && dst_cpu; if (is_same_device || both_on_cpu) { dstd = dst_cpu ? nullptr : dstd; - return new TFE_TensorHandle(h->t, dstd, dstd); + *output = new TFE_TensorHandle(*src, dstd, dstd); + return tensorflow::Status::OK(); } - tensorflow::Tensor* src = &(h->t); if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT && !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) { - TF_SetStatus( - status, TF_INVALID_ARGUMENT, - tensorflow::strings::StrCat("Can't copy Tensor with type ", - tensorflow::DataTypeString(src->dtype()), - " to device ", DeviceName(dstd), ".") - .c_str()); - return nullptr; + return tensorflow::errors::InvalidArgument( + "Can't copy Tensor with type ", + tensorflow::DataTypeString(src->dtype()), " to device ", + DeviceName(dstd), "."); } tensorflow::AllocatorAttributes attr; if (src->dtype() == tensorflow::DT_VARIANT) { @@ -236,7 +286,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape()); if (src->shape().num_elements() == 0) { dstd = dst_cpu ? nullptr : dstd; - return new TFE_TensorHandle(dst, dstd, dstd); + *output = new TFE_TensorHandle(dst, dstd, dstd); + return tensorflow::Status::OK(); } tensorflow::DeviceContext* src_device_context = nullptr; if (!src_cpu) { @@ -253,21 +304,26 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, // With that setup, Sync()ing across all 3 streams should be sufficient // but more than necessary (since it waits for operations that might have // nothing to do with this tensor to complete). - status->status = srcd->Sync(); + TF_RETURN_IF_ERROR(srcd->Sync()); tensorflow::Notification n; + tensorflow::Status status; tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context, srcd, dstd, tensorflow::AllocatorAttributes(), tensorflow::AllocatorAttributes(), src, &dst, - [status, &n](const tensorflow::Status& s) { - status->status = s; + [&status, &n](const tensorflow::Status& s) { + status = s; n.Notify(); }); n.WaitForNotification(); - return (TF_GetCode(status) == TF_OK) - ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd, - dst_cpu ? nullptr : dstd) - : nullptr; + if (status.ok()) { + dstd = dst_cpu ? nullptr : dstd; + *output = new TFE_TensorHandle(dst, dstd, dstd); + } + return status; } +} // namespace + +extern "C" { TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { @@ -311,16 +367,19 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { } void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { - // Questionable heuristic ... - // - If a device was explicitly set on the op, always use that. - // - If not, place on the first non-host device seen. - if (op->device == nullptr && !IsCPU(h->d)) { - op->device = h->d; + if (op->device == nullptr) { + // Questionable heuristic ... + // - If a device was explicitly set on the op, always use that. + // - If not, place on the first non-host device seen. + tensorflow::Device* d = nullptr; + // TODO(agarwal): This call may block if h is not ready. Avoid this if + // possible. + status->status = h->Device(&d); + if (!status->status.ok()) return; + if (!IsCPU(d)) op->device = d; } - if (!status->status.ok()) return; - op->inputs.push_back(h->t); - op->input_devices.push_back(h->d); - op->input_op_devices.push_back(h->op_device); + h->Ref(); + op->inputs.push_back(h); op->attrs.NumInputs(op->inputs.size()); } @@ -482,14 +541,14 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, tensorflow::gtl::ArraySlice( funcs.get(), num_values)); } +} // extern "C" namespace { tensorflow::Status ValidateInputTypeAndPlacement( TFE_Context* ctx, tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op, - const tensorflow::OpKernel* kernel, - std::vector* copied_tensors) { + const tensorflow::OpKernel* kernel) { const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); if (memtypes.size() != op->inputs.size()) { return tensorflow::errors::InvalidArgument( @@ -498,14 +557,17 @@ tensorflow::Status ValidateInputTypeAndPlacement( for (int i = 0; i < op->inputs.size(); ++i) { const tensorflow::Device* expected_device = memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device; + TFE_TensorHandle* handle = op->inputs[i]; + tensorflow::Device* handle_device = nullptr; + TF_RETURN_IF_ERROR(handle->Device(&handle_device)); const tensorflow::Device* actual_device = - op->input_devices[i] == nullptr ? host_device : op->input_devices[i]; + handle_device == nullptr ? host_device : handle_device; if (expected_device != actual_device) { switch (TFE_ContextGetDevicePlacementPolicy(ctx)) { case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32: // TODO(xpan): See if we could bubble python related error up // to python level. - if (op->inputs[i].dtype() == tensorflow::DT_INT32) { + if (handle->dtype == tensorflow::DT_INT32) { // Note: enabling silent copies of int32 tensors to match behavior // of graph mode. break; @@ -536,36 +598,245 @@ tensorflow::Status ValidateInputTypeAndPlacement( } // We are only here if the policy is warn or silent copies, so we should // trigger a copy. - TFE_TensorHandle original{op->inputs[i], op->input_devices[i], - op->device}; TF_Status* s = TF_NewStatus(); TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice( - &original, ctx, expected_device->name().c_str(), s); - if (!s->status.ok()) { - tensorflow::Status status = s->status; - delete s; + handle, ctx, expected_device->name().c_str(), s); + tensorflow::Status status = s->status; + TF_DeleteStatus(s); + if (!status.ok()) { + if (copied_tensor != nullptr) copied_tensor->Unref(); return tensorflow::errors::Internal( "Failed copying input tensor from ", actual_device->name(), " to ", expected_device->name(), " in order to run ", op->name, ": ", status.error_message()); } - op->inputs[i] = copied_tensor->t; - copied_tensors->push_back(copied_tensor); - op->input_devices[i] = copied_tensor->d; - delete s; + handle->Unref(); + handle = copied_tensor; + op->inputs[i] = copied_tensor; } - if (op->inputs[i].dtype() != kernel->input_type(i)) { + if (handle->dtype != kernel->input_type(i)) { return tensorflow::errors::InvalidArgument( "cannot compute ", op->name, " as input #", i, " was expected to be a ", tensorflow::DataTypeString(kernel->input_type(i)), - " tensor but is a ", - tensorflow::DataTypeString(op->inputs[i].dtype()), " tensor"); + " tensor but is a ", tensorflow::DataTypeString(handle->dtype), + " tensor"); } } return tensorflow::Status::OK(); } +tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, + TFE_Context* ctx, TF_Status* status) { + tensorflow::DeviceSet ds; + for (tensorflow::Device* d : ctx->devices) { + ds.AddDevice(d); + } + tensorflow::DeviceTypeVector final_devices; + status->status = tensorflow::SupportedDeviceTypesForNode( + ds.PrioritizedDeviceTypeList(), ndef, &final_devices); + if (!status->status.ok()) { + return nullptr; + } + if (final_devices.empty()) { + status->status = tensorflow::errors::Internal( + "Could not find valid device for node ", ndef.DebugString()); + return nullptr; + } + for (tensorflow::Device* d : ctx->devices) { + if (d->device_type() == final_devices[0].type_string()) { + return d; + } + } + status->status = tensorflow::errors::Unknown( + "Could not find a device for node ", ndef.DebugString()); + return nullptr; +} + +tensorflow::Status Execute( + TFE_Context* ctx, tensorflow::Device* device, + const tensorflow::gtl::InlinedVector& op_inputs, + tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats, + TFE_TensorHandle** retvals, int num_retvals) { + if (!ctx->soft_placement && device == nullptr) { + // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU + device = ctx->devices[0]; + } + + if (device == nullptr) { + // TODO(apassos) debug how the assignment below might return a different + // device from the one requested above. + device = kernel->device(); + } + + std::vector outputs(1); + const tensorflow::MemoryTypeVector* output_memory_types = nullptr; + output_memory_types = &kernel->kernel()->output_memory_types(); + std::vector inputs(op_inputs.size()); + for (int i = 0; i < op_inputs.size(); ++i) { + const tensorflow::Tensor* input_tensor = nullptr; + TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor)); + inputs[i] = *input_tensor; + } + // WARNING: kernel->Run utilizes the FunctionLibraryRuntime + // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, + // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation + // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by + // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. + // This is quite subtle. Re-work things to make this better? (Would it make + // sense for FunctionLibraryRuntime to ensure thread-safe access to + // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats + // for ops which are a part of functions. + // TODO(agarwal): change Run to take vector of handles ? + TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats)); + if (maybe_stats != nullptr) { + maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - + maybe_stats->all_start_micros()); + tensorflow::mutex_lock ml(ctx->metadata_mu); + if (ctx->should_store_metadata.load()) { + auto* step_stats = ctx->run_metadata.mutable_step_stats(); + // Lazily initialize the RunMetadata with information about all devices if + // this is the first call. + while (step_stats->dev_stats_size() < ctx->devices.size()) { + step_stats->add_dev_stats(); + } + // Find the current device's index. + int device_idx = 0; + for (int i = 0; i < ctx->devices.size(); ++i) { + if (ctx->devices[i] == device) { + device_idx = i; + break; + } + } + // Populate the device stats for this device. + auto* dev_stats = step_stats->mutable_dev_stats(device_idx); + dev_stats->set_device(device->name()); + *dev_stats->add_node_stats() = *maybe_stats; + } + } + if (num_retvals != outputs.size()) { + return tensorflow::errors::InvalidArgument( + "Expecting ", num_retvals, " outputs but got ", outputs.size()); + } + tensorflow::Device* op_device = IsCPU(device) ? nullptr : device; + for (int i = 0; i < num_retvals; ++i) { + tensorflow::Device* d = op_device; + if (d != nullptr && output_memory_types != nullptr && + (*output_memory_types)[i] == tensorflow::HOST_MEMORY) { + d = nullptr; + } + if (retvals[i] == nullptr) { + retvals[i] = new TFE_TensorHandle(outputs[i], d, op_device); + } else { + retvals[i]->SetTensorAndDevice(outputs[i], d, op_device); + } + } + return tensorflow::Status::OK(); +} + +// TODO(agarwal): move TFE_Executor and TFE_Node related code to a separate +// file. +class ExecuteNode : public TFE_Node { + public: + ExecuteNode(TFE_Op* op, tensorflow::KernelAndDevice* kernel, + tensorflow::NodeExecStats* maybe_stats, + const tensorflow::DataTypeVector& output_dtypes, + TFE_TensorHandle** retvals, int num_retvals) + : TFE_Node(op->ctx->executor.NextId()), + ctx_(op->ctx), + op_device_(op->device), + inputs_(op->inputs), + kernel_(kernel), + maybe_stats_(maybe_stats), + retvals_(num_retvals) { + for (auto handle : inputs_) { + handle->Ref(); + } + TFE_Context* ctx = op->ctx; + for (int i = 0; i < num_retvals; ++i) { + TFE_TensorHandle* h = new TFE_TensorHandle(id, output_dtypes[i], ctx); + h->Ref(); + retvals[i] = h; + retvals_[i] = h; + } + } + + ~ExecuteNode() override { + for (auto handle : inputs_) { + handle->Unref(); + } + for (auto handle : retvals_) { + handle->Unref(); + } + } + + tensorflow::Status Run() override { + const tensorflow::Status status = + Execute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(), + retvals_.begin(), retvals_.size()); + if (status.ok()) { + return status; + } else { + return tensorflow::Status( + status.code(), + tensorflow::strings::StrCat("Got error, \"", status.error_message(), + "\" while executing kernel ", + kernel_->kernel()->def().DebugString())); + } + } + + private: + TFE_Context* ctx_; + tensorflow::Device* op_device_; + tensorflow::gtl::InlinedVector inputs_; + tensorflow::KernelAndDevice* kernel_; + std::unique_ptr maybe_stats_; + tensorflow::gtl::InlinedVector retvals_; +}; + +class CopyToDeviceNode : public TFE_Node { + public: + CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd, + TFE_Context* ctx) + : TFE_Node(ctx->executor.NextId()), + src_(src), + dstd_(dstd), + ctx_(ctx), + dst_(new TFE_TensorHandle(id, src_->dtype, ctx)) { + src_->Ref(); + dst_->Ref(); + } + + ~CopyToDeviceNode() override { + src_->Unref(); + dst_->Unref(); + } + + tensorflow::Status Run() override { + TFE_TensorHandle* temp = nullptr; + TF_RETURN_IF_ERROR(TensorHandleCopyToDevice(src_, ctx_, dstd_, &temp)); + const tensorflow::Tensor* tensor = nullptr; + tensorflow::Device* device = nullptr; + tensorflow::Device* op_device = nullptr; + tensorflow::Status status = + temp->TensorAndDevice(&tensor, &device, &op_device); + // `temp` is a ready handle. So the following call should return OK. + TF_DCHECK_OK(status) << status.error_message(); + DCHECK(tensor); + dst_->SetTensorAndDevice(*tensor, device, op_device); + temp->Unref(); + return tensorflow::Status::OK(); + } + + TFE_TensorHandle* dst() { return dst_; } + + private: + TFE_TensorHandle* src_; + tensorflow::Device* dstd_; + TFE_Context* ctx_; + TFE_TensorHandle* dst_; +}; + #ifdef TENSORFLOW_EAGER_USE_XLA // Synthesizes and returns a wrapper function over `op`, which must be a // primitive op (e.g. matmul). @@ -631,7 +902,7 @@ const tensorflow::FunctionDef* OpToFunction( (*op_input_to_func_input)[i] = const_index; func_input_arg = signature->mutable_input_arg(const_index++); const_input_types->push_back( - static_cast(op->inputs[i].dtype())); + static_cast(op->inputs[i]->dtype)); } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) { VLOG(1) << "For resource input, mapping op input " << i << " to func input " << resource_index; @@ -643,11 +914,11 @@ const tensorflow::FunctionDef* OpToFunction( (*op_input_to_func_input)[i] = arg_index; func_input_arg = signature->mutable_input_arg(arg_index++); arg_input_types->push_back( - static_cast(op->inputs[i].dtype())); + static_cast(op->inputs[i]->dtype)); } func_input_arg->set_name(op_input_arg.name()); - func_input_arg->set_type(op->inputs[i].dtype()); + func_input_arg->set_type(op->inputs[i]->dtype); } VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString(); @@ -740,22 +1011,16 @@ std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { // Since input param reordering may have occurred between `op` and `launch_op` // via `op_input_to_func_input`, adjust the actual inputs accordingly. launch_op->inputs = op->inputs; - launch_op->input_devices = op->input_devices; - launch_op->input_op_devices = op->input_op_devices; + for (TFE_TensorHandle* h : launch_op->inputs) { + h->Ref(); + } if (!op_input_to_func_input.empty()) { DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size()); - if (!op->input_devices.empty()) { - DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size()); - } for (int i = 0; i < op_input_to_func_input.size(); ++i) { VLOG(1) << "mapping op input " << i << " to func input " << op_input_to_func_input[i]; launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i]; - if (!op->input_devices.empty()) { - launch_op->input_devices[op_input_to_func_input[i]] = - op->input_devices[i]; - } } } launch_op->attrs.NumInputs(op->inputs.size()); @@ -789,37 +1054,17 @@ std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { } #endif // TENSORFLOW_EAGER_USE_XLA -tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, - TFE_Context* ctx, TF_Status* status) { - tensorflow::DeviceSet ds; - for (tensorflow::Device* d : ctx->devices) { - ds.AddDevice(d); - } - tensorflow::DeviceTypeVector final_devices; - status->status = tensorflow::SupportedDeviceTypesForNode( - ds.PrioritizedDeviceTypeList(), ndef, &final_devices); - if (!status->status.ok()) { - return nullptr; - } - if (final_devices.empty()) { - status->status = tensorflow::errors::Internal( - "Could not find valid device for node ", ndef.DebugString()); - return nullptr; - } - for (tensorflow::Device* d : ctx->devices) { - if (d->device_type() == final_devices[0].type_string()) { - return d; - } - } - status->status = tensorflow::errors::Unknown( - "Could not find a device for node ", ndef.DebugString()); - return nullptr; -} - } // namespace +extern "C" { + void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { + TFE_Context* ctx = op->ctx; + status->status = ctx->executor.status(); + if (!status->status.ok()) { + return; + } #ifdef TENSORFLOW_EAGER_USE_XLA std::unique_ptr xla_launch_op; if (op->use_xla && op->name != "_XlaLaunch") { @@ -830,31 +1075,29 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, op = xla_launch_op.get(); } #endif // TENSORFLOW_EAGER_USE_XLA - TFE_Context* ctx = op->ctx; - tensorflow::Device* device = op->device; // Ensure all resource-touching ops run in the device the resource is, // regardless of anything else that has been specified. This is identical to // the graph mode behavior. for (int i = 0; i < op->inputs.size(); ++i) { - if (op->inputs[i].dtype() == tensorflow::DT_RESOURCE && - op->input_op_devices[i] != device) { - tensorflow::Device* d = op->input_op_devices[i] == nullptr - ? ctx->devices[0] - : op->input_op_devices[i]; + tensorflow::Device* input_op_device = nullptr; + status->status = op->inputs[i]->OpDevice(&input_op_device); + if (!status->status.ok()) return; + if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE && + input_op_device != op->device) { + tensorflow::Device* d = + input_op_device == nullptr ? ctx->devices[0] : input_op_device; VLOG(1) << "Changing device of operation " << op->name << " to " << d->name() << " because input #" << i << " is a resource in this device."; - device = d; op->device = d; } } + tensorflow::Device* device = op->device; if (!ctx->soft_placement && device == nullptr) { // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU device = ctx->devices[0]; } - std::vector outputs(1); - const tensorflow::MemoryTypeVector* output_memory_types = nullptr; tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name()); tensorflow::KernelAndDevice* kernel; @@ -879,8 +1122,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // Knowledge of the implementation of Init (and in-turn // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def // will be accessed, so grab on to the lock. - // See WARNING comment below - would be nice to rework to avoid this - // subtlety. + // See WARNING comment in Execute (before kernel->Run) - would be nice to + // rework to avoid this subtlety. tensorflow::tf_shared_lock l(ctx->functions_mu); status->status = tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel); @@ -903,29 +1146,30 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } tensorflow::DataTypeVector input_dtypes; status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes, - kernel->output_dtypes()); + kernel->mutable_output_dtypes()); if (!status->status.ok()) { return; } tensorflow::mutex_lock ml(ctx->cache_mu); tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); } + const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes(); + if (output_dtypes.size() != *num_retvals) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + tensorflow::strings::StrCat("Expecting ", output_dtypes.size(), + " outputs, but *num_retvals is ", + *num_retvals) + .c_str()); + return; + } if (device == nullptr) { // TODO(apassos) debug how the assignment below might return a different // device from the one requested above. device = kernel->device(); } - - std::vector copied_tensors; - status->status = ValidateInputTypeAndPlacement( - ctx, ctx->devices[0], device, op, kernel->kernel(), &copied_tensors); - output_memory_types = &kernel->kernel()->output_memory_types(); - if (!status->status.ok()) { - for (auto* t : copied_tensors) { - TFE_DeleteTensorHandle(t); - } - return; - } + status->status = ValidateInputTypeAndPlacement(ctx, ctx->devices[0], device, + op, kernel->kernel()); + if (!status->status.ok()) return; std::unique_ptr maybe_stats; if (ctx->should_store_metadata.load()) { maybe_stats.reset(new tensorflow::NodeExecStats); @@ -935,53 +1179,47 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros()); // TODO(apassos) track referenced tensors } - // WARNING: kernel->Run utilizes the FunctionLibraryRuntime - // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, - // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation - // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by - // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. - // This is quite subtle. Re-work things to make this better? (Would it make - // sense for FunctionLibraryRuntime to ensure thread-safe access to - // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats - // for ops which are a part of functions. - status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get()); - for (auto* t : copied_tensors) { - TFE_DeleteTensorHandle(t); - } - if (!status->status.ok()) return; - if (maybe_stats != nullptr) { - maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - - maybe_stats->all_start_micros()); - tensorflow::mutex_lock ml(ctx->metadata_mu); - if (ctx->should_store_metadata.load()) { - auto* step_stats = ctx->run_metadata.mutable_step_stats(); - // Lazily initialize the RunMetadata with information about all devices if - // this is the first call. - while (step_stats->dev_stats_size() < ctx->devices.size()) { - step_stats->add_dev_stats(); - } - // Find the current device's index. - int device_idx = 0; - for (int i = 0; i < ctx->devices.size(); ++i) { - if (ctx->devices[i] == device) { - device_idx = i; - break; - } - } - // Populate the device stats for this device. - auto* dev_stats = step_stats->mutable_dev_stats(device_idx); - dev_stats->set_device(device->name()); - *dev_stats->add_node_stats() = *maybe_stats; + if (ctx->Async()) { + // Note that for async mode, execution order will make sure that all + // input handles are ready before executing them. + // TODO(agarwal): Consider executing "cheap" kernels inline for performance. + TFE_Node* node = new ExecuteNode(op, kernel, maybe_stats.release(), + output_dtypes, retvals, *num_retvals); + ctx->executor.Add(node); + } else { + // Execute checks if retvals[i] is nullptr or not to figure if it needs to + // allocate it. + for (int i = 0; i < *num_retvals; ++i) { + retvals[i] = nullptr; } + status->status = Execute(op->ctx, op->device, op->inputs, kernel, + maybe_stats.get(), retvals, *num_retvals); } - *num_retvals = std::min(*num_retvals, outputs.size()); - for (int i = 0; i < *num_retvals; ++i) { - tensorflow::Device* d = IsCPU(device) ? nullptr : device; - if (d != nullptr && output_memory_types != nullptr && - (*output_memory_types)[i] == tensorflow::HOST_MEMORY) { - d = nullptr; - } - retvals[i] = new TFE_TensorHandle(outputs[i], d, device); +} + +TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, + TFE_Context* ctx, + const char* device_name, + TF_Status* status) { + status->status = ctx->executor.status(); + if (!status->status.ok()) { + return nullptr; + } + tensorflow::Device* dstd = ctx->devices[0]; + if (device_name != nullptr && strlen(device_name) > 0) { + status->status = ctx->device_manager->LookupDevice(device_name, &dstd); + if (!status->status.ok()) return nullptr; + } + if (ctx->Async()) { + // Note that `h` may not be currently ready. However execution order will + // make sure that `h` is ready before the copy is actually done. + CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx); + ctx->executor.Add(node); + return node->dst(); + } else { + TFE_TensorHandle* output = nullptr; + status->status = TensorHandleCopyToDevice(h, ctx, dstd, &output); + return output; } } @@ -1004,6 +1242,16 @@ void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, status->status = ctx->func_lib_def.AddFunctionDef(function->fdef); } +void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { + ctx->should_store_metadata.store(true); +} + +void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { + tensorflow::mutex_lock ml(ctx->metadata_mu); + ctx->should_store_metadata.store(false); + ctx->run_metadata.Clear(); +} + } // extern "C" TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { @@ -1012,27 +1260,24 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( TFE_TensorHandle* h, TF_Status* status) { - if (h->d != nullptr) { + tensorflow::Device* d = nullptr; + tensorflow::Device* op_device = nullptr; + const tensorflow::Tensor* t = nullptr; + status->status = h->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) return nullptr; + if (d != nullptr) { status->status = tensorflow::errors::FailedPrecondition( "TFE_TensorHandle is placed in device (not host) memory. Cannot return " "a tensorflow::Tensor"); return nullptr; } - return &h->t; -} - -void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->should_store_metadata.store(true); -} - -void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->metadata_mu); - ctx->should_store_metadata.store(false); - ctx->run_metadata.Clear(); + return t; } void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { + TFE_ContextAsyncWait(ctx, status); + if (!status->status.ok()) return; tensorflow::mutex_lock ml(ctx->metadata_mu); status->status = MessageToBuffer(ctx->run_metadata, buf); ctx->run_metadata.Clear(); @@ -1108,3 +1353,208 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, } } } // namespace tensorflow + +TFE_Node::TFE_Node(tensorflow::uint64 id) : id(id) {} + +TFE_Executor::~TFE_Executor() { + tensorflow::mutex_lock l(node_queue_mutex_); + thread_done_ = true; + nodes_pending_.notify_all(); +} + +tensorflow::uint64 TFE_Executor::NextId() { + tensorflow::mutex_lock l(next_id_mutex_); + return next_id_++; +} + +void TFE_Executor::EnableAsync() { + tensorflow::mutex_lock l(node_queue_mutex_); + if (thread_ == nullptr) { + thread_.reset(tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "eager_async_executor", + std::bind(&TFE_Executor::Run, this))); + } +} + +void TFE_Executor::Add(TFE_Node* node) { + tensorflow::mutex_lock l(node_queue_mutex_); + DCHECK(thread_) << "EnableAsync should have been called before Add"; + if (!status_.ok()) { + delete node; + return; + } + int qlen = node_queue_.size(); + if (qlen > 0) { + if (node_queue_.back()->id >= node->id) { + status_ = tensorflow::errors::InvalidArgument( + "Inserting TFE_Node with non-increasing ids:", node_queue_.back()->id, + " vs ", node->id); + delete node; + return; + } + node_queue_.push(node); + } else { + node_queue_.push(node); + nodes_pending_.notify_all(); + } +} + +tensorflow::Status TFE_Executor::WaitFor(tensorflow::uint64 node_id) { + return WaitImpl(false, node_id); +} + +tensorflow::Status TFE_Executor::WaitForAllPendingNodes() { + return WaitImpl(true, 0); +} + +tensorflow::Status TFE_Executor::WaitImpl(bool wait_all, + tensorflow::uint64 node_id) { + tensorflow::condition_variable cond; + tensorflow::mutex_lock l(node_queue_mutex_); + // Don't wait if an error is already set. + if (!status_.ok()) return status_; + if (node_queue_.empty()) return tensorflow::Status::OK(); + if (wait_all) { + node_id = node_queue_.back()->id; + } else if (node_id < node_queue_.front()->id) { + // Note that we are relying on the ops being dispatched sequentially from + // the queue. + return tensorflow::Status::OK(); + } + node_done_notifications_.insert(std::make_pair(node_id, &cond)); + cond.wait(l); + // Note that we could be woken up if an error occurs, even though the node has + // not actually executed. + return status_; +} + +void TFE_Executor::ClearError() { + tensorflow::mutex_lock l(node_queue_mutex_); + if (status_.ok()) return; + // If an error was set, node_done_notifications_ and node_queue_ should have + // been cleared, and no new entries should have been added since. + DCHECK(node_done_notifications_.empty()); + DCHECK(node_queue_.empty()); + status_ = tensorflow::Status::OK(); + nodes_pending_.notify_all(); +} + +tensorflow::Status TFE_Executor::status() { + tensorflow::mutex_lock l(node_queue_mutex_); + return status_; +} + +void TFE_Executor::Run() { + while (true) { + std::unique_ptr curr_node; + { + tensorflow::mutex_lock l(node_queue_mutex_); + while (node_queue_.empty() || !status_.ok()) { + if (thread_done_) return; + nodes_pending_.wait(l); + } + curr_node.reset(node_queue_.front()); + } + tensorflow::Status status = curr_node->Run(); + const bool ok = status.ok(); + tensorflow::mutex_lock l(node_queue_mutex_); + node_queue_.pop(); + if (!ok) { + status_ = status; + // TODO(agarwal): mark all affected handles as corrupted before clearing + // this queue. + // We remove any pending ops so that we don't try to execute them if + // ClearError is called. + for (int i = 0; i < node_queue_.size(); ++i) { + delete node_queue_.front(); + node_queue_.pop(); + } + } + if (!node_done_notifications_.empty()) { + tensorflow::uint64 node_id = curr_node->id; + // Note that we notify all waiting threads in case an error has occurred. + // These calling threads are responsible for checking status_ before + // proceeding. + const auto range = ok ? node_done_notifications_.equal_range(node_id) + : make_pair(node_done_notifications_.begin(), + node_done_notifications_.end()); + for (auto it = range.first; it != range.second; ++it) { + it->second->notify_all(); + } + node_done_notifications_.erase(range.first, range.second); + } + } +} + +bool TFE_Context::Async() const { + tensorflow::mutex_lock l(async_map_mu); + return tensorflow::gtl::FindWithDefault( + thread_local_async, std::this_thread::get_id(), async_default); +} + +bool TFE_TensorHandle::IsReady() { + if (node_id == 0) return true; + tensorflow::mutex_lock l(ctx_mutex_); + return ctx_ == nullptr; +} + +tensorflow::Status TFE_TensorHandle::WaitReady() { + if (node_id == 0) return tensorflow::Status::OK(); + TFE_Executor* executor = nullptr; + { + tensorflow::mutex_lock l(ctx_mutex_); + if (ctx_ == nullptr) return tensorflow::Status::OK(); + executor = &ctx_->executor; + } + return executor->WaitFor(node_id); +} + +tensorflow::Status TFE_TensorHandle::Tensor(const tensorflow::Tensor** t) { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *t = &tensor_; + return tensorflow::Status::OK(); +} + +tensorflow::Status TFE_TensorHandle::Device(tensorflow::Device** d) { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *d = device_; + return tensorflow::Status::OK(); +} + +tensorflow::Status TFE_TensorHandle::OpDevice(tensorflow::Device** d) { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *d = op_device_; + return tensorflow::Status::OK(); +} + +tensorflow::Status TFE_TensorHandle::TensorAndDevice( + const tensorflow::Tensor** tensor, tensorflow::Device** device, + tensorflow::Device** op_device) { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *tensor = &tensor_; + *device = device_; + *op_device = op_device_; + return tensorflow::Status::OK(); +} + +void TFE_TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor, + tensorflow::Device* device, + tensorflow::Device* op_device) { + tensorflow::mutex_lock l(ctx_mutex_); + DCHECK(node_id > 0 && ctx_) << "SetTensorAndDevice should be only called " + << "on non-ready handles."; + ctx_ = nullptr; + tensor_ = tensor; + device_ = device; + op_device_ = op_device; +} + +TFE_Op::~TFE_Op() { + for (TFE_TensorHandle* h : inputs) { + h->Unref(); + } +} diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 9610ca1..316006b 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -75,6 +75,11 @@ typedef enum TFE_ContextDevicePlacementPolicy { TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, } TFE_ContextDevicePlacementPolicy; +// Sets the default execution mode (sync/async). Note that this can be +// overridden per thread using TFE_ContextSetAsyncForThread. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, + unsigned char async); + TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); @@ -110,6 +115,30 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy( TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(TFE_Context*); +// Overrides the execution mode (sync/async) for the current thread. +TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, + unsigned char async, + TF_Status* status); + +// Causes the calling thread to block till all ops dispatched in async mode +// have been executed. Note that "execution" here refers to kernel execution / +// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee +// that lower level device queues (like GPU streams) have been flushed. +// +// This call may not block for execution of ops enqueued concurrently with this +// call. +TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context*, + TF_Status* status); + +// When an error happens, any pending operations are discarded and newly issued +// ops return an error. This call clears the error state and re-enables +// execution of newly issued ops. +// +// Note that outputs of discarded ops remain in a corrupt state and should not +// be used for future calls. +// TODO(agarwal): mark the affected handles and raise errors if they are used. +TF_CAPI_EXPORT extern void TFE_ContextAsyncClearError(TFE_Context*); + // A handle to a tensor on a device. // // Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape, @@ -119,15 +148,21 @@ typedef struct TFE_TensorHandle TFE_TensorHandle; TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status); +// Indicates that the caller will not be using `h` any more. TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); +// This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status); +// This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status); +// This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); + +// This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status); @@ -137,6 +172,9 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, // that shares the underlying buffer. Otherwise, it currently requires at least // one of the source or destination devices to be CPU (i.e., for the source or // destination tensor to be placed in host memory). +// If async execution is enabled, the copy may be enqueued and the call will +// return "non-ready" handle. Else, this function returns after the copy has +// been done. TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice( TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status); @@ -157,6 +195,7 @@ typedef struct TFE_Op TFE_Op; TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status); + TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, @@ -242,13 +281,20 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op, int num_values); // Execute the operation defined by 'op' and return handles to computed -// tensors in 'retvals'. +// tensors in `retvals`. +// +// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* and +// '*num_retvals' should be set to the size of this array. It is an error if +// the number of outputs is different from *num_retvals. // -// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* -// and '*num_retvals' should be set to the size of this array. +// If async execution is enabled, the call may simply enqueue the execution +// and return "non-ready" handles in `retvals`. Note that any handles contained +// in 'op' should not be mutated till the kernel execution actually finishes. // -// On return, 'num_retvals' will be set to the actual number of outputs -// returned by the operation. +// For sync execution, if any of the inputs to `op` are not ready, this call +// will block till they become ready and then return when the kernel execution +// is done. +// TODO(agarwal): change num_retvals to int from int*. TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status); @@ -274,6 +320,8 @@ TF_CAPI_EXPORT extern void TFE_ContextDisableRunMetadata(TFE_Context* ctx); // Populates the passed-in buffer with a serialized RunMetadata protocol buffer // containing any run metadata information accumulated so far and clears this // information. +// If async mode is enabled, this call blocks till all currently pending ops are +// done. TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 49b9434..8dba12f 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -19,7 +19,9 @@ limitations under the License. #include #include +#include #include +#include #include #include #include @@ -31,14 +33,113 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/version.h" +// A unit of execution for the TFE_Executor class below. Example subclasses +// encapsulate execution of a TFE_Op, or copying a TFE_TensorHandle from one +// device to another. +class TFE_Node { + public: + explicit TFE_Node(tensorflow::uint64 id); + + virtual ~TFE_Node() {} + + // Runs the computation corresponding to this node and blocks till the + // execution is done. + virtual tensorflow::Status Run() = 0; + + // An id unique to the TFE_Context under which this node is created. Allocated + // monotonically. + const tensorflow::uint64 id; +}; + +// A class for handling async execution (see TFE_ContextSetAsync). +// Note that this class is thread-safe. +// TODO(agarwal): TFE_OpAddInput may currently block if it tries to access the +// device of the input handle. Fix that. +// TODO(agarwal): On error, mark all affected handles as corrupted. +// TODO(agarwal): Implement support for control dependencies. +// TODO(agarwal): Support out-of-order execution and dispatching multiple +// TFE_Node in parallel. +// TODO(agarwal): Implement optimizations over TFE_Node traces. +class TFE_Executor { + public: + ~TFE_Executor(); + + // This is called whenever async mode is enabled. Note that it may be called + // multiple times as different calling threads may switch async mode on or off + // independently. + void EnableAsync(); + + // Helper function to create monotonically increasing ids unique to this + // object. + tensorflow::uint64 NextId(); + + // Schedules `node` for execution. + // Note that Add must be called in monotonically increasing order of node->id. + void Add(TFE_Node* node); + + // Causes the caller to block till node with id `node_id` has finished + // execution. + tensorflow::Status WaitFor(tensorflow::uint64 node_id); + + // Blocks till all currently pending ops are done. + tensorflow::Status WaitForAllPendingNodes(); + + // Clears all currently set errors which re-enables async execution. + void ClearError(); + + // Returns Status based on any errors that occurred during async execution. + tensorflow::Status status(); + + private: + // Starts execution of pending TFE_Nodes. This function loops till + // thread_done_ is set to true. If any errors are encontered, these are set + // inside `status_`. The loop blocks anytime there are no pending nodes, or if + // `status_` is not ok. + void Run(); + + tensorflow::Status WaitImpl(bool wait_all, tensorflow::uint64 node_id); + + tensorflow::mutex node_queue_mutex_; + + // Used to signal that some TFE_Nodes are pending execution. + tensorflow::condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_); + + // Queue of pending TFE_Nodes. + std::queue node_queue_ GUARDED_BY(node_queue_mutex_); + + // `status_` is set based on any errors raised during execution of a TFE_Node. + // It remains set until ClearError is called. + tensorflow::Status status_ GUARDED_BY(node_queue_mutex_); + + // Map from id of a TFE_Node to condition_variables (not owned by the map). + // These condition_variables are notified and removed when that TFE_Node is + // done executing, or if an error is found in execution of any TFE_Node. + std::multimap + node_done_notifications_ GUARDED_BY(node_queue_mutex_); + + // Thread object that calls the `Run` method. Currently we use only one thread + // for executing the TFE_Nodes one-by-one. + std::unique_ptr thread_ GUARDED_BY(node_queue_mutex_); + + // Indicates that `thread_` should stop as soon as it is done executing the + // current TFE_Node. + bool thread_done_ GUARDED_BY(node_queue_mutex_) = false; + + tensorflow::mutex next_id_mutex_; + tensorflow::uint64 next_id_ GUARDED_BY(next_id_mutex_) = 1; +}; + struct TFE_ContextOptions { TF_SessionOptions session_options; + // true if async execution is enabled. + bool async = false; TFE_ContextDevicePlacementPolicy policy{ TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32}; }; @@ -60,7 +161,10 @@ struct TFE_Context { device_manager.get(), opts.session_options.options.env, TF_GRAPH_DEF_VERSION, &func_lib_def, {})), log_device_placement( - opts.session_options.options.config.log_device_placement()) {} + opts.session_options.options.config.log_device_placement()), + async_default(opts.async) { + if (async_default) executor.EnableAsync(); + } const bool soft_placement; const TFE_ContextDevicePlacementPolicy policy; @@ -98,29 +202,99 @@ struct TFE_Context { std::atomic should_store_metadata{false}; tensorflow::mutex metadata_mu; tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu); - const bool log_device_placement; + // TFE_Executor for async execution. + TFE_Executor executor; + + // True if running in asynchronous mode. + bool Async() const; + + // True if the default value for execution mode is async. Note that this value + // can be overridden per thread based on `thread_local_async` overrides. + const bool async_default; + mutable tensorflow::mutex async_map_mu; + std::unordered_map thread_local_async + GUARDED_BY(async_map_mu); }; -struct TFE_TensorHandle { +struct TFE_TensorHandle : public tensorflow::core::RefCounted { + public: TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d, tensorflow::Device* op_device) - : t(t), d(d), op_device(op_device) {} + : dtype(t.dtype()), + node_id(0), + tensor_(t), + device_(d), + op_device_(op_device), + ctx_(nullptr) {} + + TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype, + TFE_Context* ctx) + : dtype(dtype), + node_id(node_id), + tensor_(dtype), + device_(nullptr), + op_device_(nullptr), + ctx_(ctx) { + DCHECK_GT(node_id, 0); + } + + ~TFE_TensorHandle() override {} + + tensorflow::Status Tensor(const tensorflow::Tensor** t); + + tensorflow::Status Device(tensorflow::Device** d); - tensorflow::Tensor t; - // TODO(ashankar): d == nullptr iff local CPU - // This was expedient, but perhaps worth revisiting ('d' should always be a - // valid pointer?) + tensorflow::Status OpDevice(tensorflow::Device** d); + + tensorflow::Status TensorAndDevice(const tensorflow::Tensor** tensor, + tensorflow::Device** device, + tensorflow::Device** op_device); + + // Note that this can be called at most once, and only on non-ready handles, + // and makes them ready. + void SetTensorAndDevice(const tensorflow::Tensor& tensor, + tensorflow::Device* device, + tensorflow::Device* op_device); + + // dtype for the handle. It must be the same as t.dtype() once the handle is + // ready. + const tensorflow::DataType dtype; + + private: + // If the contents of the Tensor pointed to by this handle is yet to be + // computed by a TFE_Node, this function will block till that compuatation is + // done and the handle is "ready". + tensorflow::Status WaitReady(); + + bool IsReady(); + + // Id for the TFE_Node that will compute the value pointed to by this handle. + // If the value is 0, the handle is already ready, but not vice-versa. + const tensorflow::uint64 node_id; + + tensorflow::Tensor tensor_; + + // TODO(ashankar): device_ == nullptr iff local CPU + // This was expedient, but perhaps worth revisiting ('device_' should always + // be a valid pointer?) // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are // provided with the appropriate TFE_Context. // - // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a + // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a // TFE_TensorHandle does not outlive the TFE_Context from which it came? - tensorflow::Device* d; + tensorflow::Device* device_; + + // Device in which the op producing this tensor was executed. Equals to + // device_ for constant tensors. + tensorflow::Device* op_device_; - // Device in which the op producing this tensor was executed. Equals to d for - // constant tensors. - tensorflow::Device* op_device; + tensorflow::mutex ctx_mutex_; + + // `ctx` is only guaranteed to be set if the handle is not "ready". This is + // typically true when the handle was produced during async execution. + // `ctx` object is not owned and should outlive this handle. + TFE_Context* ctx_ GUARDED_BY(ctx_mutex_); }; struct TFE_Op { @@ -129,15 +303,15 @@ struct TFE_Op { TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} + ~TFE_Op(); + bool const is_function() const { return attr_types == nullptr; } TFE_Context* ctx; // Must outlive the TFE_Op. const tensorflow::string name; tensorflow::AttrBuilder attrs; const tensorflow::AttrTypeMap* attr_types; - std::vector inputs; - std::vector input_devices; - std::vector input_op_devices; + tensorflow::gtl::InlinedVector inputs; tensorflow::Device* device; bool use_xla = false; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 00fb7e6..927d119 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -29,6 +29,20 @@ using tensorflow::string; namespace { +TFE_TensorHandle* DoubleTestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + TFE_TensorHandle* TestMatrixTensorHandle() { int64_t dims[] = {2, 2}; float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -43,6 +57,20 @@ TFE_TensorHandle* TestMatrixTensorHandle() { return th; } +TFE_TensorHandle* TestMatrixTensorHandle3X2() { + int64_t dims[] = {3, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { TF_Status* status = TF_NewStatus(); @@ -139,10 +167,12 @@ void BM_InitOp(int iters) { } BENCHMARK(BM_InitOp); -void BM_Execute(int iters) { +void BM_Execute(int iters, int async) { tensorflow::testing::StopTiming(); + tensorflow::testing::SetLabel(async ? "ExecuteAsync" : "Execute"); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -156,6 +186,9 @@ void BM_Execute(int iters) { TFE_Execute(matmul, &retvals[0], &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } + if (async) { + TFE_ContextAsyncWait(ctx, status); + } tensorflow::testing::StopTiming(); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); @@ -163,7 +196,7 @@ void BM_Execute(int iters) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } -BENCHMARK(BM_Execute); +BENCHMARK(BM_Execute)->Arg(0)->Arg(1); TEST(CAPI, Context) { TF_Status* status = TF_NewStatus(); @@ -205,10 +238,11 @@ TEST(CAPI, TensorHandle) { TFE_DeleteTensorHandle(h); } -TEST(CAPI, TensorHandleCopyBetweenDevices) { +void TensorHandleCopyBetweenDevices(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -274,10 +308,56 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { +TEST(CAPI, TensorHandleCopyBetweenDevices) { + TensorHandleCopyBetweenDevices(false); +} + +TEST(CAPI, TensorHandleCopyBetweenDevicesAsync) { + TensorHandleCopyBetweenDevices(true); +} + +void TensorHandleCopyBetweenDevicesError(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + const char* kErrorDevice = "NoSuchDevice:0"; + TFE_TensorHandle* hdevice = + TFE_TensorHandleCopyToDevice(hcpu, ctx, kErrorDevice, status.get()); + EXPECT_NE(TF_OK, TF_GetCode(status.get())); + const char* msg = "NoSuchDevice:0 unknown device"; + EXPECT_TRUE(strstr(TF_Message(status.get()), msg) != nullptr) + << TF_Message(status.get()); + TF_SetStatus(status.get(), TF_OK, ""); + const char* kCPUDevice = "CPU:0"; + TFE_TensorHandle* hcopy = + TFE_TensorHandleCopyToDevice(hcpu, ctx, kCPUDevice, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())); + TFE_DeleteTensorHandle(hcopy); + TFE_DeleteTensorHandle(hcpu); + if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice); + TFE_DeleteContext(ctx, status.get()); +} + +TEST(CAPI, TensorHandleCopyBetweenDevicesError) { + TensorHandleCopyBetweenDevicesError(false); +} + +TEST(CAPI, TensorHandleCopyBetweenDevicesErrorAsync) { + TensorHandleCopyBetweenDevicesError(true); +} + +void TensorHandleCopyBetweenTwoGPUDevices(bool async) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -332,11 +412,20 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, TensorHandleSilentCopy) { +TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { + TensorHandleCopyBetweenTwoGPUDevices(false); +} + +TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) { + TensorHandleCopyBetweenTwoGPUDevices(true); +} + +void TensorHandleSilentCopy(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -366,14 +455,20 @@ TEST(CAPI, TensorHandleSilentCopy) { TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContext(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, TensorHandleSilentCopyLocal) { +TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); } +TEST(CAPI, TensorHandleSilentCopyAsync) { TensorHandleSilentCopy(true); } + +void TensorHandleSilentCopyLocal(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_EXPLICIT); TFE_Context* ctx = TFE_NewContext(opts, status.get()); @@ -407,11 +502,17 @@ TEST(CAPI, TensorHandleSilentCopyLocal) { TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContext(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } +TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); } +TEST(CAPI, TensorHandleSilentCopyLocalAsync) { + TensorHandleSilentCopyLocal(true); +} -TEST(CAPI, SetAndGetOpDevices) { +void SetAndGetOpDevices(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); @@ -442,27 +543,27 @@ TEST(CAPI, SetAndGetOpDevices) { TF_DeleteStatus(status); } -TEST(CAPI, Execute_MatMul_CPU) { +void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(matmul, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - ASSERT_EQ(1, num_retvals); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -474,7 +575,101 @@ TEST(CAPI, Execute_MatMul_CPU) { EXPECT_EQ(22, product[3]); TF_DeleteStatus(status); } +TEST(CAPI, Execute_MatMul_CPU) { Execute_MatMul_CPU(false); } +TEST(CAPI, Execute_MatMul_CPUAsync) { Execute_MatMul_CPU(true); } + +void Execute_MatMul_CPU_Runtime_Error(bool async) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m1 = TestMatrixTensorHandle(); + TFE_TensorHandle* m2 = TestMatrixTensorHandle3X2(); + TFE_Op* matmul = MatMulOp(ctx, m1, m2); + TFE_Op* matmul2 = MatMulOp(ctx, m1, m1); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + TFE_DeleteOp(matmul); + if (!async) { + EXPECT_NE(TF_OK, TF_GetCode(status)); + } else { + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + EXPECT_NE(TF_OK, TF_GetCode(status)); + EXPECT_EQ(nullptr, t); + const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]"; + EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr) + << TF_Message(status); + // Since error is not cleared, the following copy with correct device will + // still fail. + TF_SetStatus(status, TF_OK, ""); + TFE_DeleteTensorHandle(retvals[0]); + retvals[0] = nullptr; + TFE_Execute(matmul2, &retvals[0], &num_retvals, status); + EXPECT_NE(TF_OK, TF_GetCode(status)); + TFE_ContextAsyncClearError(ctx); + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + } + // Following works in async mode since TFE_ContextAsyncClearError was called. + TF_SetStatus(status, TF_OK, ""); + if (retvals[0] != nullptr) { + TFE_DeleteTensorHandle(retvals[0]); + } + retvals[0] = nullptr; + TFE_Execute(matmul2, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteTensor(t); + TFE_DeleteOp(matmul2); + TFE_DeleteTensorHandle(m1); + TFE_DeleteTensorHandle(m2); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx, status); + TF_DeleteStatus(status); +} +TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) { + Execute_MatMul_CPU_Runtime_Error(false); +} +TEST(CAPI, Execute_MatMul_CPU_Runtime_ErrorAsync) { + Execute_MatMul_CPU_Runtime_Error(true); +} + +void Execute_MatMul_CPU_Type_Error(bool async) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m1 = TestMatrixTensorHandle(); + TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m1, m2); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_NE(TF_OK, TF_GetCode(status)); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m1); + TFE_DeleteTensorHandle(m2); + if (retvals[0] != nullptr) { + TFE_DeleteTensorHandle(retvals[0]); + } + TFE_DeleteContext(ctx, status); + TF_DeleteStatus(status); +} +TEST(CAPI, Execute_MatMul_CPU_Type_Error) { + Execute_MatMul_CPU_Type_Error(false); +} +TEST(CAPI, Execute_MatMul_CPU_Type_ErrorAsync) { + Execute_MatMul_CPU_Type_Error(true); +} TEST(CAPI, Execute_Min_CPU) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -485,8 +680,8 @@ TEST(CAPI, Execute_Min_CPU) { TFE_TensorHandle* input = TestMatrixTensorHandle(); TFE_TensorHandle* axis = TestAxisTensorHandle(); TFE_Op* minOp = MinOp(ctx, input, axis); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(minOp, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(minOp); @@ -509,9 +704,10 @@ TEST(CAPI, Execute_Min_CPU) { } #ifdef TENSORFLOW_EAGER_USE_XLA -TEST(CAPI, Execute_MatMul_XLA_CPU) { +void Execute_MatMul_XLA_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -521,15 +717,14 @@ TEST(CAPI, Execute_MatMul_XLA_CPU) { TFE_OpSetXLACompilation(matmul, true); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(matmul, &retvals[0], &num_retvals, status); // Running a primitive TF operator via XLA is not yet supported. ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(1, num_retvals); @@ -545,13 +740,16 @@ TEST(CAPI, Execute_MatMul_XLA_CPU) { EXPECT_EQ(10, product[1]); EXPECT_EQ(15, product[2]); EXPECT_EQ(22, product[3]); - + TFE_DeleteContext(ctx, status); TF_DeleteStatus(status); } +TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); } +TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); } -TEST(CAPI, Execute_Min_XLA_CPU) { +void Execute_Min_XLA_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -562,14 +760,13 @@ TEST(CAPI, Execute_Min_XLA_CPU) { TFE_OpSetXLACompilation(minOp, true); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(minOp, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(minOp); TFE_DeleteTensorHandle(input); TFE_DeleteTensorHandle(axis); - TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); @@ -582,13 +779,17 @@ TEST(CAPI, Execute_Min_XLA_CPU) { TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); + TFE_DeleteContext(ctx, status); TF_DeleteStatus(status); } +TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); } +TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); } #endif // TENSORFLOW_EAGER_USE_XLA -TEST(CAPI, ExecuteWithTracing) { +void ExecuteWithTracing(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); TFE_ContextEnableRunMetadata(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -596,8 +797,8 @@ TEST(CAPI, ExecuteWithTracing) { TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(matmul, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); @@ -609,12 +810,12 @@ TEST(CAPI, ExecuteWithTracing) { EXPECT_TRUE( rm.ParseFromString({reinterpret_cast(b->data), b->length})); TF_DeleteBuffer(b); - TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -626,6 +827,8 @@ TEST(CAPI, ExecuteWithTracing) { EXPECT_EQ(22, product[3]); TF_DeleteStatus(status); } +TEST(CAPI, ExecuteWithTracing) { ExecuteWithTracing(false); } +TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithTracing(true); } TEST(CAPI, Function_ident_CPU) { // First create a simple identity function. @@ -657,32 +860,37 @@ TEST(CAPI, Function_ident_CPU) { ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteFunction(fn); - TF_Tensor* t = - TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); - *reinterpret_cast(TF_TensorData(t)) = 42; - TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteTensor(t); + for (bool async : {false, true, false}) { + TFE_ContextSetAsyncForThread(ctx, static_cast(async), + status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK); + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); - TFE_Op* op = TFE_NewOp(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_OpAddInput(op, h, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - std::vector result; - result.push_back(nullptr); - int num_retvals = 1; - TFE_Execute(op, result.data(), &num_retvals, status); - TFE_DeleteOp(op); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - ASSERT_EQ(num_retvals, 1); + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); - TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); - TFE_DeleteTensorHandle(h); - TF_DeleteTensor(r); - TFE_DeleteTensorHandle(result[0]); + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + } TFE_DeleteContext(ctx, status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); @@ -719,35 +927,40 @@ TEST(CAPI, Function_ident_XLA_CPU) { ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteFunction(fn); - TF_Tensor* t = - TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); - *reinterpret_cast(TF_TensorData(t)) = 42; - TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteTensor(t); + for (bool async : {false, true, false}) { + TFE_ContextSetAsyncForThread(ctx, static_cast(async), + status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK); + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); - TFE_Op* op = TFE_NewOp(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_OpAddInput(op, h, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - // Now run it via XLA. - TFE_OpSetXLACompilation(op, true); + // Now run it via XLA. + TFE_OpSetXLACompilation(op, true); - std::vector result; - result.push_back(nullptr); - int num_retvals = 1; - TFE_Execute(op, result.data(), &num_retvals, status); - TFE_DeleteOp(op); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - ASSERT_EQ(num_retvals, 1); + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); - TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); - TFE_DeleteTensorHandle(h); - TF_DeleteTensor(r); - TFE_DeleteTensorHandle(result[0]); + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + } TFE_DeleteContext(ctx, status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); @@ -788,9 +1001,10 @@ string MatMulFunction() { return def.SerializeAsString(); } -TEST(CAPI, FunctionDefAndExecute) { +void FunctionDefAndExecute(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -827,11 +1041,16 @@ TEST(CAPI, FunctionDefAndExecute) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } +TEST(CAPI, FunctionDefAndExecute) { FunctionDefAndExecute(false); } +TEST(CAPI, FunctionDefAndExecuteAsync) { FunctionDefAndExecute(true); } -void BM_ExecuteFunction(int iters) { +void BM_ExecuteFunction(int iters, int async) { tensorflow::testing::StopTiming(); + tensorflow::testing::SetLabel(async ? "ExecuteFunctionAsync" + : "ExecuteFunction"); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -853,6 +1072,9 @@ void BM_ExecuteFunction(int iters) { TFE_Execute(matmul, &retval[0], &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } + if (async) { + TFE_ContextAsyncWait(ctx, status); + } tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); TFE_DeleteTensorHandle(retval[0]); @@ -860,7 +1082,7 @@ void BM_ExecuteFunction(int iters) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } -BENCHMARK(BM_ExecuteFunction); +BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1); TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, TF_Status* status) { diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index 985ed96..ad16f65 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -185,7 +185,8 @@ class KernelAndDevice { Device* device() const { return device_; } - DataTypeVector* output_dtypes() { return &output_dtypes_; } + DataTypeVector* mutable_output_dtypes() { return &output_dtypes_; } + const DataTypeVector& output_dtypes() { return output_dtypes_; } private: std::unique_ptr kernel_; diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index d504ca0..012c68f 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -250,13 +250,23 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteTooManyNumOutputs(self): # num_outputs provided is 50, but only one output is produced. - # That should be okay. - product = execute( - b'Mul', - num_outputs=50, - inputs=[constant_op.constant(3), constant_op.constant(5)], - attrs=('T', dtypes.int32.as_datatype_enum))[0] - self.assertAllEqual(15, product) + with self.assertRaises(errors.InvalidArgumentError): + _ = execute( + b'Mul', + num_outputs=50, + inputs=[constant_op.constant(3), + constant_op.constant(5)], + attrs=('T', dtypes.int32.as_datatype_enum))[0] + + def testExecuteTooFewNumOutputs(self): + # num_outputs provided is 50, but only one output is produced. + with self.assertRaises(errors.InvalidArgumentError): + _ = execute( + b'Mul', + num_outputs=0, + inputs=[constant_op.constant(3), + constant_op.constant(5)], + attrs=('T', dtypes.int32.as_datatype_enum))[0] def testMatMulGPU(self): if not context.context().num_gpus(): diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 8338bc4..105c09e 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -340,8 +340,10 @@ void EagerTensor_dealloc(EagerTensor* self) { Py_DECREF(self->handle_data); Py_DECREF(self->keras_mask); Py_DECREF(self->tensor_shape); - TFE_DeleteTensorHandle(self->handle); - self->handle = nullptr; + if (self->handle != nullptr) { + TFE_DeleteTensorHandle(self->handle); + self->handle = nullptr; + } // We have the global interpreter lock, so use this chance to perform delayed // refcount decrements. tensorflow::ClearDecrefCache(); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index fcb0452..fe9785d 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1012,7 +1012,14 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { TFE_TensorHandle* t = EagerTensor_Handle(tensor); tensorflow::int64 id = EagerTensor_id(tensor); - return tensorflow::eager::TapeTensor{id, t->t.dtype(), t->t.shape()}; + const tensorflow::Tensor* tensor = nullptr; + const tensorflow::Status status = t->Tensor(&tensor); + if (MaybeRaiseExceptionFromStatus(status, nullptr)) { + return tensorflow::eager::TapeTensor{id, t->dtype, + tensorflow::TensorShape({})}; + } else { + return tensorflow::eager::TapeTensor{id, t->dtype, tensor->shape()}; + } } tensorflow::int64 id = FastTensorId(tensor); if (PyErr_Occurred()) { diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 343415b..02eafd4 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -164,9 +164,9 @@ bool IsSingleNone(PyObject* obj) { } // Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`. -void ExtractTensorFromEagerTensor(const PyObject* eager_tensor, - Tensor* output_tensor) { - *output_tensor = EagerTensor_Handle(eager_tensor)->t; +tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor, + const Tensor** output_tensor) { + return EagerTensor_Handle(eager_tensor)->Tensor(output_tensor); } // Calls the registered py function through the trampoline. @@ -220,7 +220,9 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { if (call->eager) { const PyObject* item = PyList_GetItem(result, i); if (EagerTensor_CheckExact(item)) { - ExtractTensorFromEagerTensor(item, &t); + const Tensor* tensor = nullptr; + s = ExtractTensorFromEagerTensor(item, &tensor); + if (s.ok()) t = *tensor; } else { s = errors::FailedPrecondition( "Expected EagerTensor, found PyObject of type: ", @@ -238,10 +240,10 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { } else if (EagerTensor_CheckExact(result) || result == Py_None) { // result is an `EagerTensor` or `None`. DCHECK(call->eager); - Tensor t; if (result != Py_None) { - ExtractTensorFromEagerTensor(result, &t); - call->out.push_back(t); + const Tensor* t = nullptr; + s = ExtractTensorFromEagerTensor(result, &t); + if (s.ok()) call->out.push_back(*t); } } else if (PyArray_Check(result)) { // `result` is a NumPy array. -- 2.7.4