namespace tensorflow {
+class DeviceMgr;
+
class Device : public DeviceBase {
public:
Device(Env* env, const DeviceAttributes& device_attributes);
// Returns the resource manager associated w/ this device.
virtual ResourceMgr* resource_manager() { return rmgr_; }
+ // Returns the device manager that owns this device, or nullptr if this Device
+ // is not owned by a device manager.
+ DeviceMgr* device_mgr() const { return device_mgr_; }
+
// Summarizes the status of this Device, for debugging.
string DebugString() const { return ProtoDebugString(device_attributes_); }
}
private:
+ friend class DeviceMgr;
+
+ // Pointer to the device manager that owns this device. Not owned.
+ DeviceMgr* device_mgr_ = nullptr;
+
const DeviceAttributes device_attributes_;
DeviceNameUtils::ParsedName parsed_name_;
DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
: name_backing_store_(128) {
for (Device* d : devices) {
+ CHECK(d->device_mgr_ == nullptr);
+ d->device_mgr_ = this;
+
devices_.push_back(d);
// Register under the (1) full name and (2) canonical name.
}
Device* device = flr->device();
string device_type = device->parsed_name().type;
- if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
+ if (device_type == "CPU" || device_type == "TPU_SYSTEM" ||
+ device_type == "TPU") {
// "TPU_SYSTEM" indicates that `device` is a CPU.
return Status::OK();
}