string DebugString(Handle h) override;
+ Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
+ std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
+ FunctionLibraryRuntime** out_flr) override;
+
private:
typedef FunctionLibraryRuntimeImpl ME;
}
}
+Status FunctionLibraryRuntimeImpl::Clone(
+ std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
+ std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
+ FunctionLibraryRuntime** out_flr) {
+ TF_RETURN_IF_ERROR(
+ parent_->Clone(env_, graph_def_version_, optimizer_.options(),
+ custom_kernel_creator_, out_lib_def, out_pflr));
+ *out_flr = (*out_pflr)->GetFLR(device_->name());
+ if (out_flr != nullptr) {
+ return Status::OK();
+ } else {
+ return errors::Internal("Cloning FunctionLibraryRuntime failed.");
+ }
+}
+
namespace {
struct CustomCreatorSingleton {
shape_map,
const std::function<bool(const Node*)>& cse_consider_fn = nullptr);
+ const OptimizerOptions& options() { return opts_; }
+
private:
OptimizerOptions opts_;
done(errors::Internal("Could not find device"));
}
+Status ProcessFunctionLibraryRuntime::Clone(
+ Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
+ CustomKernelCreator custom_kernel_creator,
+ std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
+ std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr) {
+ out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_));
+ out_pflr->reset(new ProcessFunctionLibraryRuntime(
+ device_mgr_, env, graph_def_version, out_lib_def->get(),
+ optimizer_options, std::move(custom_kernel_creator), parent_));
+ return Status::OK();
+}
+
} // namespace tensorflow
// Removes handle from the state owned by this object.
Status RemoveHandle(FunctionLibraryRuntime::Handle handle);
+ Status Clone(Env* env, int graph_def_version,
+ const OptimizerOptions& optimizer_options,
+ CustomKernelCreator custom_kernel_creator,
+ std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
+ std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr);
+
friend class FunctionLibraryRuntimeImpl;
mutable mutex mu_;
class CancellationManager;
class GraphDef;
class OpKernel;
+class ProcessFunctionLibraryRuntime;
class ResourceMgr;
class Rendezvous;
class ScopedStepContainer;
virtual int graph_def_version() = 0;
typedef uint64 LocalHandle;
+
+ virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
+ std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
+ FunctionLibraryRuntime** out_flr) = 0;
};
// Returns a canonicalized string for the instantiation of the
{
mutex_lock l(mu_);
if (resource_ == nullptr) {
- FunctionLibraryRuntime* lib = context->function_library();
+ FunctionLibraryRuntime* lib;
std::unique_ptr<DeviceMgr> device_mgr(nullptr);
std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
// is sufficient demand, but it will require a significant refactoring.
if (!name_.empty()) {
lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
+ } else {
+ OP_REQUIRES_OK(context, context->function_library()->Clone(
+ &flib_def, &pflr, &lib));
}
ResourceMgr* mgr = context->resource_manager();
// Wrap the existing device in order to see any captured resources
// in its resource manager. The existing device will outlive the
// IteratorResource, because we are storing the IteratorResource
- // in that device's resourc manager.
+ // in that device's resource manager.
Device* wrapped_device = RenamedDevice::NewRenamedDevice(
ctx->device()->name(), down_cast<Device*>(ctx->device()),
false /* owns_underlying */, false /* isolate_session_state */);
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testImplicitDisposeParallelMapDataset(self):
+ # Tests whether a parallel map dataset will be cleaned up correctly when
+ # the pipeline does not run it until exhaustion.
+ # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
+ # RepeatDataset(None) -> PrefetchDataset(100).
+ worker, _ = test_util.create_local_cluster(1, 1)
+
+ components = (np.arange(1000),
+ np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
+ np.array(37.0) * np.arange(1000))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
+ .repeat(None).prefetch(10000))
+
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with session.Session(worker[0].target) as sess:
+ sess.run(init_op)
+ for _ in range(3):
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()