llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
}
-// Tries to get a Slice for the given instruction at the given index, but
-// returns nullopt if we might not know the slice's address at runtime without
-// dereferencing a containing tuple.
-//
-// In particular, when XLA accepts a parameter of tuple type, the caller has the
-// option of telling XLA what are the values inside of the tuple, or just giving
-// XLA a pointer to the top-level tuple and letting us chase the pointers on the
-// GPU. We therefore cannot rely having these pointers to parameter sub-buffers
-// being present when we run the program.
-optional<BufferAllocation::Slice> GetKnownAtRuntimeSlice(
- const HloInstruction* instr, const ShapeIndex& index,
- const BufferAssignment& buffer_assn) {
- auto maybe_slice = buffer_assn.GetUniqueSlice(instr, index);
- if (!maybe_slice.ok()) {
- return nullopt;
- }
- // BufferAllocation gives a slice and alloc to every buffer accessed by XLA,
- // but we don't necessarily know the runtime address of sub-buffers of input
- // parameters.
- const BufferAllocation::Slice& slice = maybe_slice.ValueOrDie();
- const BufferAllocation* alloc = slice.allocation();
- if (alloc->IsInputOrOutput() && !alloc->maybe_live_out() &&
- !alloc->param_shape_index().empty()) {
- return nullopt;
- }
-
- // Otherwise, we will know the address of this slice at runtime without having
- // to dereference a tuple.
- return slice;
-}
-
} // namespace
IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
return hlo.opcode() == HloOpcode::kCopy &&
hlo.operand(0)->opcode() == HloOpcode::kConstant &&
ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) &&
- GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value();
+ buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok();
}
bool ImplementedAsDeviceToDeviceMemcpy(
//
// 1. `hlo` is a kCopy instruction.
// 2. `hlo` and its operand have the same shape (thus the same layout too).
- // 3. The operand to `hlo` has a buffer assignment (constants do not, for
- // instance) which means the source buffer also resides on the device.
+ // 3. `hlo` and its operand have a statically-known buffer assignment
+ // (constants do not, for instance), which means the source buffer also
+ // resides on the device.
return hlo.opcode() == HloOpcode::kCopy &&
ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) &&
- GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value() &&
- GetKnownAtRuntimeSlice(hlo.operand(0), {}, buffer_assignment)
- .has_value();
+ buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok() &&
+ buffer_assignment.GetUniqueTopLevelSlice(hlo.operand(0)).ok();
}
} // namespace
-> optional<std::pair<BufferAllocation::Slice, ShapeIndex>> {
// Simple, common case: Is the buffer for instr known at runtime? If so,
// we're done.
- auto slice = GetKnownAtRuntimeSlice(instr, index, buffer_assn);
- if (slice.has_value()) {
- return {{*slice, ShapeIndex()}};
+ auto slice = buffer_assn.GetUniqueSlice(instr, index);
+ if (slice.ok()) {
+ return {{slice.ValueOrDie(), ShapeIndex()}};
}
- // If we don't know the buffer for instr at index, see if we know the buffer
- // for instr at index without its last element. If so, we can dynamically
- // find the buffer for instr by dereferencing a pointer in that buffer.
- // Continue looking this way until we run out of elements in 'index'.
- ShapeIndex new_index = index;
- ShapeIndex gte_indices;
- while (!new_index.empty()) {
- gte_indices.push_front(new_index.back());
- new_index.pop_back();
- auto slice = GetKnownAtRuntimeSlice(instr, new_index, buffer_assn);
- if (slice.has_value()) {
- return {{*slice, gte_indices}};
- }
- }
-
- // If *that* didn't work, walk up any bitcasts that we might see. These
- // must appear before any GTE instructions, because it's illegal to bitcast
- // to a tuple type.
+ // If that didn't work, walk up any bitcasts that we might see. These must
+ // appear before any GTE instructions, because it's illegal to bitcast to a
+ // tuple type.
const HloInstruction* parent = instr;
while (parent->opcode() == HloOpcode::kBitcast) {
parent = parent->operand(0);
- auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn);
- if (slice.has_value()) {
- return {{*slice, gte_indices}};
+ auto slice = buffer_assn.GetUniqueSlice(parent, {});
+ if (slice.ok()) {
+ return {{slice.ValueOrDie(), ShapeIndex()}};
}
}
- // Finally, check whether instr is a GTE instruction. If it is, see if we
- // can get a buffer for its parent, and continue walking up parents until we
- // find a defined buffer or we hit something that's not a GTE.
+ // Check whether instr is a GTE instruction. If it is, see if we can get a
+ // buffer for its parent, and continue walking up parents until we find a
+ // defined buffer or we hit something that's not a GTE.
+ ShapeIndex gte_indices;
while (parent->opcode() == HloOpcode::kGetTupleElement) {
gte_indices.push_front(parent->tuple_index());
parent = parent->operand(0);
- auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn);
- if (slice.has_value()) {
- return {{*slice, gte_indices}};
+ auto slice = buffer_assn.GetUniqueSlice(parent, {});
+ if (slice.ok()) {
+ return {{slice.ValueOrDie(), gte_indices}};
+ }
+ }
+
+ // Finally, if we don't know the buffer for instr at index, see if we know
+ // the buffer for instr at index without its last element. If so, we can
+ // dynamically find the buffer for instr by dereferencing a pointer in that
+ // buffer. Continue looking this way until we run out of elements in
+ // 'index'.
+ //
+ // We can almost always get a buffer without resorting to this. The only
+ // exception is for cases where the relevant sub-buffer is truly unknowable,
+ // for example the sub-buffer of a tuple-shaped select.
+ ShapeIndex new_index = index;
+ while (!new_index.empty()) {
+ gte_indices.push_front(new_index.back());
+ new_index.pop_back();
+ auto slice = buffer_assn.GetUniqueSlice(instr, new_index);
+ if (slice.ok()) {
+ return {{slice.ValueOrDie(), gte_indices}};
}
}