From 9a7a63aff142658db6d54027815a54a267be808a Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Thu, 29 Mar 2018 10:30:31 -0700 Subject: [PATCH] [XLA:GPU] Assume that tuple sub-buffers are available at runtime. Previously we assumed this was not the case, and allowed front-ends to pass in a pointer to tuple without also passing in pointers to sub-buffers. This mostly worked: Whenever we wanted a tuple sub-buffer, we'd just chase the tuple's pointers in our emitted kernel. But this doesn't work if we ever need a pointer to that sub-buffer on the host. Which we do if e.g. the sub-buffer is an input to a cudnn call. There are various ways to make this work, but by far the simplest and most efficient is simply to specify away this problem, and say that the front-end *must* give us all the pointers we want. This is what the earlier change, "Assert that all buffers and sub-buffers passed to XLA have an explicit pointer" did. This change adds a testcase and lets us skip some pointer chasing when we have a tuple whose sub-buffers are known statically. PiperOrigin-RevId: 190949743 --- .../xla/service/gpu/ir_emitter_unnested.cc | 108 ++++++++------------- .../compiler/xla/tests/dot_operation_test.cc | 19 ++++ 2 files changed, 60 insertions(+), 67 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 199e6b7..d29cc21 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -145,37 +145,6 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } -// Tries to get a Slice for the given instruction at the given index, but -// returns nullopt if we might not know the slice's address at runtime without -// dereferencing a containing tuple. -// -// In particular, when XLA accepts a parameter of tuple type, the caller has the -// option of telling XLA what are the values inside of the tuple, or just giving -// XLA a pointer to the top-level tuple and letting us chase the pointers on the -// GPU. We therefore cannot rely having these pointers to parameter sub-buffers -// being present when we run the program. -optional GetKnownAtRuntimeSlice( - const HloInstruction* instr, const ShapeIndex& index, - const BufferAssignment& buffer_assn) { - auto maybe_slice = buffer_assn.GetUniqueSlice(instr, index); - if (!maybe_slice.ok()) { - return nullopt; - } - // BufferAllocation gives a slice and alloc to every buffer accessed by XLA, - // but we don't necessarily know the runtime address of sub-buffers of input - // parameters. - const BufferAllocation::Slice& slice = maybe_slice.ValueOrDie(); - const BufferAllocation* alloc = slice.allocation(); - if (alloc->IsInputOrOutput() && !alloc->maybe_live_out() && - !alloc->param_shape_index().empty()) { - return nullopt; - } - - // Otherwise, we will know the address of this slice at runtime without having - // to dereference a tuple. - return slice; -} - } // namespace IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, @@ -206,7 +175,7 @@ bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment, return hlo.opcode() == HloOpcode::kCopy && hlo.operand(0)->opcode() == HloOpcode::kConstant && ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value(); + buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok(); } bool ImplementedAsDeviceToDeviceMemcpy( @@ -216,13 +185,13 @@ bool ImplementedAsDeviceToDeviceMemcpy( // // 1. `hlo` is a kCopy instruction. // 2. `hlo` and its operand have the same shape (thus the same layout too). - // 3. The operand to `hlo` has a buffer assignment (constants do not, for - // instance) which means the source buffer also resides on the device. + // 3. `hlo` and its operand have a statically-known buffer assignment + // (constants do not, for instance), which means the source buffer also + // resides on the device. return hlo.opcode() == HloOpcode::kCopy && ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value() && - GetKnownAtRuntimeSlice(hlo.operand(0), {}, buffer_assignment) - .has_value(); + buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok() && + buffer_assignment.GetUniqueTopLevelSlice(hlo.operand(0)).ok(); } } // namespace @@ -1959,49 +1928,54 @@ GetHloBufferSlices(const HloInstruction* hlo, -> optional> { // Simple, common case: Is the buffer for instr known at runtime? If so, // we're done. - auto slice = GetKnownAtRuntimeSlice(instr, index, buffer_assn); - if (slice.has_value()) { - return {{*slice, ShapeIndex()}}; + auto slice = buffer_assn.GetUniqueSlice(instr, index); + if (slice.ok()) { + return {{slice.ValueOrDie(), ShapeIndex()}}; } - // If we don't know the buffer for instr at index, see if we know the buffer - // for instr at index without its last element. If so, we can dynamically - // find the buffer for instr by dereferencing a pointer in that buffer. - // Continue looking this way until we run out of elements in 'index'. - ShapeIndex new_index = index; - ShapeIndex gte_indices; - while (!new_index.empty()) { - gte_indices.push_front(new_index.back()); - new_index.pop_back(); - auto slice = GetKnownAtRuntimeSlice(instr, new_index, buffer_assn); - if (slice.has_value()) { - return {{*slice, gte_indices}}; - } - } - - // If *that* didn't work, walk up any bitcasts that we might see. These - // must appear before any GTE instructions, because it's illegal to bitcast - // to a tuple type. + // If that didn't work, walk up any bitcasts that we might see. These must + // appear before any GTE instructions, because it's illegal to bitcast to a + // tuple type. const HloInstruction* parent = instr; while (parent->opcode() == HloOpcode::kBitcast) { parent = parent->operand(0); - auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn); - if (slice.has_value()) { - return {{*slice, gte_indices}}; + auto slice = buffer_assn.GetUniqueSlice(parent, {}); + if (slice.ok()) { + return {{slice.ValueOrDie(), ShapeIndex()}}; } } - // Finally, check whether instr is a GTE instruction. If it is, see if we - // can get a buffer for its parent, and continue walking up parents until we - // find a defined buffer or we hit something that's not a GTE. + // Check whether instr is a GTE instruction. If it is, see if we can get a + // buffer for its parent, and continue walking up parents until we find a + // defined buffer or we hit something that's not a GTE. + ShapeIndex gte_indices; while (parent->opcode() == HloOpcode::kGetTupleElement) { gte_indices.push_front(parent->tuple_index()); parent = parent->operand(0); - auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn); - if (slice.has_value()) { - return {{*slice, gte_indices}}; + auto slice = buffer_assn.GetUniqueSlice(parent, {}); + if (slice.ok()) { + return {{slice.ValueOrDie(), gte_indices}}; + } + } + + // Finally, if we don't know the buffer for instr at index, see if we know + // the buffer for instr at index without its last element. If so, we can + // dynamically find the buffer for instr by dereferencing a pointer in that + // buffer. Continue looking this way until we run out of elements in + // 'index'. + // + // We can almost always get a buffer without resorting to this. The only + // exception is for cases where the relevant sub-buffer is truly unknowable, + // for example the sub-buffer of a tuple-shaped select. + ShapeIndex new_index = index; + while (!new_index.empty()) { + gte_indices.push_front(new_index.back()); + new_index.pop_back(); + auto slice = buffer_assn.GetUniqueSlice(instr, new_index); + if (slice.ok()) { + return {{slice.ValueOrDie(), gte_indices}}; } } diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 09b1dd2..7b994a4 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -54,6 +54,25 @@ using TypesF16F32F64CF64 = #error "Situation not handled yet" #endif +// Check that we can safely pass an input tuple's elements to a dot operation. +TEST_F(DotOperationTest, DotOfInputTupleElem) { + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle param; + auto param_data = CreateParameterAndTransferLiteral( + 0, + *Literal::MakeTuple({Literal::CreateR2({{1, 2}, {3, 4}}).get(), + Literal::CreateR2({{5, 6}, {7, 8}}).get()}), + "arg0", &builder, ¶m); + auto lhs = builder.GetTupleElement(param, 0); + auto rhs = builder.GetTupleElement(param, 1); + builder.Dot(lhs, rhs); + + ComputeAndCompareLiteral(&builder, + *Literal::CreateR2({{19, 22}, {43, 50}}), + {param_data.get()}); +} + template class DotOperationTest_F16F32F64CF64 : public DotOperationTest {}; TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64); -- 2.7.4