Moves TFE_Executor to common_runtime
authorAlexandre Passos <apassos@google.com>
Thu, 22 Mar 2018 01:22:36 +0000 (18:22 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 01:25:12 +0000 (18:25 -0700)
PiperOrigin-RevId: 190001737

tensorflow/c/eager/BUILD
tensorflow/c/eager/c_api.cc
tensorflow/c/eager/c_api_internal.h
tensorflow/core/common_runtime/eager/BUILD
tensorflow/core/common_runtime/eager/context.cc [new file with mode: 0644]
tensorflow/core/common_runtime/eager/context.h [new file with mode: 0644]

index 841ff48..bea5a12 100644 (file)
@@ -28,6 +28,7 @@ tf_cuda_library(
             "//tensorflow/c:c_api",
             "//tensorflow/c:c_api_internal",
             "//tensorflow/core:core_cpu",
+            "//tensorflow/core/common_runtime/eager:context",
             "//tensorflow/core/common_runtime/eager:eager_executor",
             "//tensorflow/core/common_runtime/eager:kernel_and_device",
             "//tensorflow/core:core_cpu_internal",
@@ -64,6 +65,7 @@ tf_cuda_library(
         "//tensorflow/core:framework_lite",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/common_runtime/eager:context",
         "//tensorflow/core/common_runtime/eager:eager_executor",
         "//tensorflow/core/common_runtime/eager:kernel_and_device",
     ],
index a23015c..5d66884 100644 (file)
@@ -71,18 +71,6 @@ std::atomic_int_fast64_t func_id_generator(0);
 
 }  // namespace
 
-TFE_ContextDevicePlacementPolicy PlacementPolicy(
-    bool soft_placement, TFE_ContextDevicePlacementPolicy original_policy) {
-  if (!soft_placement) {
-    return original_policy;
-  }
-  if (original_policy == TFE_DEVICE_PLACEMENT_EXPLICIT ||
-      original_policy == TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) {
-    return TFE_DEVICE_PLACEMENT_SILENT;
-  }
-  return original_policy;
-}
-
 extern "C" {
 
 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
@@ -104,19 +92,7 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
 TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
                                                         unsigned char async,
                                                         TF_Status* status) {
-  {
-    tensorflow::mutex_lock l(ctx->async_map_mu);
-    ctx->thread_local_async[std::this_thread::get_id()] = async;
-  }
-  if (async) {
-    ctx->executor.EnableAsync();
-  } else {
-    // TODO(agarwal): Currently we add a wait here to handle cases where a sync
-    // op has a control dependency on an async op, and the latter has not
-    // executed yet. This wait can be removed by storing all the control inputs
-    // and waiting for them when executing ops.
-    status->status = ctx->executor.WaitForAllPendingNodes();
-  }
+  status->status = ctx->context.SetAsyncForThread(async);
 }
 
 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
@@ -133,34 +109,26 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
       new tensorflow::DeviceMgr(devices));
   tensorflow::Rendezvous* r =
       new tensorflow::IntraProcessRendezvous(device_mgr.get());
-  return new TFE_Context(*opts, std::move(device_mgr), r);
+  return new TFE_Context(opts->session_options.options, opts->policy,
+                         opts->async, std::move(device_mgr), r);
 }
 
 void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
-  status->status = ctx->executor.WaitForAllPendingNodes();
-  {
-    tensorflow::mutex_lock ml(ctx->cache_mu);
-    tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
-  }
-  ctx->rendezvous->Unref();
   delete ctx;
 }
 
 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
   TF_DeviceList* list = new TF_DeviceList;
-  ctx->device_manager->ListDeviceAttributes(&list->response);
+  ctx->context.device_mgr()->ListDeviceAttributes(&list->response);
   return list;
 }
 
-void TFE_ContextClearCaches(TFE_Context* ctx) {
-  tensorflow::mutex_lock ml(ctx->cache_mu);
-  tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
-}
+void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
 
 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
     TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
-  tensorflow::mutex_lock ml(ctx->policy_map_mu);
-  ctx->thread_local_policies[std::this_thread::get_id()] = policy;
+  ctx->context.SetThreadLocalDevicePlacementPolicy(
+      static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
 }
 
 // Note: this function looks up a thread local policy. So it should be called in
