Reduce overhead for eager ops
authorAkshay Modi <nareshmodi@google.com>
Mon, 2 Apr 2018 20:17:14 +0000 (13:17 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 2 Apr 2018 20:19:59 +0000 (13:19 -0700)
- Call _context_handle in the fastpath. Fall back to slow path if it is not
  initialized.
  A better fix would be to not initialize handle and devices lazily (and not
  have to pay that function call in the slow path either), but that
  seems to break all GPU/TPU tests. I'm not as yet really familiar with how
  devices are recognized, but I'd be happy to hear any ideas you may have to
  fix this.
- context.context() is monkey patched to remove the "is None" check once we
  know the context is correctly initialized. Ideally we would be able to remove
  this function call as well.
- Maintain is_eager instead of doing the comparison every time. Also, in the
  fastpath, inline the check directly instead of paying the function call cost.
- Inline _eager_context.device_name instead of get the device_name property to
  not pay the function call cost

gen_array_ops.identity Old: 216706.923837 examples/sec (4.61452722549)
gen_array_ops.identity New: 290819.129714 examples/sec (3.43856334686)

PiperOrigin-RevId: 191336857

tensorflow/python/eager/context.py
tensorflow/python/eager/python_eager_op_gen.cc
tensorflow/python/eager/pywrap_tfe_src.cc
tensorflow/python/framework/ops.py

index 6ad9e0d..99ec895 100644 (file)
@@ -86,6 +86,7 @@ class _EagerContext(threading.local):
     self.device_spec = pydev.DeviceSpec.from_string("")
     self.device_name = self.device_spec.to_string()
     self.mode = _default_mode
+    self.is_eager = _default_mode == EAGER_MODE
     self.scope_name = ""
     self.recording_summaries = False
     self.summary_writer_resource = None
@@ -283,9 +284,12 @@ class Context(object):
 
   @tf_contextlib.contextmanager
   def _mode(self, mode):
+    """A context manager to allow setting the mode to EAGER/GRAPH."""
     ctx = self._eager_context
     old_mode = ctx.mode
+    old_is_eager = ctx.is_eager
     ctx.mode = mode
+    ctx.is_eager = mode == EAGER_MODE
     if mode == EAGER_MODE:
       # Entering graph mode does not provide us with sufficient information to
       # record a context switch; graph-based context switches are only logged
@@ -294,13 +298,14 @@ class Context(object):
     try:
       yield
     finally:
+      ctx.is_eager = old_is_eager
       ctx.mode = old_mode
       if mode == EAGER_MODE:
         self.context_switches.pop()
 
   def executing_eagerly(self):
     """Returns True if current thread has eager executing enabled."""
-    return self._eager_context.mode == EAGER_MODE
+    return self._eager_context.is_eager
 
   def scalar_cache(self):
     """Per-device cache for scalars."""
@@ -508,23 +513,19 @@ class Context(object):
     To retrieve the accumulated metadata call context.export_run_metadata()
     and to stop tracing call context.disable_run_metadata().
     """
-    if not self._context_handle:
-      self._initialize_handle_and_devices()
-    pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle)
+    pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._handle)
 
   @tf_contextlib.contextmanager
   def device_policy(self, policy):
-    if not self._context_handle:
-      self._initialize_handle_and_devices()
-    old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
-        self._context_handle)
+    handle = self._handle
+    old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(handle)
     pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
-        self._handle, policy)
+        handle, policy)
     try:
       yield
     finally:
       pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
-          self._handle, old)
+          handle, old)
 
   def disable_run_metadata(self):
     """Disables tracing of op execution via RunMetadata."""
@@ -575,6 +576,10 @@ def context():
   return _context
 
 
+def context_safe():
+  return _context
+
+
 # TODO(agarwal): remove this.
 def get_default_context():
   """Same as context."""
index c2ce8ef..0618590 100644 (file)
@@ -367,7 +367,7 @@ void GenEagerPythonOp::HandleGraphMode(const string& function_setup) {
   // Handle graph-mode case
   strings::StrAppend(&result_,
                      "  _ctx = _context.context()\n"
-                     "  if not _ctx.executing_eagerly():\n",
+                     "  if not _ctx._eager_context.is_eager:\n",
                      function_setup,
                      "    _, _, _op = _op_def_lib._apply_op_helper(\n");
   AddBodyNoReturn("        ");
@@ -712,9 +712,9 @@ bool GenEagerPythonOp::AddEagerFallbackCode(
 }
 
 void GenEagerPythonOp::AddEagerFastPathExecute() {
-  string fastpath_execute_params =
-      strings::StrCat("_ctx._handle, _ctx.device_name, \"", op_def_.name(),
-                      "\", ", "name, _ctx._post_execution_callbacks");
+  string fastpath_execute_params = strings::StrCat(
+      "_ctx._context_handle, _ctx._eager_context.device_name, \"",
+      op_def_.name(), "\", ", "name, _ctx._post_execution_callbacks");
   string fallback_params;
 
   for (int i = 0; i < api_def_.in_arg_size(); i++) {
index 8a398f6..d99bd0b 100644 (file)
@@ -1844,6 +1844,15 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
 
   op_exec_info.ctx = reinterpret_cast<TFE_Context*>(
       PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
+
+  if (op_exec_info.ctx == nullptr) {
+    // The context hasn't been initialized. It will be in the slow path.
+    RaiseFallbackException(
+        "This function does not handle the case of the path where "
+        "all inputs are not already EagerTensors.");
+    return nullptr;
+  }
+
   op_exec_info.device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
   op_exec_info.op_name = PyTuple_GET_ITEM(args, 2);
   op_exec_info.op_def = GetOpDef(op_exec_info.op_name);
index 22b621e..c0baeb9 100644 (file)
@@ -5343,6 +5343,10 @@ def enable_eager_execution(config=None, device_policy=None,
     raise ValueError(
         "tf.enable_eager_execution must be called at program startup.")
 
+  # Monkey patch to get rid of an unnecessary conditional since the context is
+  # now initialized.
+  context.context = context.context_safe
+
 
 def eager_run(main=None, argv=None):
   """Runs the program with an optional main function and argv list.