}
Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
- if (hlo_module_config_.replica_count() == 1) {
- // When there is a single replica, a cross replica sum is the identity
- // function, and the buffer assignment expects a copy (we could eliminate
- // these at the HLO level as an optimization).
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
+ if (hlo_module_config_.replica_count() != 1) {
+ // TODO(b/33011107): Support nontrivial cross replica sum on CPU.
+ return Unimplemented(
+ "CrossReplicaSum with >1 replica is not implemented on CPU.");
+ }
+
+ // When there is a single replica, a cross replica sum is the identity
+ // function, and the buffer assignment expects a copy.
+ //
+ // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
+ // in algebraic-simplifier, but currently on some platforms
+ // HloModuleConfig::num_replicas changes between when the module is compiled
+ // and when it's run.
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
+
+ // CRS with one operand and one replica is simply the identity function.
+ if (crs->operand_count() == 1) {
return EmitMemcpy(*crs->operand(0), *crs);
}
- // TODO(b/33011107): Support cross replica sum on CPU.
- return Unimplemented("CrossReplicaSum is not implemented on CPU.");
+ // CRS with multiple operands and one replica produces a (one-deep) tuple.
+ std::vector<llvm::Value*> operand_ptrs;
+ for (int64 i = 0; i < crs->operand_count(); ++i) {
+ llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i));
+ TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
+ assignment_.GetUniqueSlice(crs, {i}));
+
+ const Shape& operand_shape = crs->operand(i)->shape();
+ CHECK(ShapeUtil::IsArray(operand_shape))
+ << "Operands to cross-replica-sum must be arrays: " << crs->ToString();
+ operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape));
+
+ // TODO(b/63762267): Be more aggressive about specifying alignment.
+ ir_builder_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
+ /*SrcAlign=*/1,
+ ShapeUtil::ByteSizeOf(operand_shape));
+ }
+ llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &ir_builder_, module_);
+ return Status::OK();
}
// Fills up the free variables in 'index_with_free_var' with values from
for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
}
- } else if (ImplementedAsLibraryCall(*hlo)) {
- // For all other library calls, materialize all the operands into memory.
+ } else if (ImplementedAsLibraryCall(*hlo) ||
+ hlo->opcode() == HloOpcode::kCrossReplicaSum) {
+ // For all other library calls and cross-replica-sum, materialize all the
+ // operands into memory. (Cross-replica-sum gets its constant args
+ // materialized even if it's not implemented as a libcall to simplify the
+ // implementation. It's slower, but we can constant fold away constant
+ // args *anyway*, so we just need to make it work.)
for (int64 i = 0; i < hlo->operand_count(); ++i) {
TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
}
return IrEmitter::HandleSelect(select);
}
+Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
+ if (hlo_module_config_.replica_count() != 1) {
+ // TODO(b/33011107): Support nontrivial cross replica sum on GPU.
+ return Unimplemented(
+ "CrossReplicaSum with >1 replica is not implemented on GPU.");
+ }
+
+ // CRS with one operand and one replica is simply the identity function.
+ // Buffer assignment expects a copy, so that's what we do.
+ //
+ // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
+ // in algebraic-simplifier, but currently on some platforms
+ // HloModuleConfig::num_replicas changes between when the module is compiled
+ // and when it's run.
+ if (crs->operand_count() == 1) {
+ CHECK(ShapeUtil::IsArray(crs->operand(0)->shape()))
+ << "Operands to cross-replica-sum must be arrays: " << crs->ToString();
+ thunk_sequence_->push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ /*source_address=*/GetAllocationSlice(*crs->operand(0)),
+ /*destination_buffer=*/GetAllocationSlice(*crs),
+ /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs));
+ return Status::OK();
+ }
+
+ // One-replica CRS with multiple operands produces a tuple of the inputs.
+ // Again, buffer assignment expects us to copy each.
+ std::vector<std::unique_ptr<Thunk>> thunks;
+ std::vector<BufferAllocation::Slice> tuple_element_buffers;
+ for (int64 i = 0; i < crs->operand_count(); ++i) {
+ tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment()
+ .GetUniqueSlice(crs, {i})
+ .ValueOrDie());
+ thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ /*source_address=*/GetAllocationSlice(*crs->operand(i)),
+ /*destination_buffer=*/tuple_element_buffers.back(),
+ /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs));
+ }
+
+ // Output a tuple of the buffers above.
+ thunks.push_back(MakeUnique<TupleThunk>(tuple_element_buffers,
+ GetAllocationSlice(*crs), crs));
+ thunk_sequence_->push_back(
+ MakeUnique<SequentialThunk>(std::move(thunks), crs));
+ return Status::OK();
+}
+
Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) {
thunk_sequence_->emplace_back(BuildInfeedThunk(infeed));
return Status::OK();
Status HandleInfeed(HloInstruction* xla_infeed) override;
Status HandleRng(HloInstruction* random) override;
Status HandleSelect(HloInstruction* select) override;
+ Status HandleCrossReplicaSum(HloInstruction* crs) override;
Status EmitTargetElementLoop(
const HloInstruction& hlo,
)
xla_test(
+ name = "cross_replica_sum_test",
+ srcs = ["cross_replica_sum_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
name = "bitcast_convert_test",
srcs = ["bitcast_convert_test.cc"],
tags = [
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+
+namespace xla {
+namespace {
+
+class TrivialCrossReplicaSumTest : public HloTestBase {};
+
+// Currently the CPU and GPU backends only support CrossReplicaSum with one
+// replica. But we can at least check this.
+
+XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
+ const char* module_str = R"(
+ HloModule test
+ ENTRY test_computation {
+ p = f32[3] parameter(0)
+ ROOT crs = f32[3] cross-replica-sum(p)
+ })";
+ auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie();
+ auto literal = Literal::CreateR1<float>({1, 2, 3});
+ EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
+}
+
+XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
+ const char* module_str = R"(
+ HloModule test
+ ENTRY test_computation {
+ p0 = f32[3] parameter(0)
+ p1 = f32[2] parameter(1)
+ ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1)
+ })";
+ auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie();
+ auto literal0 = Literal::CreateR1<float>({1, 2, 3});
+ auto literal1 = Literal::CreateR1<float>({10, 20});
+ EXPECT_EQ(
+ *Literal::MakeTuple({literal0.get(), literal1.get()}),
+ *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()}));
+}
+
+// On the GPU backend, constants get special handling. Someone might pass a
+// constant to CRS to e.g. count the number of replicas -- we need to make sure
+// it works.
+XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
+ const char* module_str = R"(
+ HloModule test
+ ENTRY test_computation {
+ p0 = f32[3] parameter(0)
+ p1 = f32[2] constant({10, 20})
+ ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1)
+ })";
+ auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie();
+ auto literal0 = Literal::CreateR1<float>({1, 2, 3});
+ auto literal1 = Literal::CreateR1<float>({10, 20});
+ EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}),
+ *ExecuteAndTransfer(std::move(module), {literal0.get()}));
+}
+
+} // namespace
+} // namespace xla
/* static */
std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
- HloModuleConfig config;
- config.set_debug_options(GetDebugOptionsForTest());
- return MakeUnique<HloModule>(name, VersionedComputationHandle(), config);
+ return MakeUnique<HloModule>(name, VersionedComputationHandle(),
+ GetModuleConfigForTest());
}
/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() {
// DebugOptions, e.g. when creating a module from a string or a file.
static DebugOptions GetDebugOptionsForTest();
+ // Gets an HloModuleConfig with options appropriate for tests.
+ static HloModuleConfig GetModuleConfigForTest() {
+ HloModuleConfig config;
+ config.set_debug_options(GetDebugOptionsForTest());
+ return config;
+ }
+
// Executes the given module and return the result as a Literal.
StatusOr<std::unique_ptr<Literal>> Execute(
std::unique_ptr<HloModule> module,