@@ -168,25 +136,20 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
 // safe to call this function from the async EagerExecutor threads.
 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
     TFE_Context* ctx) {
-  tensorflow::mutex_lock ml(ctx->policy_map_mu);
-  auto policy_map_it =
-      ctx->thread_local_policies.find(std::this_thread::get_id());
-  if (policy_map_it != ctx->thread_local_policies.end()) {
-    return policy_map_it->second;
-  }
-  return ctx->policy;
+  return static_cast<TFE_ContextDevicePlacementPolicy>(
+      ctx->context.GetDevicePlacementPolicy());
 }
 
 void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
-  status->status = ctx->executor.WaitForAllPendingNodes();
+  status->status = ctx->context.AsyncWait();
 }
 
 void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
-  status->status = ctx->executor.status();
+  status->status = ctx->context.GetStatus();
 }
 
 void TFE_ContextAsyncClearError(TFE_Context* ctx) {
-  ctx->executor.ClearError();
+  ctx->context.ClearAsyncError();
 }
 
 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
@@ -259,7 +222,7 @@ tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h,
   // nullptr.
   tensorflow::Device* src_opd = nullptr;
   TF_RETURN_IF_ERROR(h->TensorAndDevice(&src, &srcd, &src_opd));
-  if (srcd == nullptr) srcd = ctx->devices[0];
+  if (srcd == nullptr) srcd = ctx->context.HostCPU();
   bool is_same_device =
       (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
   const bool dst_cpu = IsCPU(dstd);
@@ -332,8 +295,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
   status->status = tensorflow::AttrTypeMapForOp(name, &types);
   if (status->status.ok()) return new TFE_Op(ctx, name, types);
   if (TF_GetCode(status) == TF_NOT_FOUND) {
-    tensorflow::mutex_lock l(ctx->functions_mu);
-    if (ctx->func_lib_def.Find(name) != nullptr) {
+    if (ctx->context.FindFunctionByName(name)) {
       status->status = tensorflow::Status::OK();
       return new TFE_Op(ctx, name, nullptr);
     }
@@ -346,20 +308,14 @@ void TFE_DeleteOp(TFE_Op* op) { delete op; }
 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
   tensorflow::Device* d = nullptr;
   if (device_name != nullptr && strlen(device_name) > 0) {
-    auto it = op->ctx->devices_map.find(device_name);
-    if (it == op->ctx->devices_map.end()) {
-      status->status =
-          tensorflow::errors::InvalidArgument(device_name, " unknown device.");
-      return;
-    }
-    d = it->second;
+    status->status = op->ctx->context.FindDeviceByName(device_name, &d);
   }
   op->device = d;
 }
 
 const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
   tensorflow::Device* device =
-      (op->device == nullptr) ? op->ctx->devices[0] : op->device;
+      (op->device == nullptr) ? op->ctx->context.HostCPU() : op->device;
   return device->name().c_str();
 }
 
@@ -634,7 +590,7 @@ tensorflow::Status ValidateInputTypeAndPlacement(
 tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
                                  TFE_Context* ctx, TF_Status* status) {
   tensorflow::DeviceSet ds;
-  for (tensorflow::Device* d : ctx->devices) {
+  for (tensorflow::Device* d : *ctx->context.devices()) {
     ds.AddDevice(d);
   }
   tensorflow::DeviceTypeVector final_devices;
@@ -648,7 +604,7 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
         "Could not find valid device for node ", ndef.DebugString());
     return nullptr;
   }
-  for (tensorflow::Device* d : ctx->devices) {
+  for (tensorflow::Device* d : *ctx->context.devices()) {
     if (d->device_type() == final_devices[0].type_string()) {
       return d;
     }
@@ -663,9 +619,8 @@ tensorflow::Status Execute(
     const tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>& op_inputs,
     tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats,
     TFE_TensorHandle** retvals, int num_retvals) {
-  if (!ctx->soft_placement && device == nullptr) {
-    // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU
-    device = ctx->devices[0];
+  if (!ctx->context.SoftPlacement() && device == nullptr) {
+    device = ctx->context.HostCPU();
   }
 
   if (device == nullptr) {
@@ -697,18 +652,18 @@ tensorflow::Status Execute(
   if (maybe_stats != nullptr) {
     maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() -
                                        maybe_stats->all_start_micros());
-    tensorflow::mutex_lock ml(ctx->metadata_mu);
-    if (ctx->should_store_metadata.load()) {
-      auto* step_stats = ctx->run_metadata.mutable_step_stats();
+    tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
+    if (ctx->context.ShouldStoreMetadata()) {
+      auto* step_stats = ctx->context.RunMetadataProto()->mutable_step_stats();
       // Lazily initialize the RunMetadata with information about all devices if
       // this is the first call.
-      while (step_stats->dev_stats_size() < ctx->devices.size()) {
+      while (step_stats->dev_stats_size() < ctx->context.devices()->size()) {
         step_stats->add_dev_stats();
       }
       // Find the current device's index.
       int device_idx = 0;
-      for (int i = 0; i < ctx->devices.size(); ++i) {
-        if (ctx->devices[i] == device) {
+      for (int i = 0; i < ctx->context.devices()->size(); ++i) {
+        if (ctx->context.devices()->at(i) == device) {
           device_idx = i;
           break;
         }
@@ -744,7 +699,7 @@ class ExecuteNode : public tensorflow::EagerNode {
               tensorflow::NodeExecStats* maybe_stats,
               const tensorflow::DataTypeVector& output_dtypes,
               TFE_TensorHandle** retvals, int num_retvals)
-      : tensorflow::EagerNode(op->ctx->executor.NextId()),
+      : tensorflow::EagerNode(op->ctx->context.NextId()),
         ctx_(op->ctx),
         op_device_(op->device),
         inputs_(op->inputs),
@@ -800,7 +755,7 @@ class CopyToDeviceNode : public tensorflow::EagerNode {
  public:
   CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd,
                    TFE_Context* ctx)
-      : tensorflow::EagerNode(ctx->executor.NextId()),
+      : tensorflow::EagerNode(ctx->context.NextId()),
         src_(src),
         dstd_(dstd),
         ctx_(ctx),
@@ -1063,7 +1018,7 @@ extern "C" {
 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
                  TF_Status* status) {
   TFE_Context* ctx = op->ctx;
-  status->status = ctx->executor.status();
+  status->status = ctx->context.GetStatus();
   if (!status->status.ok()) {
     return;
   }
@@ -1087,7 +1042,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE &&
         input_op_device != op->device) {
       tensorflow::Device* d =
-          input_op_device == nullptr ? ctx->devices[0] : input_op_device;
+          input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device;
       VLOG(1) << "Changing device of operation " << op->name << " to "
               << d->name() << " because input #" << i
               << " is a resource in this device.";
@@ -1095,40 +1050,35 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     }
   }
   tensorflow::Device* device = op->device;
-  if (!ctx->soft_placement && device == nullptr) {
-    // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU
-    device = ctx->devices[0];
+  if (!ctx->context.SoftPlacement() && device == nullptr) {
+    device = ctx->context.HostCPU();
   }
 
   tensorflow::Fprint128 cache_key =
       op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name());
-  tensorflow::KernelAndDevice* kernel;
-  {
-    tensorflow::tf_shared_lock l(ctx->cache_mu);
-    kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
-  }
+  tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key);
   if (kernel == nullptr) {
     const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
-    if (ctx->soft_placement && device == nullptr) {
+    if (ctx->context.SoftPlacement() && device == nullptr) {
       device = SelectDevice(ndef, ctx, status);
       if (!status->status.ok()) {
         return;
       }
     }
     CHECK(device != nullptr);
-    if (ctx->log_device_placement) {
+    if (ctx->context.LogDevicePlacement()) {
       LOG(INFO) << "Executing op " << ndef.op() << " in device "
                 << device->name();
     }
-    kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
+    kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous());
     // Knowledge of the implementation of Init (and in-turn
     // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
     // will be accessed, so grab on to the lock.
     // See WARNING comment in Execute (before kernel->Run) - would be nice to
     // rework to avoid this subtlety.
-    tensorflow::tf_shared_lock l(ctx->functions_mu);
-    status->status =
-        tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
+    tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu());
+    status->status = tensorflow::KernelAndDevice::Init(
+        ndef, ctx->context.func_lib(device), kernel);
     if (!status->status.ok()) {
       delete kernel;
       return;
@@ -1136,7 +1086,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     // Update output_dtypes inside `kernel`.
     const tensorflow::OpDef* op_def = nullptr;
     const tensorflow::FunctionDef* function_def =
-        ctx->func_lib_def.Find(ndef.op());
+        ctx->context.FuncLibDef()->Find(ndef.op());
     if (function_def != nullptr) {
       op_def = &(function_def->signature());
     }
@@ -1152,8 +1102,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     if (!status->status.ok()) {
       return;
     }
-    tensorflow::mutex_lock ml(ctx->cache_mu);
-    tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
+    ctx->context.AddKernelToCache(cache_key, kernel);
   }
   const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes();
   const int output_dtypes_size = output_dtypes.size();
@@ -1171,11 +1120,11 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     // device from the one requested above.
     device = kernel->device();
   }
-  status->status = ValidateInputTypeAndPlacement(ctx, ctx->devices[0], device,
-                                                 op, kernel->kernel());
+  status->status = ValidateInputTypeAndPlacement(ctx, ctx->context.HostCPU(),
+                                                 device, op, kernel->kernel());
   if (!status->status.ok()) return;
   std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
-  if (ctx->should_store_metadata.load()) {
+  if (ctx->context.ShouldStoreMetadata()) {
     maybe_stats.reset(new tensorflow::NodeExecStats);
     maybe_stats->set_node_name(op->name);
     maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros());
@@ -1183,14 +1132,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros());
     // TODO(apassos) track referenced tensors
   }
-  if (ctx->Async()) {
+  if (ctx->context.Async()) {
     // Note that for async mode, execution order will make sure that all
     // input handles are ready before executing them.
     // TODO(agarwal): Consider executing "cheap" kernels inline for performance.
     tensorflow::EagerNode* node =
         new ExecuteNode(op, kernel, maybe_stats.release(), output_dtypes,
                         retvals, *num_retvals);
-    ctx->executor.Add(node);
+    ctx->context.ExecutorAdd(node);
   } else {
     // Execute checks if retvals[i] is nullptr or not to figure if it needs to
     // allocate it.
@@ -1206,23 +1155,24 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
                                                TFE_Context* ctx,
                                                const char* device_name,
                                                TF_Status* status) {
-  status->status = ctx->executor.status();
+  status->status = ctx->context.GetStatus();
   if (!status->status.ok()) {
     return nullptr;
   }
-  tensorflow::Device* dstd = ctx->devices[0];
+  tensorflow::Device* dstd = ctx->context.HostCPU();
   if (device_name != nullptr && strlen(device_name) > 0) {
-    status->status = ctx->device_manager->LookupDevice(device_name, &dstd);
+    status->status =
+        ctx->context.device_mgr()->LookupDevice(device_name, &dstd);
     if (!status->status.ok()) return nullptr;
   }
-  if (ctx->Async()) {
+  if (ctx->context.Async()) {
     // Note that `h` may not be currently ready. However execution order will
     // make sure that `h` is ready before the copy is actually done.
     CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
     TFE_TensorHandle* output = node->dst();
     // Note that calling Add makes `node` accessible by the EagerExecutor
     // thread. So further accesses need to be thread-safe.
-    ctx->executor.Add(node);
+    ctx->context.ExecutorAdd(node);
     return output;
   } else {
     TFE_TensorHandle* output = nullptr;
@@ -1240,24 +1190,20 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
         tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
     return;
   }
-  tensorflow::mutex_lock l(ctx->functions_mu);
-  status->status = ctx->func_lib_def.AddFunctionDef(function_def);
+  status->status = ctx->context.AddFunctionDef(function_def);
 }
 
 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
                             TF_Status* status) {
-  tensorflow::mutex_lock l(ctx->functions_mu);
-  status->status = ctx->func_lib_def.AddFunctionDef(function->fdef);
+  status->status = ctx->context.AddFunctionDef(function->fdef);
 }
 
 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
-  ctx->should_store_metadata.store(true);
+  ctx->context.SetShouldStoreMetadata(true);
 }
 
 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
-  tensorflow::mutex_lock ml(ctx->metadata_mu);
-  ctx->should_store_metadata.store(false);
-  ctx->run_metadata.Clear();
+  ctx->context.SetShouldStoreMetadata(false);
 }
 
 }  // extern "C"
@@ -1286,9 +1232,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
                                   TF_Status* status) {
   TFE_ContextAsyncWait(ctx, status);
   if (!status->status.ok()) return;
-  tensorflow::mutex_lock ml(ctx->metadata_mu);
-  status->status = MessageToBuffer(ctx->run_metadata, buf);
-  ctx->run_metadata.Clear();
+  tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
+  status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf);
+  ctx->context.RunMetadataProto()->Clear();
 }
 
 namespace {
@@ -1363,11 +1309,6 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
 }  // namespace tensorflow
 
 
-bool TFE_Context::Async() const {
-  tensorflow::mutex_lock l(async_map_mu);
-  return tensorflow::gtl::FindWithDefault(
-      thread_local_async, std::this_thread::get_id(), async_default);
-}
 
 bool TFE_TensorHandle::IsReady() {
   if (node_id == 0) return true;
@@ -1381,7 +1322,7 @@ tensorflow::Status TFE_TensorHandle::WaitReady() {
   {
     tensorflow::mutex_lock l(ctx_mutex_);
     if (ctx_ == nullptr) return tensorflow::Status::OK();
-    executor = &ctx_->executor;
+    executor = ctx_->context.Executor();
   }
   return executor->WaitFor(node_id);
 }
index a79f8dd..5b29120 100644 (file)
@@ -30,6 +30,7 @@ limitations under the License.
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/c/eager/runtime.h"
 #include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
 #include "tensorflow/core/common_runtime/function.h"
@@ -52,85 +53,18 @@ struct TFE_ContextOptions {
       TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32};
 };
 
-TFE_ContextDevicePlacementPolicy PlacementPolicy(
-    bool soft_placement, TFE_ContextDevicePlacementPolicy original_policy);
-
 struct TFE_Context {
-  explicit TFE_Context(const TFE_ContextOptions& opts,
+  explicit TFE_Context(const tensorflow::SessionOptions& opts,
+                       TFE_ContextDevicePlacementPolicy default_policy,
+                       bool async,
                        std::unique_ptr<tensorflow::DeviceMgr> device_mgr,
                        tensorflow::Rendezvous* rendezvous)
-      : soft_placement(
-            opts.session_options.options.config.allow_soft_placement()),
-        policy(PlacementPolicy(soft_placement, opts.policy)),
-        device_manager(std::move(device_mgr)),
-        devices(device_manager->ListDevices()),
-        rendezvous(rendezvous),
-        pflr(new tensorflow::ProcessFunctionLibraryRuntime(
-            device_manager.get(), opts.session_options.options.env,
-            TF_GRAPH_DEF_VERSION, &func_lib_def, {})),
-        log_device_placement(
-            opts.session_options.options.config.log_device_placement()),
-        async_default(opts.async) {
-    if (async_default) executor.EnableAsync();
-
-    for (auto* device : devices) {
-      devices_map[tensorflow::StringPiece(device->name())] = device;
-    }
-  }
-
-  const bool soft_placement;
-  const TFE_ContextDevicePlacementPolicy policy;
-
-  // Note: we cannot use C++11 thread_local here as there is no concept of a
-  // thread-local-object-local variable in C++11.
-  tensorflow::mutex policy_map_mu;
-  std::unordered_map<std::thread::id, TFE_ContextDevicePlacementPolicy>
-      thread_local_policies GUARDED_BY(policy_map_mu);
-
-  std::unique_ptr<tensorflow::DeviceMgr> device_manager;
-  // Devices owned by device_manager
-  std::vector<tensorflow::Device*> devices;
-  // All devices are not owned.
-  tensorflow::gtl::FlatMap<tensorflow::StringPiece, tensorflow::Device*,
-                           tensorflow::StringPieceHasher>
-      devices_map;
-  tensorflow::Rendezvous* const rendezvous;
-
-  tensorflow::mutex functions_mu;
-  tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){
-      tensorflow::OpRegistry::Global(), {}};
-
-  // One FunctionLibraryRuntime per device.
-  // func_libs[i] is the FunctionLibraryRuntime corresponding to
-  // session->devices[i].
-  const std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
-
-  tensorflow::mutex cache_mu;
-  std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
-                     tensorflow::Fprint128Hasher>
-      kernel_cache GUARDED_BY(cache_mu);
-
-  tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const {
-    return pflr->GetFLR(d->name());
-  }
+      : context(opts,
+                static_cast<tensorflow::ContextDevicePlacementPolicy>(
+                    default_policy),
+                async, std::move(device_mgr), rendezvous) {}
 
-  // Whether we should compute RunMetadata.
-  std::atomic<bool> should_store_metadata{false};
-  tensorflow::mutex metadata_mu;
-  tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu);
-  const bool log_device_placement;
-  // EagerExecutor for async execution.
-  tensorflow::EagerExecutor executor;
-
-  // True if running in asynchronous mode.
-  bool Async() const;
-
-  // True if the default value for execution mode is async. Note that this value
-  // can be overridden per thread based on `thread_local_async` overrides.
-  const bool async_default;
-  mutable tensorflow::mutex async_map_mu;
-  std::unordered_map<std::thread::id, bool> thread_local_async
-      GUARDED_BY(async_map_mu);
+  tensorflow::EagerContext context;
 };
 
 struct TFE_TensorHandle : public tensorflow::core::RefCounted {
index 8ba560b..de10b10 100644 (file)
@@ -33,6 +33,28 @@ tf_cuda_library(
 )
 
 tf_cuda_library(
+    name = "context",
+    srcs = [
+        "context.cc",
+    ],
+    hdrs = [
+        "context.h",
+    ],
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        ":eager_executor",
+        ":kernel_and_device",
+        "//tensorflow/core:core_cpu_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:session_options",
+    ],
+)
+
+tf_cuda_library(
     name = "kernel_and_device",
     srcs = [
         "kernel_and_device.cc",
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
new file mode 100644 (file)
index 0000000..5e8d083
--- /dev/null
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/eager/context.h"
+
+namespace tensorflow {
+
+ContextDevicePlacementPolicy PlacementPolicy(
+    bool soft_placement, ContextDevicePlacementPolicy original_policy) {
+  if (!soft_placement) {
+    return original_policy;
+  }
+  if (original_policy == DEVICE_PLACEMENT_EXPLICIT ||
+      original_policy == DEVICE_PLACEMENT_SILENT_FOR_INT32) {
+    return DEVICE_PLACEMENT_SILENT;
+  }
+  return original_policy;
+}
+
+EagerContext::EagerContext(const SessionOptions& opts,
+                           ContextDevicePlacementPolicy default_policy,
+                           bool async, std::unique_ptr<DeviceMgr> device_mgr,
+                           Rendezvous* rendezvous)
+    : soft_placement_(opts.config.allow_soft_placement()),
+      policy_(PlacementPolicy(soft_placement_, default_policy)),
+      device_manager_(std::move(device_mgr)),
+      devices_(device_manager_->ListDevices()),
+      rendezvous_(rendezvous),
+      pflr_(new ProcessFunctionLibraryRuntime(device_manager_.get(), opts.env,
+                                              TF_GRAPH_DEF_VERSION,
+                                              &func_lib_def_, {})),
+      log_device_placement_(opts.config.log_device_placement()),
+      async_default_(async) {
+  if (async_default_) {
+    executor_.EnableAsync();
+  }
+
+  for (auto* device : devices_) {
+    devices_map_[device->name()] = device;
+  }
+}
+
+bool EagerContext::Async() const {
+  mutex_lock l(async_map_mu_);
+  return gtl::FindWithDefault(thread_local_async_, std::this_thread::get_id(),
+                              async_default_);
+}
+
+Status EagerContext::SetAsyncForThread(bool async) {
+  {
+    tensorflow::mutex_lock l(async_map_mu_);
+    thread_local_async_[std::this_thread::get_id()] = async;
+  }
+  if (async) {
+    executor_.EnableAsync();
+  } else {
+    // TODO(agarwal): Currently we add a wait here to handle cases where a
+    // sync op has a control dependency on an async op, and the latter has not
+    // executed yet. This wait can be removed by storing all the control
+    // inputs and waiting for them when executing ops.
+    return executor_.WaitForAllPendingNodes();
+  }
+  return Status::OK();
+}
+
+void EagerContext::ClearCaches() {
+  mutex_lock ml(cache_mu_);
+  gtl::STLDeleteValues(&kernel_cache_);
+}
+
+void EagerContext::SetThreadLocalDevicePlacementPolicy(
+    ContextDevicePlacementPolicy policy) {
+  mutex_lock ml(policy_map_mu_);
+  thread_local_policies_[std::this_thread::get_id()] = policy;
+}
+
+ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() {
+  mutex_lock ml(policy_map_mu_);
+  auto policy_map_it = thread_local_policies_.find(std::this_thread::get_id());
+  if (policy_map_it != thread_local_policies_.end()) {
+    return policy_map_it->second;
+  }
+  return policy_;
+}
+
+EagerContext::~EagerContext() {
+  executor_.WaitForAllPendingNodes().IgnoreError();
+  ClearCaches();
+  rendezvous_->Unref();
+}
+
+bool EagerContext::FindFunctionByName(const string& name) {
+  mutex_lock l(functions_mu_);
+  return func_lib_def_.Find(name) != nullptr;
+}
+
+Status EagerContext::FindDeviceByName(const string& name, Device** result) {
+  auto it = devices_map_.find(name);
+  if (it == devices_map_.end()) {
+    return errors::InvalidArgument(name, " unknown device.");
+  }
+  *result = it->second;
+  return Status::OK();
+}
+
+Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
+  mutex_lock l(functions_mu_);
+  return func_lib_def_.AddFunctionDef(fdef);
+}
+
+KernelAndDevice* EagerContext::GetCachedKernel(Fprint128 cache_key) {
+  tf_shared_lock l(cache_mu_);
+  return gtl::FindPtrOrNull(kernel_cache_, cache_key);
+}
+
+void EagerContext::AddKernelToCache(Fprint128 cache_key,
+                                    KernelAndDevice* kernel) {
+  mutex_lock ml(cache_mu_);
+  gtl::InsertOrUpdate(&kernel_cache_, cache_key, kernel);
+}
+
+void EagerContext::SetShouldStoreMetadata(bool value) {
+  should_store_metadata_.store(value);
+  if (!value) {
+    mutex_lock ml(metadata_mu_);
+    run_metadata_.Clear();
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
new file mode 100644 (file)
index 0000000..d525d44
--- /dev/null
@@ -0,0 +1,193 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
+
+#include <algorithm>
+#include <cstddef>
+#include <map>
+#include <memory>
+#include <queue>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+// Note: there's a copy enum in eager/c_api.h. It should be kept in sync.
+enum ContextDevicePlacementPolicy {
+  // Running operations with input tensors on the wrong device will fail. When
+  // soft placement is enabled acts like TFE_DEVICE_PLACEMENT_SILENT.
+  DEVICE_PLACEMENT_EXPLICIT = 0,
+  // Copy the tensor to the right device but log a warning.
+  DEVICE_PLACEMENT_WARN = 1,
+  // Silently copy the tensor, which has a performance cost since the
+  // operation will be blocked till the copy completes.
+  DEVICE_PLACEMENT_SILENT = 2,
+  // Default placement policy which silently copies int32 tensors but not other
+  // dtypes.  When soft placement is enabled acts like
+  // TFE_DEVICE_PLACEMENT_SILENT.
+  DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
+};
+
+ContextDevicePlacementPolicy PlacementPolicy(
+    bool soft_placement, ContextDevicePlacementPolicy original_policy);
+
+class EagerContext {
+ public:
+  explicit EagerContext(const SessionOptions& opts,
+                        ContextDevicePlacementPolicy default_policy, bool async,
+                        std::unique_ptr<DeviceMgr> device_mgr,
+                        Rendezvous* rendezvous);
+
+  ~EagerContext();
+
+  // Returns the function library runtime for the given device.
+  FunctionLibraryRuntime* func_lib(Device* d) const {
+    return pflr_->GetFLR(d->name());
+  }
+
+  // True if running in asynchronous mode.
+  bool Async() const;
+
+  EagerExecutor* Executor() { return &executor_; }
+
+  // Sets whether this thread should run in synchronous or asynchronous mode.
+  Status SetAsyncForThread(bool async);
+
+  // TODO(apassos) make this return a constant reference
+  gtl::FlatMap<string, Device*, StringPieceHasher>* device_map() {
+    return &devices_map_;
+  }
+
+  // TODO(apassos) make this return a constant reference
+  std::vector<Device*>* devices() { return &devices_; }
+
+  // Clears the kernel caches.
+  void ClearCaches();
+
+  // Sets the device placement policy for the current thread.
+  void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy);
+
+  // Returns the device placement policy for the current thread.
+  ContextDevicePlacementPolicy GetDevicePlacementPolicy();
+
+  Status AsyncWait() { return executor_.WaitForAllPendingNodes(); }
+
+  Status GetStatus() { return executor_.status(); }
+
+  void ClearAsyncError() { executor_.ClearError(); }
+
+  bool FindFunctionByName(const string& name);
+
+  Status FindDeviceByName(const string& name, Device** result);
+
+  Device* HostCPU() { return devices_[0]; }
+
+  bool SoftPlacement() { return soft_placement_; }
+
+  uint64 NextId() { return executor_.NextId(); }
+
+  void ExecutorAdd(EagerNode* node) { executor_.Add(node); }
+
+  Status AddFunctionDef(const FunctionDef& fdef);
+
+  KernelAndDevice* GetCachedKernel(Fprint128 cache_key);
+
+  void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
+
+  bool LogDevicePlacement() { return log_device_placement_; }
+
+  Rendezvous* GetRendezvous() { return rendezvous_; }
+
+  mutex* FunctionsMu() { return &functions_mu_; }
+
+  tensorflow::DeviceMgr* device_mgr() { return device_manager_.get(); }
+
+  // TODO(apassos) remove the need for this
+  void ReleaseDeviceMgr() { device_manager_.release(); }
+
+  // TODO(apassos) clean up RunMetadata storage.
+  mutex* MetadataMu() { return &metadata_mu_; }
+  bool ShouldStoreMetadata() { return should_store_metadata_.load(); }
+  void SetShouldStoreMetadata(bool value);
+  RunMetadata* RunMetadataProto() { return &run_metadata_; }
+
+  FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
+
+ private:
+  const bool soft_placement_;
+  const ContextDevicePlacementPolicy policy_;
+
+  // Note: we cannot use C++11 thread_local here as there is no concept of a
+  // thread-local-object-local variable in C++11.
+  mutex policy_map_mu_;
+  std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
+      thread_local_policies_ GUARDED_BY(policy_map_mu_);
+
+  std::unique_ptr<DeviceMgr> device_manager_;
+  // Devices owned by device_manager
+  std::vector<Device*> devices_;
+  // All devices are not owned.
+  gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
+  Rendezvous* const rendezvous_;
+
+  mutex functions_mu_;
+  FunctionLibraryDefinition func_lib_def_ GUARDED_BY(functions_mu_){
+      OpRegistry::Global(), {}};
+
+  // One FunctionLibraryRuntime per device.
+  // func_libs[i] is the FunctionLibraryRuntime corresponding to
+  // session->devices[i].
+  const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+
+  mutex cache_mu_;
+  std::unordered_map<Fprint128, KernelAndDevice*, Fprint128Hasher> kernel_cache_
+      GUARDED_BY(cache_mu_);
+
+  // Whether we should compute RunMetadata.
+  std::atomic<bool> should_store_metadata_{false};
+  mutex metadata_mu_;
+  RunMetadata run_metadata_ GUARDED_BY(metadata_mu_);
+  const bool log_device_placement_;
+  // EagerExecutor for async execution.
+  EagerExecutor executor_;
+
+  // True if the default value for execution mode is async. Note that this value
+  // can be overridden per thread based on `thread_local_async` overrides.
+  const bool async_default_;
+  mutable mutex async_map_mu_;
+  std::unordered_map<std::thread::id, bool> thread_local_async_
+      GUARDED_BY(async_map_mu_);
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_