From 7ace7f14caf81c9acbac2e3ba26a754cbe78ead5 Mon Sep 17 00:00:00 2001
From: Guangda Lai <laigd@google.com>
Date: Fri, 9 Feb 2018 22:47:30 -0800
Subject: [PATCH] Fix grappler to use CudaGpuId instead of TfGpuId to query
 device states.

PiperOrigin-RevId: 185233116
---
 tensorflow/core/grappler/clusters/BUILD           |  6 +++---
 .../core/grappler/clusters/single_machine.cc      |  9 +++++++--
 tensorflow/core/grappler/clusters/utils.cc        | 15 +++++++++++----
 tensorflow/core/grappler/clusters/utils.h         |  3 ++-
 tensorflow/core/grappler/costs/BUILD              |  1 +
 tensorflow/core/grappler/costs/utils.cc           |  3 ++-
 6 files changed, 26 insertions(+), 11 deletions(-)

diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD
index 5b8ce373bc..b15a709c5b 100644
--- a/tensorflow/core/grappler/clusters/BUILD
+++ b/tensorflow/core/grappler/clusters/BUILD
@@ -26,13 +26,12 @@ config_setting(
 tf_cuda_library(
     name = "utils",
     srcs = ["utils.cc"],
-    hdrs = [
-        "utils.h",
-    ],
+    hdrs = ["utils.h"],
     visibility = ["//visibility:public"],
     deps = [
         "//third_party/eigen3",
         "//tensorflow/core:framework",
+        "//tensorflow/core:gpu_id",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
     ] + select({
@@ -104,6 +103,7 @@ cc_library(
         "//tensorflow/core:core_cpu_lib",
         "//tensorflow/core:direct_session",
         "//tensorflow/core:framework",
+        "//tensorflow/core:gpu_id",
         "//tensorflow/core:lib",
         "//tensorflow/core/grappler:utils",
         "//tensorflow/core/kernels:ops_util",
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
index 862ce4ae88..3e97b31f2c 100644
--- a/tensorflow/core/grappler/clusters/single_machine.cc
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -21,6 +21,7 @@ limitations under the License.
 #include "tensorflow/cc/training/queue_runner.h"
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
 #include "tensorflow/core/grappler/clusters/utils.h"
 #include "tensorflow/core/grappler/utils.h"
 #include "tensorflow/core/kernels/ops_util.h"
@@ -79,13 +80,17 @@ Status SingleMachine::Provision() {
 
   std::vector<DeviceAttributes> devices;
   TF_RETURN_IF_ERROR(session_->ListDevices(&devices));
-  int gpu_id = 0;
   for (const auto& dev : devices) {
     DeviceProperties attr;
     if (dev.device_type() == "CPU") {
       attr = GetLocalCPUInfo();
     } else if (dev.device_type() == "GPU") {
-      attr = GetLocalGPUInfo(gpu_id++);
+      DeviceNameUtils::ParsedName parsed;
+      if (!DeviceNameUtils::ParseFullName(dev.name(), &parsed)) {
+        return errors::InvalidArgument(
+            strings::StrCat("Not able to parse GPU device name: ", dev.name()));
+      }
+      attr = GetLocalGPUInfo(TfGpuId(parsed.id));
     } else {
       attr.set_type(dev.device_type());
     }
diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc
index aacd2ccb72..3e7a7a3356 100644
--- a/tensorflow/core/grappler/clusters/utils.cc
+++ b/tensorflow/core/grappler/clusters/utils.cc
@@ -27,6 +27,8 @@ limitations under the License.
 #include "include/libxsmm.h"
 #endif
 
+#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
 #include "tensorflow/core/lib/strings/numbers.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/cpu_info.h"
@@ -66,13 +68,14 @@ DeviceProperties GetLocalCPUInfo() {
   return device;
 }
 
-DeviceProperties GetLocalGPUInfo(int gpu_id) {
+DeviceProperties GetLocalGPUInfo(TfGpuId tf_gpu_id) {
   DeviceProperties device;
   device.set_type("GPU");
 
 #if GOOGLE_CUDA
   cudaDeviceProp properties;
-  cudaError_t error = cudaGetDeviceProperties(&properties, gpu_id);
+  CudaGpuId cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id);
+  cudaError_t error = cudaGetDeviceProperties(&properties, cuda_gpu_id.value());
   if (error == cudaSuccess) {
     device.set_vendor("NVidia");
     device.set_model(properties.name);
@@ -94,6 +97,10 @@ DeviceProperties GetLocalGPUInfo(int gpu_id) {
     // double data rate (DDR).
     device.set_bandwidth(properties.memoryBusWidth / 8 *
                          properties.memoryClockRate * 2);
+  } else {
+    LOG(ERROR) << "Failed to get device properties, error code: " << error;
+    device.set_type("UNKNOWN");
+    return device;
   }
 
   (*device.mutable_environment())["architecture"] =
@@ -110,9 +117,9 @@ DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device) {
     return GetLocalCPUInfo();
   } else if (device.type == "GPU") {
     if (device.has_id) {
-      return GetLocalGPUInfo(device.id);
+      return GetLocalGPUInfo(TfGpuId(device.id));
     } else {
-      return GetLocalGPUInfo(0);
+      return GetLocalGPUInfo(TfGpuId(0));
     }
   }
   DeviceProperties result;
diff --git a/tensorflow/core/grappler/clusters/utils.h b/tensorflow/core/grappler/clusters/utils.h
index 191942040a..4ea7e98390 100644
--- a/tensorflow/core/grappler/clusters/utils.h
+++ b/tensorflow/core/grappler/clusters/utils.h
@@ -16,6 +16,7 @@ limitations under the License.
 #ifndef TENSORFLOW_GRAPPLER_CLUSTERS_UTILS_H_
 #define TENSORFLOW_GRAPPLER_CLUSTERS_UTILS_H_
 
+#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
 #include "tensorflow/core/protobuf/device_properties.pb.h"
 #include "tensorflow/core/util/device_name_utils.h"
 
@@ -27,7 +28,7 @@ DeviceProperties GetLocalCPUInfo();
 
 // Returns the DeviceProperties for the specified GPU attached to the server on
 // which grappler is running.
-DeviceProperties GetLocalGPUInfo(int gpu_id);
+DeviceProperties GetLocalGPUInfo(TfGpuId tf_gpu_id);
 
 // Returns the DeviceProperties of the specified device
 DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device);
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 0fe01e9c9e..5336df1f51 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -142,6 +142,7 @@ tf_cuda_library(
         "//third_party/eigen3",
         "//tensorflow/core:framework",
         "//tensorflow/core:graph",
+        "//tensorflow/core:gpu_id",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_proto_parsing",
         "//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index 602f69f12e..ac30090607 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -26,6 +26,7 @@ limitations under the License.
 #include "cuda/include/cudnn.h"
 #endif
 
+#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
 #include "tensorflow/core/framework/allocation_description.pb.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/op.h"
@@ -203,7 +204,7 @@ DeviceProperties GetDeviceInfo(const string& device_str) {
   DeviceNameUtils::ParsedName parsed;
   if (DeviceNameUtils::ParseFullName(device_str, &parsed)) {
     if (parsed.type == "GPU") {
-      return GetLocalGPUInfo(parsed.id);
+      return GetLocalGPUInfo(TfGpuId(parsed.id));
     } else if (parsed.type == "CPU") {
       return GetLocalCPUInfo();
     }
-- 
2.34.1