From: Akshay Agrawal Date: Mon, 26 Feb 2018 21:54:02 +0000 (-0800) Subject: Update eager uniform replay buffer microbenchmarks to compare against graph functions... X-Git-Tag: upstream/v1.7.0~105 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6c99456856973d7cfee31aeeabef8d79014a097f;p=platform%2Fupstream%2Ftensorflow.git Update eager uniform replay buffer microbenchmarks to compare against graph functions when possible. PiperOrigin-RevId: 187075418 --- diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py index 3c5c55e..ab603cc 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py @@ -143,7 +143,7 @@ class CriticalSection(object): def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name """Initialize the CriticalSection from constructor arguments.""" with ops.name_scope(name, "CriticalSection", []) as name: - with ops.control_dependencies(None): + with ops.init_scope(): # pylint: disable=protected-access container = ops.get_default_graph()._container # pylint: enable=protected-access @@ -226,7 +226,9 @@ class CriticalSection(object): # mode. This is generally ok; since eager mode (as of # writing) executes sequentially anyway. for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): - if sg.handle.name == self._handle.name: + sg_handle_name = ops.convert_to_tensor(sg.handle).name + self_handle_name = ops.convert_to_tensor(self._handle).name + if sg_handle_name == self_handle_name: # Other executions in the same critical section are allowed. continue if not (exclusive_resource_access or sg.exclusive_resource_access): diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 5a14ea4..b0d2704 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -4805,7 +4805,14 @@ def container(container_name): @tf_export("colocate_with") def colocate_with(op, ignore_existing=False): if context.in_graph_mode(): - return get_default_graph().colocate_with(op, ignore_existing) + default_graph = get_default_graph() + if isinstance(op, EagerTensor): + if default_graph.building_function: + op = internal_convert_to_tensor(op) + else: + raise ValueError("Encountered an Eager-defined Tensor during graph " + "construction, but a function was not being built.") + return default_graph.colocate_with(op, ignore_existing) else: if op is not None: return device(op.device)