Update eager uniform replay buffer microbenchmarks to compare against graph functions...
authorAkshay Agrawal <akshayka@google.com>
Mon, 26 Feb 2018 21:54:02 +0000 (13:54 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
PiperOrigin-RevId: 187075418

tensorflow/contrib/framework/python/ops/critical_section_ops.py
tensorflow/python/framework/ops.py

index 3c5c55e..ab603cc 100644 (file)
@@ -143,7 +143,7 @@ class CriticalSection(object):
   def _init_from_args(self, name, shared_name):  # pylint: disable=invalid-name
     """Initialize the CriticalSection from constructor arguments."""
     with ops.name_scope(name, "CriticalSection", []) as name:
-      with ops.control_dependencies(None):
+      with ops.init_scope():
         # pylint: disable=protected-access
         container = ops.get_default_graph()._container
         # pylint: enable=protected-access
@@ -226,7 +226,9 @@ class CriticalSection(object):
         # mode.  This is generally ok; since eager mode (as of
         # writing) executes sequentially anyway.
         for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
-          if sg.handle.name == self._handle.name:
+          sg_handle_name = ops.convert_to_tensor(sg.handle).name
+          self_handle_name = ops.convert_to_tensor(self._handle).name
+          if sg_handle_name == self_handle_name:
             # Other executions in the same critical section are allowed.
             continue
           if not (exclusive_resource_access or sg.exclusive_resource_access):
index 5a14ea4..b0d2704 100644 (file)
@@ -4805,7 +4805,14 @@ def container(container_name):
 @tf_export("colocate_with")
 def colocate_with(op, ignore_existing=False):
   if context.in_graph_mode():
-    return get_default_graph().colocate_with(op, ignore_existing)
+    default_graph = get_default_graph()
+    if isinstance(op, EagerTensor):
+      if default_graph.building_function:
+        op = internal_convert_to_tensor(op)
+      else:
+        raise ValueError("Encountered an Eager-defined Tensor during graph "
+                         "construction, but a function was not being built.")
+    return default_graph.colocate_with(op, ignore_existing)
   else:
     if op is not None:
       return device(op.device)