GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
CudaGpuId cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id);
int numa_node = dev_locality.numa_node();
- Bytes allocated_bytes = static_cast<Bytes>(memory_limit);
gpu::StreamExecutor* se =
GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
const gpu::DeviceDescription& desc = se->GetDeviceDescription();
- LOG(INFO) << "Creating TensorFlow device (" << device_name << " with "
- << (memory_limit >> 20) << " MB memory) -> physical GPU ("
- << GetShortDeviceDescription(cuda_gpu_id, desc) << ")";
ProcessState* process_state = ProcessState::singleton();
+ Allocator* gpu_allocator = process_state->GetGPUAllocator(
+ options.config.gpu_options(), tf_gpu_id, memory_limit);
+ if (gpu_allocator == nullptr) {
+ return errors::Internal("Failed to get memory allocator for TF GPU ",
+ tf_gpu_id.value(), " with ", memory_limit,
+ " bytes of memory.");
+ }
+ AllocatorStats stats;
+ gpu_allocator->GetStats(&stats);
+ // 'memory_limit' is the required memory size, but if the allocator with given
+ // tf_gpu_id was created before, we'll use it instead of creating a new one
+ // (as TF gpu device is a shared resource), in which case the actual memory
+ // limit represented by 'stats.bytes_limit' used by that allocator may be
+ // different (which should be an error).
+ //
+ // TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit.
BaseGPUDevice* gpu_device = CreateGPUDevice(
- options, device_name, allocated_bytes, dev_locality, tf_gpu_id,
- GetShortDeviceDescription(cuda_gpu_id, desc),
- process_state->GetGPUAllocator(options.config.gpu_options(), tf_gpu_id,
- memory_limit),
+ options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality,
+ tf_gpu_id, GetShortDeviceDescription(cuda_gpu_id, desc), gpu_allocator,
process_state->GetCPUAllocator(numa_node));
+ LOG(INFO) << "Created TensorFlow device (" << device_name << " with "
+ << (stats.bytes_limit >> 20) << " MB memory) -> physical GPU ("
+ << GetShortDeviceDescription(cuda_gpu_id, desc) << ")";
TF_RETURN_IF_ERROR(gpu_device->Init(options));
devices->push_back(gpu_device);