self.device_spec = pydev.DeviceSpec.from_string("")
self.device_name = self.device_spec.to_string()
self.mode = _default_mode
+ self.is_eager = _default_mode == EAGER_MODE
self.scope_name = ""
self.recording_summaries = False
self.summary_writer_resource = None
@tf_contextlib.contextmanager
def _mode(self, mode):
+ """A context manager to allow setting the mode to EAGER/GRAPH."""
ctx = self._eager_context
old_mode = ctx.mode
+ old_is_eager = ctx.is_eager
ctx.mode = mode
+ ctx.is_eager = mode == EAGER_MODE
if mode == EAGER_MODE:
# Entering graph mode does not provide us with sufficient information to
# record a context switch; graph-based context switches are only logged
try:
yield
finally:
+ ctx.is_eager = old_is_eager
ctx.mode = old_mode
if mode == EAGER_MODE:
self.context_switches.pop()
def executing_eagerly(self):
"""Returns True if current thread has eager executing enabled."""
- return self._eager_context.mode == EAGER_MODE
+ return self._eager_context.is_eager
def scalar_cache(self):
"""Per-device cache for scalars."""
To retrieve the accumulated metadata call context.export_run_metadata()
and to stop tracing call context.disable_run_metadata().
"""
- if not self._context_handle:
- self._initialize_handle_and_devices()
- pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle)
+ pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._handle)
@tf_contextlib.contextmanager
def device_policy(self, policy):
- if not self._context_handle:
- self._initialize_handle_and_devices()
- old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
- self._context_handle)
+ handle = self._handle
+ old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(handle)
pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
- self._handle, policy)
+ handle, policy)
try:
yield
finally:
pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
- self._handle, old)
+ handle, old)
def disable_run_metadata(self):
"""Disables tracing of op execution via RunMetadata."""
return _context
+def context_safe():
+ return _context
+
+
# TODO(agarwal): remove this.
def get_default_context():
"""Same as context."""
// Handle graph-mode case
strings::StrAppend(&result_,
" _ctx = _context.context()\n"
- " if not _ctx.executing_eagerly():\n",
+ " if not _ctx._eager_context.is_eager:\n",
function_setup,
" _, _, _op = _op_def_lib._apply_op_helper(\n");
AddBodyNoReturn(" ");
}
void GenEagerPythonOp::AddEagerFastPathExecute() {
- string fastpath_execute_params =
- strings::StrCat("_ctx._handle, _ctx.device_name, \"", op_def_.name(),
- "\", ", "name, _ctx._post_execution_callbacks");
+ string fastpath_execute_params = strings::StrCat(
+ "_ctx._context_handle, _ctx._eager_context.device_name, \"",
+ op_def_.name(), "\", ", "name, _ctx._post_execution_callbacks");
string fallback_params;
for (int i = 0; i < api_def_.in_arg_size(); i++) {
op_exec_info.ctx = reinterpret_cast<TFE_Context*>(
PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
+
+ if (op_exec_info.ctx == nullptr) {
+ // The context hasn't been initialized. It will be in the slow path.
+ RaiseFallbackException(
+ "This function does not handle the case of the path where "
+ "all inputs are not already EagerTensors.");
+ return nullptr;
+ }
+
op_exec_info.device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
op_exec_info.op_name = PyTuple_GET_ITEM(args, 2);
op_exec_info.op_def = GetOpDef(op_exec_info.op_name);