From 72c55090f6365b8b3846b09bc749ce92bf43479a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 9 May 2018 07:27:30 -0700 Subject: [PATCH] Automated g4 rollback of changelist 195120627 PiperOrigin-RevId: 195966744 --- tensorflow/core/common_runtime/device.h | 11 +++++++++++ tensorflow/core/common_runtime/device_mgr.cc | 3 +++ .../core/common_runtime/process_function_library_runtime.cc | 3 ++- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index 5918cd9..b537666 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -51,6 +51,8 @@ limitations under the License. namespace tensorflow { +class DeviceMgr; + class Device : public DeviceBase { public: Device(Env* env, const DeviceAttributes& device_attributes); @@ -133,6 +135,10 @@ class Device : public DeviceBase { // Returns the resource manager associated w/ this device. virtual ResourceMgr* resource_manager() { return rmgr_; } + // Returns the device manager that owns this device, or nullptr if this Device + // is not owned by a device manager. + DeviceMgr* device_mgr() const { return device_mgr_; } + // Summarizes the status of this Device, for debugging. string DebugString() const { return ProtoDebugString(device_attributes_); } @@ -158,6 +164,11 @@ class Device : public DeviceBase { } private: + friend class DeviceMgr; + + // Pointer to the device manager that owns this device. Not owned. + DeviceMgr* device_mgr_ = nullptr; + const DeviceAttributes device_attributes_; DeviceNameUtils::ParsedName parsed_name_; diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc index a77601b..470abc1 100644 --- a/tensorflow/core/common_runtime/device_mgr.cc +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -27,6 +27,9 @@ namespace tensorflow { DeviceMgr::DeviceMgr(const std::vector& devices) : name_backing_store_(128) { for (Device* d : devices) { + CHECK(d->device_mgr_ == nullptr); + d->device_mgr_ = this; + devices_.push_back(d); // Register under the (1) full name and (2) canonical name. diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index e61ed8c..668ce87 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -144,7 +144,8 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext( } Device* device = flr->device(); string device_type = device->parsed_name().type; - if (device_type == "CPU" || device_type == "TPU_SYSTEM") { + if (device_type == "CPU" || device_type == "TPU_SYSTEM" || + device_type == "TPU") { // "TPU_SYSTEM" indicates that `device` is a CPU. return Status::OK(); } -- 2.7.4