From 3af99b657f23e52d9c291d488fa3bb2a68e90022 Mon Sep 17 00:00:00 2001 From: Guangda Lai Date: Mon, 26 Feb 2018 10:59:54 -0800 Subject: [PATCH] Automated g4 rollback of changelist 185324160 PiperOrigin-RevId: 187048135 --- tensorflow/contrib/cmake/tf_core_cpu.cmake | 7 ++ tensorflow/contrib/makefile/Makefile | 1 + .../core/common_runtime/gpu/gpu_id_manager.cc | 50 +++++++++-- .../core/common_runtime/gpu/gpu_id_manager.h | 14 ++- tensorflow/core/grappler/clusters/BUILD | 26 +++++- .../core/grappler/clusters/single_machine.cc | 17 +++- tensorflow/core/grappler/clusters/utils.cc | 71 +++++++++------ tensorflow/core/grappler/clusters/utils.h | 3 +- tensorflow/core/grappler/clusters/utils_test.cc | 100 +++++++++++++++++++++ tensorflow/core/grappler/costs/BUILD | 1 + tensorflow/core/grappler/costs/utils.cc | 18 +++- 11 files changed, 262 insertions(+), 46 deletions(-) create mode 100644 tensorflow/core/grappler/clusters/utils_test.cc diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index 96ac60d..a54cbff 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -63,6 +63,12 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" "${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc" ) +file(GLOB_RECURSE tf_core_cpu_whitelisted_srcs + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id.h" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc" +) +list(REMOVE_ITEM tf_core_cpu_exclude_srcs ${tf_core_cpu_whitelisted_srcs}) list(REMOVE_ITEM tf_core_cpu_srcs ${tf_core_cpu_exclude_srcs}) if (tensorflow_ENABLE_GPU) @@ -79,6 +85,7 @@ if (tensorflow_ENABLE_GPU) "${tensorflow_source_dir}/tensorflow/core/*test*.cc" ) list(REMOVE_ITEM tf_core_gpu_srcs ${tf_core_gpu_exclude_srcs}) + list(REMOVE_ITEM tf_core_gpu_srcs ${tf_core_cpu_whitelisted_srcs}) list(APPEND tf_core_cpu_srcs ${tf_core_gpu_srcs}) endif() diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index 8132740..05e8d90 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -677,6 +677,7 @@ endif # TEGRA TF_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) # Add in any extra files that don't fit the patterns easily TF_CC_SRCS += tensorflow/contrib/makefile/downloads/fft2d/fftsg.c +TF_CC_SRCS += tensorflow/core/common_runtime/gpu/gpu_id_manager.cc # Also include the op and kernel definitions. TF_CC_SRCS += $(shell cat $(MAKEFILE_DIR)/tf_op_files.txt) PBT_CC_SRCS := $(shell cat $(MAKEFILE_DIR)/tf_pb_text_files.txt) diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc index 207afdc..7dfff32 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc @@ -18,7 +18,10 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" namespace tensorflow { @@ -27,8 +30,8 @@ namespace { class TfToCudaGpuIdMap { public: static TfToCudaGpuIdMap* singleton() { - static auto* manager = new TfToCudaGpuIdMap; - return manager; + static auto* id_map = new TfToCudaGpuIdMap; + return id_map; } void InsertOrDie(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id) @@ -47,18 +50,41 @@ class TfToCudaGpuIdMap { } } - int32 FindOrDie(TfGpuId tf_gpu_id) const LOCKS_EXCLUDED(mu_) { + CudaGpuId FindOrDie(TfGpuId tf_gpu_id) const LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); + return FindOrDieLocked(tf_gpu_id); + } + + bool Find(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) const + LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + if (id_map_.count(tf_gpu_id.value()) == 0) return false; + *cuda_gpu_id = FindOrDieLocked(tf_gpu_id); + return true; + } + + private: + TfToCudaGpuIdMap() = default; + + CudaGpuId FindOrDieLocked(TfGpuId tf_gpu_id) const + EXCLUSIVE_LOCKS_REQUIRED(mu_) { auto result = id_map_.find(tf_gpu_id.value()); CHECK(result != id_map_.end()) << "Could not find the mapping for TfGpuId: " << tf_gpu_id; - return result->second; + return CudaGpuId(result->second); + } + + void TestOnlyReset() LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + id_map_.clear(); } - private: using IdMapType = std::unordered_map; mutable mutex mu_; IdMapType id_map_ GUARDED_BY(mu_); + + friend class ::tensorflow::GpuIdManager; + TF_DISALLOW_COPY_AND_ASSIGN(TfToCudaGpuIdMap); }; } // namespace @@ -67,8 +93,20 @@ void GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, TfToCudaGpuIdMap::singleton()->InsertOrDie(tf_gpu_id, cuda_gpu_id); } +Status GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) { + if (TfToCudaGpuIdMap::singleton()->Find(tf_gpu_id, cuda_gpu_id)) { + return Status::OK(); + } + return errors::NotFound("TF GPU device with id ", tf_gpu_id.value(), + " was not registered"); +} + CudaGpuId GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id) { - return CudaGpuId(TfToCudaGpuIdMap::singleton()->FindOrDie(tf_gpu_id)); + return TfToCudaGpuIdMap::singleton()->FindOrDie(tf_gpu_id); +} + +void GpuIdManager::TestOnlyReset() { + TfToCudaGpuIdMap::singleton()->TestOnlyReset(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h index 33925d8..2b54cc1 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h +++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h @@ -17,15 +17,25 @@ limitations under the License. #define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_MANAGER_H_ #include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/lib/core/status.h" namespace tensorflow { -// Class that manages the translation between Tensorflow GPU ids and CUDA GPU -// ids. +// Class that maintains a map from TfGpuId to CudaGpuId, and manages the +// translation between them. class GpuIdManager { public: + // Adds a mapping from tf_gpu_id to cuda_gpu_id. static void InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id); + + // Gets the cuda_gpu_id associated with tf_gpu_id. Returns OK if found. + static Status TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id); + // Similar to the above version, but returns the result, and checks fail if + // no result is found. static CudaGpuId TfToCudaGpuId(TfGpuId tf_gpu_id); + + // Clears the map. Used in unit tests only. + static void TestOnlyReset(); }; } // namespace tensorflow diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index b8f8e13..b653f90 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -1,7 +1,12 @@ licenses(["notice"]) # Apache 2.0 +load("//tensorflow:tensorflow.bzl", "if_cuda") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cuda_library") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) filegroup( name = "all_files", @@ -26,13 +31,12 @@ config_setting( tf_cuda_library( name = "utils", srcs = ["utils.cc"], - hdrs = [ - "utils.h", - ], + hdrs = ["utils.h"], visibility = ["//visibility:public"], deps = [ "//third_party/eigen3", "//tensorflow/core:framework", + "//tensorflow/core:gpu_id", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ] + select({ @@ -41,6 +45,21 @@ tf_cuda_library( }), ) +tf_cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + linkstatic = if_cuda(1, 0), + tags = tf_cuda_tests_tags(), + deps = [ + ":utils", + "//tensorflow/core:gpu_id", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "cluster", srcs = ["cluster.cc"], @@ -104,6 +123,7 @@ cc_library( "//tensorflow/core:core_cpu_lib", "//tensorflow/core:direct_session", "//tensorflow/core:framework", + "//tensorflow/core:gpu_id", "//tensorflow/core:lib", "//tensorflow/core/grappler:utils", "//tensorflow/core/kernels:ops_util", diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc index cc7f418..8e236c9 100644 --- a/tensorflow/core/grappler/clusters/single_machine.cc +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/cc/training/queue_runner.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/grappler/clusters/utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/kernels/ops_util.h" @@ -80,13 +82,24 @@ Status SingleMachine::Provision() { std::vector devices; TF_RETURN_IF_ERROR(session_->ListDevices(&devices)); - int gpu_id = 0; for (const auto& dev : devices) { DeviceProperties attr; if (dev.device_type() == "CPU") { attr = GetLocalCPUInfo(); } else if (dev.device_type() == "GPU") { - attr = GetLocalGPUInfo(gpu_id++); + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(dev.name(), &parsed)) { + return errors::InvalidArgument( + strings::StrCat("Not able to parse GPU device name: ", dev.name())); + } + TfGpuId tf_gpu_id(parsed.id); + CudaGpuId cuda_gpu_id; + Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); + if (!s.ok()) { + return errors::Unavailable("Unknown TF GPU device with id ", + tf_gpu_id.value(), ": ", s.ToString()); + } + attr = GetLocalGPUInfo(cuda_gpu_id); } else if (dev.device_type().find("XLA") == string::npos) { // Filter out the fake XLA devices to avoid double counting the actual // hardware resources that are available. diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc index 607e10e..b54b349 100644 --- a/tensorflow/core/grappler/clusters/utils.cc +++ b/tensorflow/core/grappler/clusters/utils.cc @@ -27,6 +27,9 @@ limitations under the License. #include "include/libxsmm.h" #endif +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cpu_info.h" @@ -66,36 +69,40 @@ DeviceProperties GetLocalCPUInfo() { return device; } -DeviceProperties GetLocalGPUInfo(int gpu_id) { +DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id) { DeviceProperties device; device.set_type("GPU"); #if GOOGLE_CUDA cudaDeviceProp properties; - cudaError_t error = cudaGetDeviceProperties(&properties, gpu_id); - if (error == cudaSuccess) { - device.set_vendor("NVidia"); - device.set_model(properties.name); - device.set_frequency(properties.clockRate * 1e-3); - device.set_num_cores(properties.multiProcessorCount); - device.set_num_registers(properties.regsPerMultiprocessor); - // For compute capability less than 5, l1 cache size is configurable to - // either 16 KB or 48 KB. We use the initial configuration 16 KB here. For - // compute capability larger or equal to 5, l1 cache (unified with texture - // cache) size is 24 KB. This number may need to be updated for future - // compute capabilities. - device.set_l1_cache_size((properties.major < 5) ? 16 * 1024 : 24 * 1024); - device.set_l2_cache_size(properties.l2CacheSize); - device.set_l3_cache_size(0); - device.set_shared_memory_size_per_multiprocessor( - properties.sharedMemPerMultiprocessor); - device.set_memory_size(properties.totalGlobalMem); - // 8 is the number of bits per byte. 2 is accounted for - // double data rate (DDR). - device.set_bandwidth(properties.memoryBusWidth / 8 * - properties.memoryClockRate * 2); + cudaError_t error = cudaGetDeviceProperties(&properties, cuda_gpu_id.value()); + if (error != cudaSuccess) { + device.set_type("UNKNOWN"); + LOG(ERROR) << "Failed to get device properties, error code: " << error; + return device; } + device.set_vendor("NVIDIA"); + device.set_model(properties.name); + device.set_frequency(properties.clockRate * 1e-3); + device.set_num_cores(properties.multiProcessorCount); + device.set_num_registers(properties.regsPerMultiprocessor); + // For compute capability less than 5, l1 cache size is configurable to + // either 16 KB or 48 KB. We use the initial configuration 16 KB here. For + // compute capability larger or equal to 5, l1 cache (unified with texture + // cache) size is 24 KB. This number may need to be updated for future + // compute capabilities. + device.set_l1_cache_size((properties.major < 5) ? 16 * 1024 : 24 * 1024); + device.set_l2_cache_size(properties.l2CacheSize); + device.set_l3_cache_size(0); + device.set_shared_memory_size_per_multiprocessor( + properties.sharedMemPerMultiprocessor); + device.set_memory_size(properties.totalGlobalMem); + // 8 is the number of bits per byte. 2 is accounted for + // double data rate (DDR). + device.set_bandwidth(properties.memoryBusWidth / 8 * + properties.memoryClockRate * 2); + (*device.mutable_environment())["architecture"] = strings::StrCat(properties.major, ".", properties.minor); (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION); @@ -106,18 +113,26 @@ DeviceProperties GetLocalGPUInfo(int gpu_id) { } DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device) { + DeviceProperties unknown; + unknown.set_type("UNKNOWN"); + if (device.type == "CPU") { return GetLocalCPUInfo(); } else if (device.type == "GPU") { if (device.has_id) { - return GetLocalGPUInfo(device.id); + TfGpuId tf_gpu_id(device.id); + CudaGpuId cuda_gpu_id; + Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); + if (!s.ok()) { + LOG(ERROR) << s; + return unknown; + } + return GetLocalGPUInfo(cuda_gpu_id); } else { - return GetLocalGPUInfo(0); + return GetLocalGPUInfo(CudaGpuId(0)); } } - DeviceProperties result; - result.set_type("UNKNOWN"); - return result; + return unknown; } } // end namespace grappler diff --git a/tensorflow/core/grappler/clusters/utils.h b/tensorflow/core/grappler/clusters/utils.h index 1919420..df8e7dc 100644 --- a/tensorflow/core/grappler/clusters/utils.h +++ b/tensorflow/core/grappler/clusters/utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_GRAPPLER_CLUSTERS_UTILS_H_ #define TENSORFLOW_GRAPPLER_CLUSTERS_UTILS_H_ +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/protobuf/device_properties.pb.h" #include "tensorflow/core/util/device_name_utils.h" @@ -27,7 +28,7 @@ DeviceProperties GetLocalCPUInfo(); // Returns the DeviceProperties for the specified GPU attached to the server on // which grappler is running. -DeviceProperties GetLocalGPUInfo(int gpu_id); +DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id); // Returns the DeviceProperties of the specified device DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device); diff --git a/tensorflow/core/grappler/clusters/utils_test.cc b/tensorflow/core/grappler/clusters/utils_test.cc new file mode 100644 index 0000000..74218ad --- /dev/null +++ b/tensorflow/core/grappler/clusters/utils_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/clusters/utils.h" + +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" + +namespace tensorflow { +namespace grappler { +namespace { + +TEST(UtilsTest, GetLocalGPUInfo) { + GpuIdManager::TestOnlyReset(); +#if GOOGLE_CUDA + LOG(INFO) << "CUDA is enabled."; + DeviceProperties properties; + + // Invalid CUDA GPU ID. + properties = GetLocalGPUInfo(CudaGpuId(100)); + EXPECT_EQ("UNKNOWN", properties.type()); + + // Succeed when a valid CUDA GPU id was inserted. + properties = GetLocalGPUInfo(CudaGpuId(0)); + EXPECT_EQ("GPU", properties.type()); + EXPECT_EQ("NVIDIA", properties.vendor()); +#else + LOG(INFO) << "CUDA is not enabled."; + DeviceProperties properties; + + properties = GetLocalGPUInfo(CudaGpuId(0)); + EXPECT_EQ("GPU", properties.type()); + + properties = GetLocalGPUInfo(CudaGpuId(100)); + EXPECT_EQ("GPU", properties.type()); +#endif +} + +TEST(UtilsTest, GetDeviceInfo) { + GpuIdManager::TestOnlyReset(); + DeviceNameUtils::ParsedName device; + DeviceProperties properties; + + // Invalid type. + properties = GetDeviceInfo(device); + EXPECT_EQ("UNKNOWN", properties.type()); + + // Cpu info. + device.type = "CPU"; + properties = GetDeviceInfo(device); + EXPECT_EQ("CPU", properties.type()); + + // No TF GPU id provided. + device.type = "GPU"; + device.has_id = false; + properties = GetDeviceInfo(device); + EXPECT_EQ("GPU", properties.type()); +#if GOOGLE_CUDA + EXPECT_EQ("NVIDIA", properties.vendor()); +#endif + + // TF to CUDA GPU id mapping entry doesn't exist. + device.has_id = true; + device.id = 0; + properties = GetDeviceInfo(device); + EXPECT_EQ("UNKNOWN", properties.type()); + +#if GOOGLE_CUDA + // Invalid CUDA GPU id. + GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(0), CudaGpuId(100)); + properties = GetDeviceInfo(device); + EXPECT_EQ("UNKNOWN", properties.type()); + + // Valid CUDA GPU id. + GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(1), CudaGpuId(0)); + device.id = 1; + properties = GetDeviceInfo(device); + EXPECT_EQ("GPU", properties.type()); + EXPECT_EQ("NVIDIA", properties.vendor()); +#endif +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 0fe01e9..5336df1 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -142,6 +142,7 @@ tf_cuda_library( "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:gpu_id", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 602f69f..076945d 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -26,6 +26,8 @@ limitations under the License. #include "cuda/include/cudnn.h" #endif +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" @@ -200,17 +202,25 @@ std::vector FindInputFeatures( } DeviceProperties GetDeviceInfo(const string& device_str) { + DeviceProperties unknown; + unknown.set_type("UNKNOWN"); + DeviceNameUtils::ParsedName parsed; if (DeviceNameUtils::ParseFullName(device_str, &parsed)) { if (parsed.type == "GPU") { - return GetLocalGPUInfo(parsed.id); + TfGpuId tf_gpu_id(parsed.id); + CudaGpuId cuda_gpu_id; + Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); + if (!s.ok()) { + LOG(ERROR) << s; + return unknown; + } + return GetLocalGPUInfo(cuda_gpu_id); } else if (parsed.type == "CPU") { return GetLocalCPUInfo(); } } - DeviceProperties device; - device.set_type("UNKNOWN"); - return device; + return unknown; } DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node) { -- 2.7.4