From d0f9424e22eb438f3d846fa62feaf331797e62c4 Mon Sep 17 00:00:00 2001 From: HyoukJoong Lee Date: Wed, 30 May 2018 18:43:40 -0700 Subject: [PATCH] Automated g4 rollback of changelist 195379693 PiperOrigin-RevId: 198654780 --- .../compiler/xla/service/hlo_module_group_metadata.cc | 7 +++++++ tensorflow/compiler/xla/service/hlo_module_group_metadata.h | 3 +++ tensorflow/compiler/xla/service/service.cc | 13 ++++++++++--- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 7d706b5..f6fa45a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -247,6 +247,13 @@ tensorflow::gtl::optional HloModuleGroupMetadata::GetInstructionDevice( return device; } +int64 HloModuleGroupMetadata::GetDeviceModulesCount() const { + return std::count_if(modules_.begin(), modules_.end(), + [](const HloModule* module) { + return !module->config().is_host_module(); + }); +} + Status HloModuleGroupMetadata::RecordInstructions() { const auto visitor = [this](HloInstruction* hlo) -> Status { if (hlo->opcode() == HloOpcode::kWhile) { diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 5f5bf27..f68d402 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -155,6 +155,9 @@ class HloModuleGroupMetadata { tensorflow::gtl::optional GetInstructionDevice( const HloInstruction& instruction) const; + // Returns the number of modules for devices (excluding the host module). + int64 GetDeviceModulesCount() const; + // Returns the companion instructions for the given instruction. // // Precondition: IsCompanionWhile(instruction) is true. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index cb0f76e..5a813dc 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -624,9 +624,16 @@ Service::ExecuteParallelAndRegisterResult( // profiled. std::map index_to_profiled_streams; - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - backend->computation_placer()->AssignDevices( - options_.number_of_replicas(), executables.size())); + // Build DeviceAssignment for all cores based on the provided device handles. + DeviceAssignment device_assignment(options_.number_of_replicas(), + executables.size()); + for (int64 i = 0; i < executables.size(); i++) { + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + CHECK_EQ(replicas.size(), arguments[i].size()); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + device_assignment(replica, i) = replicas[replica]->device_ordinal(); + } + } for (int64 i = 0; i < executables.size(); i++) { // Stream executors for the replicas of the current computation. -- 2.7.4