CapturedFunction::~CapturedFunction() {}
-Status CapturedFunction::set_lib(FunctionLibraryRuntime* lib) {
- mutex_lock l(mu_);
- if (lib_ == nullptr) {
- lib_ = lib;
- return Status::OK();
- }
- if (lib != lib_) {
- return errors::Internal(
- "Captured function was called with a different "
- "FunctionLibraryRuntime*, which is not permitted.");
- }
- return Status::OK();
-}
-
namespace {
class CallFrameBase : public CallFrameInterface {
public:
} // namespace
Status CapturedFunction::MaybeInstantiate(
- FunctionLibraryRuntime* lib,
- FunctionLibraryRuntime::InstantiateOptions inst_opts) {
- TF_RETURN_IF_ERROR(set_lib(lib));
- inst_opts.state_handle = std::to_string(random::New64());
+ IteratorContext* ctx, FunctionLibraryRuntime::Handle* out_handle) {
mutex_lock l(mu_);
- if (f_handle_ == kInvalidHandle) {
+ if (lib_ == nullptr) {
+ // The context's runtime will be used for all subsequent calls.
+ lib_ = ctx->lib();
+ DCHECK(f_handle_ == kInvalidHandle);
+ FunctionLibraryRuntime::InstantiateOptions inst_opts;
+ inst_opts.overlay_lib = ctx->function_library().get();
+ inst_opts.state_handle = std::to_string(random::New64());
TF_RETURN_IF_ERROR(lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
inst_opts, &f_handle_));
+ const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
+ if (fbody == nullptr) {
+ return errors::Internal("Failed to instantiate function body.");
+ }
+ ret_types_ = fbody->ret_types;
+ } else {
+ // TODO(mrry): Consider moving this under a shared lock, as it is
+ // the common case.
+ if (ctx->lib() != lib_) {
+ return errors::Internal(
+ "Captured function was called with a different "
+ "FunctionLibraryRuntime*, which is not permitted.");
+ }
}
- const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
- if (fbody == nullptr) {
- return errors::Internal("Failed to instantiate function body.");
- }
- ret_types_ = fbody->ret_types;
+ *out_handle = f_handle_;
return Status::OK();
}
Status CapturedFunction::Run(IteratorContext* ctx,
- FunctionLibraryRuntime::Options f_opts,
std::vector<Tensor>&& args,
std::vector<Tensor>* rets) {
- FunctionLibraryRuntime::InstantiateOptions inst_opts;
- inst_opts.overlay_lib = ctx->function_library().get();
- TF_RETURN_IF_ERROR(MaybeInstantiate(ctx->lib(), inst_opts));
+ FunctionLibraryRuntime::Handle handle;
+ TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
+
+ FunctionLibraryRuntime::Options f_opts;
+ f_opts.step_id = CapturedFunction::generate_step_id();
+ ScopedStepContainer step_container(f_opts.step_id, [ctx](const string& name) {
+ ctx->lib()->device()->resource_manager()->Cleanup(name).IgnoreError();
+ });
+ f_opts.step_container = &step_container;
+ f_opts.runner = ctx->runner();
// TODO(mrry): Add cancellation manager support to IteratorContext
// so that we can cancel running map functions. The local
// cancellation manager here is created so that we can run kernels
// (such as queue kernels) that depend on the non-nullness of
// `OpKernelContext::cancellation_manager()`, but additional effort
// will be required to plumb it through the `IteratorContext`.
- auto c_mgr = new CancellationManager;
- auto frame =
- new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_);
- f_opts.cancellation_manager = c_mgr;
+ CancellationManager c_mgr;
+ f_opts.cancellation_manager = &c_mgr;
+
+ OwnedArgsCallFrame frame(std::move(args), &captured_inputs_, ret_types_);
Notification n;
Status s;
- mutex_lock l(mu_);
- lib_->Run(f_opts, f_handle_, frame,
- [rets, c_mgr, frame, &n, &s](Status func_status) {
- delete c_mgr;
- s.Update(func_status);
- if (s.ok()) {
- s = frame->ConsumeRetvals(rets);
- }
- delete frame;
- n.Notify();
- });
+ ctx->lib()->Run(f_opts, handle, &frame, [&n, &s](Status func_status) {
+ s.Update(func_status);
+ n.Notify();
+ });
n.WaitForNotification();
- return s;
+ TF_RETURN_IF_ERROR(s);
+ return frame.ConsumeRetvals(rets);
}
-Status CapturedFunction::RunWithBorrowedArgs(
- IteratorContext* ctx, FunctionLibraryRuntime::Options f_opts,
- const std::vector<Tensor>& args, std::vector<Tensor>* rets) {
- FunctionLibraryRuntime::InstantiateOptions inst_opts;
- inst_opts.overlay_lib = ctx->function_library().get();
- TF_RETURN_IF_ERROR(MaybeInstantiate(ctx->lib(), inst_opts));
+Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
+ const std::vector<Tensor>& args,
+ std::vector<Tensor>* rets) {
+ FunctionLibraryRuntime::Handle handle;
+ TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
+
+ FunctionLibraryRuntime::Options f_opts;
+ f_opts.step_id = CapturedFunction::generate_step_id();
+ ScopedStepContainer step_container(f_opts.step_id, [ctx](const string& name) {
+ ctx->lib()->device()->resource_manager()->Cleanup(name).IgnoreError();
+ });
+ f_opts.step_container = &step_container;
+ f_opts.runner = ctx->runner();
// TODO(mrry): Add cancellation manager support to IteratorContext
// so that we can cancel running map functions. The local
// cancellation manager here is created so that we can run kernels
// (such as queue kernels) that depend on the non-nullness of
// `OpKernelContext::cancellation_manager()`, but additional effort
// will be required to plumb it through the `IteratorContext`.
- auto c_mgr = new CancellationManager;
+ CancellationManager c_mgr;
+ f_opts.cancellation_manager = &c_mgr;
+
BorrowedArgsCallFrame frame(args, &captured_inputs_, ret_types_);
- f_opts.cancellation_manager = c_mgr;
Notification n;
Status s;
- mutex_lock l(mu_);
- lib_->Run(f_opts, f_handle_, &frame,
- [rets, c_mgr, &frame, &n, &s](Status func_status) {
- delete c_mgr;
- s.Update(func_status);
- if (s.ok()) {
- s = frame.ConsumeRetvals(rets);
- }
- n.Notify();
- });
+ ctx->lib()->Run(f_opts, handle, &frame, [&n, &s](Status func_status) {
+ s.Update(func_status);
+ n.Notify();
+ });
n.WaitForNotification();
- return s;
+ TF_RETURN_IF_ERROR(s);
+ return frame.ConsumeRetvals(rets);
}
-void CapturedFunction::RunAsync(
- FunctionLibraryRuntime* lib,
- FunctionLibraryRuntime::InstantiateOptions inst_opts,
- FunctionLibraryRuntime::Options f_opts, std::vector<Tensor>&& args,
- std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
- Status s = MaybeInstantiate(lib, inst_opts);
+void CapturedFunction::RunAsync(IteratorContext* ctx,
+ std::vector<Tensor>&& args,
+ std::vector<Tensor>* rets,
+ FunctionLibraryRuntime::DoneCallback done) {
+ // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
+ // be deleted before `done` is called. Take care not to capture `ctx` in any
+ // code that may execute asynchronously in this function.
+ FunctionLibraryRuntime::Handle handle;
+ Status s = MaybeInstantiate(ctx, &handle);
if (!s.ok()) {
done(s);
return;
}
+ auto frame =
+ new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_);
+
+ FunctionLibraryRuntime::Options f_opts;
+ f_opts.step_id = CapturedFunction::generate_step_id();
+ ResourceMgr* resource_mgr = ctx->lib()->device()->resource_manager();
+ auto step_container = new ScopedStepContainer(
+ f_opts.step_id, [resource_mgr](const string& name) {
+ resource_mgr->Cleanup(name).IgnoreError();
+ });
+ f_opts.step_container = step_container;
+ f_opts.runner = ctx->runner();
// TODO(mrry): Add cancellation manager support to IteratorContext
// so that we can cancel running map functions. The local
// cancellation manager here is created so that we can run kernels
// `OpKernelContext::cancellation_manager()`, but additional effort
// will be required to plumb it through the `IteratorContext`.
auto c_mgr = new CancellationManager;
- auto frame =
- new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_);
f_opts.cancellation_manager = c_mgr;
- mutex_lock l(mu_);
- lib_->Run(f_opts, f_handle_, frame,
- std::bind(
- [rets, c_mgr, frame](FunctionLibraryRuntime::DoneCallback done,
- // Begin unbound arguments.
- Status s) {
- delete c_mgr;
- if (s.ok()) {
- s = frame->ConsumeRetvals(rets);
- }
- delete frame;
- done(s);
- },
- std::move(done), std::placeholders::_1));
+ tf_shared_lock l(mu_);
+ ctx->lib()->Run(f_opts, handle, frame,
+ std::bind(
+ [rets, step_container, c_mgr, frame](
+ FunctionLibraryRuntime::DoneCallback done,
+ // Begin unbound arguments.
+ Status s) {
+ delete step_container;
+ delete c_mgr;
+ if (s.ok()) {
+ s = frame->ConsumeRetvals(rets);
+ }
+ delete frame;
+ done(s);
+ },
+ std::move(done), std::placeholders::_1));
}
CapturedFunction::CapturedFunction(const NameAttrList& func,
// tensors in `args`, in order to be able to deallocate them as early as
// possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
// ownership of the `args`.
- Status Run(IteratorContext* ctx, FunctionLibraryRuntime::Options f_opts,
- std::vector<Tensor>&& args, std::vector<Tensor>* rets);
+ Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
+ std::vector<Tensor>* rets);
// Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
// possible.
Status RunWithBorrowedArgs(IteratorContext* ctx,
- FunctionLibraryRuntime::Options f_opts,
const std::vector<Tensor>& args,
std::vector<Tensor>* rets);
// the results in `*rets`, and calls the given `done` callback when the
// function returns. This method takes ownership of the tensors in `args`,
// in order to be able to deallocate them as early as possible.
- void RunAsync(FunctionLibraryRuntime* lib,
- FunctionLibraryRuntime::InstantiateOptions inst_opts,
- FunctionLibraryRuntime::Options f_opts,
- std::vector<Tensor>&& args, std::vector<Tensor>* rets,
+ void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
+ std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done);
// Returns that additional captured inputs that will be passed to the function
CapturedFunction(const NameAttrList& func,
std::vector<Tensor> captured_inputs);
- Status set_lib(FunctionLibraryRuntime* lib);
-
- Status MaybeInstantiate(FunctionLibraryRuntime* lib,
- FunctionLibraryRuntime::InstantiateOptions inst_opts);
+ Status MaybeInstantiate(IteratorContext* ctx,
+ FunctionLibraryRuntime::Handle* out_handle);
mutex mu_;
const NameAttrList func_;
==============================================================================*/
#include "tensorflow/core/kernels/data/dataset_utils.h"
-#include "tensorflow/core/common_runtime/device.h"
namespace tensorflow {
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator) {
- FunctionLibraryRuntime::Options opts;
- opts.runner = ctx->runner();
- // Choose a step ID that is guaranteed not to clash with any
- // Session-generated step ID. DirectSession only generates
- // non-negative step IDs (contiguous, starting from 0), and
- // MasterSession generates 56-bit random step IDs whose MSB
- // is always 0, so a negative random step ID should suffice.
- opts.step_id = CapturedFunction::generate_step_id();
- ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) {
- ctx->lib()->device()->resource_manager()->Cleanup(name).IgnoreError();
- });
- opts.step_container = &step_container;
std::vector<Tensor> return_values;
- TF_RETURN_IF_ERROR(captured_func->RunWithBorrowedArgs(
- ctx, opts, input_element, &return_values));
+ TF_RETURN_IF_ERROR(
+ captured_func->RunWithBorrowedArgs(ctx, input_element, &return_values));
if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
TensorShapeUtils::IsScalar(return_values[0].shape()))) {
return Status::OK();
}
- FunctionLibraryRuntime::Options opts;
- opts.step_id = CapturedFunction::generate_step_id();
- ScopedStepContainer step_container(opts.step_id,
- [ctx](const string& name) {
- ctx->lib()
- ->device()
- ->resource_manager()
- ->Cleanup(name)
- .IgnoreError();
- });
- opts.step_container = &step_container;
- opts.runner = ctx->runner();
// TODO(mrry): Avoid blocking a threadpool thread. We will need to
// stack-rip the iterators and use async kernels.
- Notification n;
- Status ret;
std::vector<Tensor> result;
- ret = dataset()->captured_func_->RunWithBorrowedArgs(
- ctx, opts, *out_tensors, &result);
+ TF_RETURN_IF_ERROR(dataset()->captured_func_->RunWithBorrowedArgs(
+ ctx, *out_tensors, &result));
- if (!ret.ok()) {
- return ret;
- } else if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
- result[0].NumElements() != 1) {
+ if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
+ result[0].NumElements() != 1) {
return errors::InvalidArgument(
"Filter predicate `f` must return a scalar bool.");
}
input_impl_->GetNext(ctx, &next_input_element, &end_of_input_));
if (!end_of_input_) {
- FunctionLibraryRuntime::Options opts;
- opts.step_id = CapturedFunction::generate_step_id();
- opts.runner = ctx->runner();
- ScopedStepContainer step_container(opts.step_id,
- [ctx](const string& name) {
- ctx->lib()
- ->device()
- ->resource_manager()
- ->Cleanup(name)
- .IgnoreError();
- });
- opts.step_container = &step_container;
-
// Run the key function on the input element to identify its
// group.
std::vector<Tensor> key_func_output;
TF_RETURN_IF_ERROR(
dataset()->captured_key_func_->RunWithBorrowedArgs(
- ctx, opts, next_input_element, &key_func_output));
+ ctx, next_input_element, &key_func_output));
if (key_func_output.size() != 1 ||
key_func_output[0].dtype() != DT_INT64 ||
const int64 key = key_func_output[0].scalar<int64>()();
if (window_sizes_.find(key) == window_sizes_.end()) {
- // Run window_size function
- FunctionLibraryRuntime::Options opts2;
- opts2.step_id = CapturedFunction::generate_step_id();
- opts2.runner = ctx->runner();
- ScopedStepContainer step_container2(opts2.step_id,
- [ctx](const string& name) {
- ctx->lib()
- ->device()
- ->resource_manager()
- ->Cleanup(name)
- .IgnoreError();
- });
- opts2.step_container = &step_container2;
-
// Run the window size function on the key to identify its
// window size.
std::vector<Tensor> window_size_func_output;
TF_RETURN_IF_ERROR(dataset()->captured_window_size_func_->Run(
- ctx, opts2, std::move(key_func_output),
- &window_size_func_output));
+ ctx, std::move(key_func_output), &window_size_func_output));
if (window_size_func_output.size() != 1 ||
window_size_func_output[0].dtype() != DT_INT64 ||
Status StartFlushingGroup(IteratorContext* ctx, int64 key)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- FunctionLibraryRuntime::Options opts;
- opts.step_id = CapturedFunction::generate_step_id();
- opts.runner = ctx->runner();
- ScopedStepContainer step_container(opts.step_id, [ctx](const string&
- name) {
- ctx->lib()->device()->resource_manager()->Cleanup(name).IgnoreError();
- });
- opts.step_container = &step_container;
-
DatasetBase* group_dataset;
TF_RETURN_IF_ERROR(NewWindowDataset(
groups_[key], dataset()->input_->output_dtypes(),
{std::move(key_arg), std::move(group_dataset_arg)});
std::vector<Tensor> return_values;
TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Run(
- ctx, opts, std::move(args), &return_values));
+ ctx, std::move(args), &return_values));
if (!(return_values.size() == 1 &&
return_values[0].dtype() == DT_VARIANT &&
// Call `captured_func_(input_element)`, store the result in
// `result->return_values`, and notify `batch_result->counter`
// to unblock a consumer.
- FunctionLibraryRuntime::Options opts;
- opts.step_id = CapturedFunction::generate_step_id();
- ResourceMgr* resource_mgr = ctx->lib()->device()->resource_manager();
- ScopedStepContainer* step_container = new ScopedStepContainer(
- opts.step_id, [resource_mgr](const string& name) {
- resource_mgr->Cleanup(name).IgnoreError();
- });
- opts.step_container = step_container;
- std::function<void(std::function<void()>)>* runner =
- new std::function<void(std::function<void()>)>(*ctx->runner());
- opts.runner = runner;
- FunctionLibraryRuntime* lib = ctx->lib();
- FunctionLibraryRuntime::InstantiateOptions inst_opts;
- inst_opts.overlay_lib = ctx->function_library().get();
-
(*ctx->runner())(std::bind(
- [this, lib, inst_opts, opts, result, step_container, runner,
- batch_result, offset](std::vector<Tensor> input_element) {
+ [this, result, batch_result, offset](
+ IteratorContext* ctx, std::vector<Tensor> input_element) {
dataset()->captured_func_->RunAsync(
- lib, inst_opts, opts, std::move(input_element),
- &result->return_values,
- [this, step_container, runner, result, batch_result,
- offset](Status ret_status) {
- delete step_container;
- delete runner;
+ ctx, std::move(input_element), &result->return_values,
+ [this, ctx, result, batch_result, offset](Status ret_status) {
+ delete ctx;
result->status.Update(ret_status);
if (ret_status.ok()) {
EnsureOutputAllocated(batch_result,
batch_result->counter->DecrementCount();
});
},
- std::move(input_element)));
+ new IteratorContext(*ctx), std::move(input_element)));
}
void StartInvocationBatch(IteratorContext* ctx, int64 batch_index)
return Status::OK();
}
- FunctionLibraryRuntime::Options opts;
- opts.step_id = CapturedFunction::generate_step_id();
-
- ScopedStepContainer step_container(opts.step_id, [ctx](const string&
- name) {
- ctx->lib()->device()->resource_manager()->Cleanup(name).IgnoreError();
- });
- opts.step_container = &step_container;
- opts.runner = ctx->runner();
// TODO(mrry): Avoid blocking a threadpool thread. We will need to
// stack-rip the iterators and use async kernels.
- Status s = dataset()->captured_func_->Run(ctx, opts, std::move(args),
- out_tensors);
+ Status s =
+ dataset()->captured_func_->Run(ctx, std::move(args), out_tensors);
if (errors::IsOutOfRange(s)) {
// `f` may deliberately raise `errors::OutOfRange` to indicate
// that we should terminate the iteration early.
// `result->return_values`, and notify `result->notification`
// to unblock a consumer.
result->notification.reset(new Notification);
-
- FunctionLibraryRuntime::Options opts;
- opts.step_id = CapturedFunction::generate_step_id();
- ResourceMgr* resource_manager =
- ctx->lib()->device()->resource_manager();
- ScopedStepContainer* step_container = new ScopedStepContainer(
- opts.step_id, [resource_manager](const string& name) {
- resource_manager->Cleanup(name).IgnoreError();
- });
- opts.step_container = step_container;
- opts.runner = ctx->runner();
- FunctionLibraryRuntime::InstantiateOptions inst_opts;
- inst_opts.overlay_lib = ctx->function_library().get();
-
dataset()->captured_func_->RunAsync(
- ctx->lib(), inst_opts, opts, std::move(input_element),
- &result->return_values,
- [result, step_container, result_index](Status ret_status) {
- delete step_container;
+ ctx, std::move(input_element), &result->return_values,
+ [result, result_index](Status ret_status) {
result->status.Update(ret_status);
result->notification->Notify();
});
std::copy(next_element.begin(), next_element.end(),
std::back_inserter(args));
- FunctionLibraryRuntime::Options opts;
- opts.step_id = CapturedFunction::generate_step_id();
- ScopedStepContainer step_container(opts.step_id, [ctx](const string&
- name) {
- ctx->lib()->device()->resource_manager()->Cleanup(name).IgnoreError();
- });
- opts.step_container = &step_container;
- opts.runner = ctx->runner();
std::vector<Tensor> state_and_output;
state_and_output.reserve(dataset()->state_types_.size() +
output_dtypes().size());
- Status s = dataset()->captured_func_->Run(ctx, opts, std::move(args),
+ Status s = dataset()->captured_func_->Run(ctx, std::move(args),
&state_and_output);
if (s.ok()) {
state_.clear();