// The instantiated and transformed function is encoded as a Graph
// object, and an executor is created for the graph.
- struct Item : public core::RefCounted {
- bool invalidated = false;
+ struct Item {
+ uint64 instantiation_counter = 0;
const Graph* graph = nullptr; // Owned by exec.
const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned.
FunctionBody* func_graph = nullptr;
Executor* exec = nullptr;
- ~Item() override {
+ ~Item() {
delete this->func_graph;
delete this->exec;
}
};
- std::unordered_map<Handle, Item*> items_ GUARDED_BY(mu_);
+ std::unordered_map<Handle, std::unique_ptr<Item>> items_ GUARDED_BY(mu_);
ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned.
}
}
-FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
- for (auto p : items_) p.second->Unref();
-}
+FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {}
// An asynchronous op kernel which executes an instantiated function
// defined in a library.
options_copy.target = device_name_;
const string key = Canonicalize(function_name, attrs, options_copy);
- Handle found_handle = kInvalidHandle;
{
mutex_lock l(mu_);
- found_handle = parent_->GetHandle(key);
- if (found_handle != kInvalidHandle) {
+ *handle = parent_->GetHandle(key);
+ if (*handle != kInvalidHandle) {
FunctionLibraryRuntime::LocalHandle handle_on_device =
- parent_->GetHandleOnDevice(device_name_, found_handle);
+ parent_->GetHandleOnDevice(device_name_, *handle);
if (handle_on_device == kInvalidLocalHandle) {
return errors::Internal("LocalHandle not found for handle ", *handle,
".");
}
- auto iter = items_.find(handle_on_device);
- if (iter == items_.end()) {
+ auto item_handle = items_.find(handle_on_device);
+ if (item_handle == items_.end()) {
return errors::Internal("LocalHandle ", handle_on_device,
- " for handle ", found_handle,
+ " for handle ", *handle,
" not found in items.");
}
- Item* item = iter->second;
- if (!item->invalidated) {
- *handle = found_handle;
- return Status::OK();
- }
- // *item is invalidated. Fall through and instantiate the given
- // function_name/attrs/option again.
+ ++item_handle->second->instantiation_counter;
+ return Status::OK();
}
}
{
mutex_lock l(mu_);
- Handle found_handle_again = parent_->GetHandle(key);
- if (found_handle_again != found_handle) {
+ *handle = parent_->GetHandle(key);
+ if (*handle != kInvalidHandle) {
delete fbody;
- *handle = found_handle_again;
+ ++items_[parent_->GetHandleOnDevice(device_name_, *handle)]
+ ->instantiation_counter;
} else {
*handle = parent_->AddHandle(key, device_name_, next_handle_);
Item* item = new Item;
item->func_graph = fbody;
item->overlay_lib = options.overlay_lib;
- items_.insert({next_handle_, item});
+ item->instantiation_counter = 1;
+ items_.emplace(next_handle_, std::unique_ptr<Item>(item));
next_handle_++;
}
}
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
return parent_->ReleaseHandle(handle);
}
+
LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
CHECK_NE(h, kInvalidLocalHandle);
mutex_lock l(mu_);
CHECK_EQ(1, items_.count(h));
- Item* item = items_[h];
- item->invalidated = true; // Reinstantiate later.
+ std::unique_ptr<Item>& item = items_[h];
+ --item->instantiation_counter;
+ if (item->instantiation_counter == 0) {
+ items_.erase(h);
+ TF_RETURN_IF_ERROR(parent_->RemoveHandle(handle));
+ }
return Status::OK();
}
return errors::NotFound("Function handle ", handle,
" is not valid. Likely an internal error.");
}
- *item = items_[local_handle];
+ *item = items_[local_handle].get();
if ((*item)->exec != nullptr) {
return Status::OK();
}
// computation is done and stored in *rets, we send the return values back
// to the source_device (caller) so that the ProcFLR can receive them later.
std::vector<Tensor>* remote_args = new std::vector<Tensor>;
- item->Ref();
ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
source_device, target_device, "arg_", src_incarnation, args.size(),
device_context, {}, rendezvous, remote_args,
s = frame->SetArgs(*remote_args);
}
if (!s.ok()) {
- item->Unref();
delete frame;
delete remote_args;
delete exec_args;
return;
}
item->exec->RunAsync(
- *exec_args, [item, frame, rets, done, source_device, target_device,
+ *exec_args, [frame, rets, done, source_device, target_device,
target_incarnation, rendezvous, device_context,
remote_args, exec_args](const Status& status) {
- core::ScopedUnref unref(item);
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
return;
}
- item->Ref();
item->exec->RunAsync(
// Executor args
*exec_args,
// Done callback.
- [item, frame, rets, done, exec_args](const Status& status) {
- core::ScopedUnref unref(item);
+ [frame, rets, done, exec_args](const Status& status) {
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
exec_args->runner = *run_opts.runner;
exec_args->call_frame = frame;
- item->Ref();
item->exec->RunAsync(
// Executor args
*exec_args,
[item, frame, exec_args](DoneCallback done,
// Start unbound arguments.
const Status& status) {
- core::ScopedUnref unref(item);
delete exec_args;
done(status);
},
return status;
}
FunctionLibraryRuntime::Options opts;
- TF_RETURN_IF_ERROR(Run(flr, handle, opts, args, rets, add_runner));
- return flr->ReleaseHandle(handle);
+ status = Run(flr, handle, opts, args, rets, add_runner);
+ if (!status.ok()) return status;
+
+ // Release the handle and try running again. It should not succeed.
+ status = flr->ReleaseHandle(handle);
+ if (!status.ok()) return status;
+
+ Status status2 = Run(flr, handle, opts, args, std::move(rets));
+ EXPECT_TRUE(errors::IsInvalidArgument(status2));
+ EXPECT_TRUE(
+ str_util::StrContains(status2.error_message(), "remote execution."));
+
+ return status;
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
*rets[i] = retvals[i];
}
- // Release the handle.
- return flr->ReleaseHandle(handle);
+ // Release the handle and try running again. It should not succeed.
+ status = flr->ReleaseHandle(handle);
+ if (!status.ok()) return status;
+
+ Status status2 = Run(flr, handle, opts, args, std::move(rets));
+ EXPECT_TRUE(errors::IsInvalidArgument(status2));
+ EXPECT_TRUE(
+ str_util::StrContains(status2.error_message(), "remote execution."));
+
+ return status;
}
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
FunctionLibraryRuntime::LocalHandle local_handle) {
mutex_lock l(mu_);
auto h = next_handle_;
- FunctionData* fd = new FunctionData(device_name, local_handle);
- function_data_[h] = std::unique_ptr<FunctionData>(fd);
+ function_data_[h] = MakeUnique<FunctionData>(
+ device_name, local_handle, function_key);
table_[function_key] = h;
next_handle_++;
return h;
gtl::FindWithDefault(table_, function_key, kInvalidHandle);
if (h == kInvalidHandle || function_data_.count(h) == 0) {
h = next_handle_;
- FunctionData* fd = new FunctionData(options.target, kInvalidHandle);
- function_data_[h] = std::unique_ptr<FunctionData>(fd);
+ function_data_[h] = MakeUnique<FunctionData>(
+ options.target, kInvalidHandle, function_key);
table_[function_key] = h;
next_handle_++;
}
return Status::OK();
}
+Status ProcessFunctionLibraryRuntime::RemoveHandle(
+ FunctionLibraryRuntime::Handle handle) {
+ mutex_lock l(mu_);
+ table_.erase(function_data_[handle]->function_key());
+ function_data_.erase(handle);
+ return Status::OK();
+}
+
Status ProcessFunctionLibraryRuntime::ReleaseHandle(
FunctionLibraryRuntime::Handle handle) {
FunctionLibraryRuntime* flr = nullptr;
// of the device where the function is registered.
string GetDeviceName(FunctionLibraryRuntime::Handle handle);
+ // Removes handle from the state owned by this object.
+ Status RemoveHandle(FunctionLibraryRuntime::Handle handle);
+
Status Clone(Env* env, int graph_def_version,
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
class FunctionData {
public:
FunctionData(const string& target_device,
- FunctionLibraryRuntime::LocalHandle local_handle)
- : target_device_(target_device), local_handle_(local_handle) {}
+ FunctionLibraryRuntime::LocalHandle local_handle,
+ const string& function_key)
+ : target_device_(target_device),
+ local_handle_(local_handle),
+ function_key_(function_key) {}
string target_device() { return target_device_; }
+ const string& function_key() { return function_key_; }
FunctionLibraryRuntime::LocalHandle local_handle() {
mutex_lock l(mu_);
const string target_device_;
FunctionLibraryRuntime::LocalHandle local_handle_ GUARDED_BY(mu_);
+ const string function_key_;
bool init_started_ GUARDED_BY(mu_) = false;
Status init_result_ GUARDED_BY(mu_);
Notification init_done_;