[XLA:GPU] Implement trivial (one-replica) cross-replica-sum on XLA:GPU.
authorJustin Lebar <jlebar@google.com>
Tue, 22 May 2018 03:41:26 +0000 (20:41 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 03:43:56 +0000 (20:43 -0700)
Also fix the CPU implementation to work in the case when there are
multiple operands to the cross-replica-sum op.

PiperOrigin-RevId: 197506311

tensorflow/compiler/xla/service/cpu/ir_emitter.cc
tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
tensorflow/compiler/xla/tests/BUILD
tensorflow/compiler/xla/tests/cross_replica_sum_test.cc [new file with mode: 0644]
tensorflow/compiler/xla/tests/hlo_test_base.cc
tensorflow/compiler/xla/tests/hlo_test_base.h

index 23fcb9c..f6c8593 100644 (file)
@@ -1186,16 +1186,45 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
 }
 
 Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
-  if (hlo_module_config_.replica_count() == 1) {
-    // When there is a single replica, a cross replica sum is the identity
-    // function, and the buffer assignment expects a copy (we could eliminate
-    // these at the HLO level as an optimization).
-    TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
+  if (hlo_module_config_.replica_count() != 1) {
+    // TODO(b/33011107): Support nontrivial cross replica sum on CPU.
+    return Unimplemented(
+        "CrossReplicaSum with >1 replica is not implemented on CPU.");
+  }
+
+  // When there is a single replica, a cross replica sum is the identity
+  // function, and the buffer assignment expects a copy.
+  //
+  // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
+  // in algebraic-simplifier, but currently on some platforms
+  // HloModuleConfig::num_replicas changes between when the module is compiled
+  // and when it's run.
+  TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
+
+  // CRS with one operand and one replica is simply the identity function.
+  if (crs->operand_count() == 1) {
     return EmitMemcpy(*crs->operand(0), *crs);
   }
 
-  // TODO(b/33011107): Support cross replica sum on CPU.
-  return Unimplemented("CrossReplicaSum is not implemented on CPU.");
+  // CRS with multiple operands and one replica produces a (one-deep) tuple.
+  std::vector<llvm::Value*> operand_ptrs;
+  for (int64 i = 0; i < crs->operand_count(); ++i) {
+    llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i));
+    TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
+                        assignment_.GetUniqueSlice(crs, {i}));
+
+    const Shape& operand_shape = crs->operand(i)->shape();
+    CHECK(ShapeUtil::IsArray(operand_shape))
+        << "Operands to cross-replica-sum must be arrays: " << crs->ToString();
+    operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape));
+
+    // TODO(b/63762267): Be more aggressive about specifying alignment.
+    ir_builder_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
+                             /*SrcAlign=*/1,
+                             ShapeUtil::ByteSizeOf(operand_shape));
+  }
+  llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &ir_builder_, module_);
+  return Status::OK();
 }
 
 // Fills up the free variables in 'index_with_free_var' with values from
index 9db85bc..d956077 100644 (file)
@@ -84,8 +84,13 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
       for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
         TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
       }
-    } else if (ImplementedAsLibraryCall(*hlo)) {
-      // For all other library calls, materialize all the operands into memory.
+    } else if (ImplementedAsLibraryCall(*hlo) ||
+               hlo->opcode() == HloOpcode::kCrossReplicaSum) {
+      // For all other library calls and cross-replica-sum, materialize all the
+      // operands into memory.  (Cross-replica-sum gets its constant args
+      // materialized even if it's not implemented as a libcall to simplify the
+      // implementation.  It's slower, but we can constant fold away constant
+      // args *anyway*, so we just need to make it work.)
       for (int64 i = 0; i < hlo->operand_count(); ++i) {
         TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
       }
index 957733f..55d4c1d 100644 (file)
@@ -1927,6 +1927,52 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
   return IrEmitter::HandleSelect(select);
 }
 
+Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
+  if (hlo_module_config_.replica_count() != 1) {
+    // TODO(b/33011107): Support nontrivial cross replica sum on GPU.
+    return Unimplemented(
+        "CrossReplicaSum with >1 replica is not implemented on GPU.");
+  }
+
+  // CRS with one operand and one replica is simply the identity function.
+  // Buffer assignment expects a copy, so that's what we do.
+  //
+  // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
+  // in algebraic-simplifier, but currently on some platforms
+  // HloModuleConfig::num_replicas changes between when the module is compiled
+  // and when it's run.
+  if (crs->operand_count() == 1) {
+    CHECK(ShapeUtil::IsArray(crs->operand(0)->shape()))
+        << "Operands to cross-replica-sum must be arrays: " << crs->ToString();
+    thunk_sequence_->push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+        /*source_address=*/GetAllocationSlice(*crs->operand(0)),
+        /*destination_buffer=*/GetAllocationSlice(*crs),
+        /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs));
+    return Status::OK();
+  }
+
+  // One-replica CRS with multiple operands produces a tuple of the inputs.
+  // Again, buffer assignment expects us to copy each.
+  std::vector<std::unique_ptr<Thunk>> thunks;
+  std::vector<BufferAllocation::Slice> tuple_element_buffers;
+  for (int64 i = 0; i < crs->operand_count(); ++i) {
+    tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment()
+                                        .GetUniqueSlice(crs, {i})
+                                        .ValueOrDie());
+    thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+        /*source_address=*/GetAllocationSlice(*crs->operand(i)),
+        /*destination_buffer=*/tuple_element_buffers.back(),
+        /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs));
+  }
+
+  // Output a tuple of the buffers above.
+  thunks.push_back(MakeUnique<TupleThunk>(tuple_element_buffers,
+                                          GetAllocationSlice(*crs), crs));
+  thunk_sequence_->push_back(
+      MakeUnique<SequentialThunk>(std::move(thunks), crs));
+  return Status::OK();
+}
+
 Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) {
   thunk_sequence_->emplace_back(BuildInfeedThunk(infeed));
   return Status::OK();
index b842f48..14780de 100644 (file)
@@ -76,6 +76,7 @@ class IrEmitterUnnested : public IrEmitter {
   Status HandleInfeed(HloInstruction* xla_infeed) override;
   Status HandleRng(HloInstruction* random) override;
   Status HandleSelect(HloInstruction* select) override;
+  Status HandleCrossReplicaSum(HloInstruction* crs) override;
 
   Status EmitTargetElementLoop(
       const HloInstruction& hlo,
index 95acfe5..4883380 100644 (file)
@@ -1496,6 +1496,30 @@ xla_test(
 )
 
 xla_test(
+    name = "cross_replica_sum_test",
+    srcs = ["cross_replica_sum_test.cc"],
+    deps = [
+        "//tensorflow/compiler/xla:literal_util",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:test_helpers",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/client:local_client",
+        "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+        "//tensorflow/compiler/xla/client/xla_client:xla_computation",
+        "//tensorflow/compiler/xla/tests:client_library_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:literal_test_util",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:stream_executor_no_cuda",
+        "//tensorflow/core:test",
+    ],
+)
+
+xla_test(
     name = "bitcast_convert_test",
     srcs = ["bitcast_convert_test.cc"],
     tags = [
diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
new file mode 100644 (file)
index 0000000..b159887
--- /dev/null
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+
+namespace xla {
+namespace {
+
+class TrivialCrossReplicaSumTest : public HloTestBase {};
+
+// Currently the CPU and GPU backends only support CrossReplicaSum with one
+// replica.  But we can at least check this.
+
+XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
+  const char* module_str = R"(
+  HloModule test
+  ENTRY test_computation {
+    p = f32[3] parameter(0)
+    ROOT crs = f32[3] cross-replica-sum(p)
+  })";
+  auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie();
+  auto literal = Literal::CreateR1<float>({1, 2, 3});
+  EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
+}
+
+XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
+  const char* module_str = R"(
+  HloModule test
+  ENTRY test_computation {
+    p0 = f32[3] parameter(0)
+    p1 = f32[2] parameter(1)
+    ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1)
+  })";
+  auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie();
+  auto literal0 = Literal::CreateR1<float>({1, 2, 3});
+  auto literal1 = Literal::CreateR1<float>({10, 20});
+  EXPECT_EQ(
+      *Literal::MakeTuple({literal0.get(), literal1.get()}),
+      *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()}));
+}
+
+// On the GPU backend, constants get special handling.  Someone might pass a
+// constant to CRS to e.g. count the number of replicas -- we need to make sure
+// it works.
+XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
+  const char* module_str = R"(
+  HloModule test
+  ENTRY test_computation {
+    p0 = f32[3] parameter(0)
+    p1 = f32[2] constant({10, 20})
+    ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1)
+  })";
+  auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie();
+  auto literal0 = Literal::CreateR1<float>({1, 2, 3});
+  auto literal1 = Literal::CreateR1<float>({10, 20});
+  EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}),
+            *ExecuteAndTransfer(std::move(module), {literal0.get()}));
+}
+
+}  // namespace
+}  // namespace xla
index d964875..36e19e6 100644 (file)
@@ -94,9 +94,8 @@ HloTestBase::HloTestBase(se::Platform* test_platform,
 
 /* static */
 std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
-  HloModuleConfig config;
-  config.set_debug_options(GetDebugOptionsForTest());
-  return MakeUnique<HloModule>(name, VersionedComputationHandle(), config);
+  return MakeUnique<HloModule>(name, VersionedComputationHandle(),
+                               GetModuleConfigForTest());
 }
 
 /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() {
index 9539ae0..eb3a2ea 100644 (file)
@@ -93,6 +93,13 @@ class HloTestBase : public ::testing::Test {
   // DebugOptions, e.g. when creating a module from a string or a file.
   static DebugOptions GetDebugOptionsForTest();
 
+  // Gets an HloModuleConfig with options appropriate for tests.
+  static HloModuleConfig GetModuleConfigForTest() {
+    HloModuleConfig config;
+    config.set_debug_options(GetDebugOptionsForTest());
+    return config;
+  }
+
   // Executes the given module and return the result as a Literal.
   StatusOr<std::unique_ptr<Literal>> Execute(
       std::unique_ptr<HloModule> module,