[XLA:GPU] Assume that tuple sub-buffers are available at runtime.
authorJustin Lebar <jlebar@google.com>
Thu, 29 Mar 2018 17:30:31 +0000 (10:30 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 29 Mar 2018 17:33:34 +0000 (10:33 -0700)
Previously we assumed this was not the case, and allowed front-ends to
pass in a pointer to tuple without also passing in pointers to
sub-buffers.

This mostly worked: Whenever we wanted a tuple sub-buffer, we'd just
chase the tuple's pointers in our emitted kernel.

But this doesn't work if we ever need a pointer to that sub-buffer on
the host.  Which we do if e.g. the sub-buffer is an input to a cudnn
call.

There are various ways to make this work, but by far the simplest and
most efficient is simply to specify away this problem, and say that the
front-end *must* give us all the pointers we want.  This is what the
earlier change, "Assert that all buffers and sub-buffers passed to XLA
have an explicit pointer" did.

This change adds a testcase and lets us skip some pointer chasing when
we have a tuple whose sub-buffers are known statically.

PiperOrigin-RevId: 190949743

tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
tensorflow/compiler/xla/tests/dot_operation_test.cc

index 199e6b7..d29cc21 100644 (file)
@@ -145,37 +145,6 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
        llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
 }
 
-// Tries to get a Slice for the given instruction at the given index, but
-// returns nullopt if we might not know the slice's address at runtime without
-// dereferencing a containing tuple.
-//
-// In particular, when XLA accepts a parameter of tuple type, the caller has the
-// option of telling XLA what are the values inside of the tuple, or just giving
-// XLA a pointer to the top-level tuple and letting us chase the pointers on the
-// GPU.  We therefore cannot rely having these pointers to parameter sub-buffers
-// being present when we run the program.
-optional<BufferAllocation::Slice> GetKnownAtRuntimeSlice(
-    const HloInstruction* instr, const ShapeIndex& index,
-    const BufferAssignment& buffer_assn) {
-  auto maybe_slice = buffer_assn.GetUniqueSlice(instr, index);
-  if (!maybe_slice.ok()) {
-    return nullopt;
-  }
-  // BufferAllocation gives a slice and alloc to every buffer accessed by XLA,
-  // but we don't necessarily know the runtime address of sub-buffers of input
-  // parameters.
-  const BufferAllocation::Slice& slice = maybe_slice.ValueOrDie();
-  const BufferAllocation* alloc = slice.allocation();
-  if (alloc->IsInputOrOutput() && !alloc->maybe_live_out() &&
-      !alloc->param_shape_index().empty()) {
-    return nullopt;
-  }
-
-  // Otherwise, we will know the address of this slice at runtime without having
-  // to dereference a tuple.
-  return slice;
-}
-
 }  // namespace
 
 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
@@ -206,7 +175,7 @@ bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment,
   return hlo.opcode() == HloOpcode::kCopy &&
          hlo.operand(0)->opcode() == HloOpcode::kConstant &&
          ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) &&
-         GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value();
+         buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok();
 }
 
 bool ImplementedAsDeviceToDeviceMemcpy(
@@ -216,13 +185,13 @@ bool ImplementedAsDeviceToDeviceMemcpy(
   //
   // 1. `hlo` is a kCopy instruction.
   // 2. `hlo` and its operand have the same shape (thus the same layout too).
-  // 3. The operand to `hlo` has a buffer assignment (constants do not, for
-  //    instance) which means the source buffer also resides on the device.
+  // 3. `hlo` and its operand have a statically-known buffer assignment
+  //     (constants do not, for instance), which means the source buffer also
+  //     resides on the device.
   return hlo.opcode() == HloOpcode::kCopy &&
          ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) &&
-         GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value() &&
-         GetKnownAtRuntimeSlice(hlo.operand(0), {}, buffer_assignment)
-             .has_value();
+         buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok() &&
+         buffer_assignment.GetUniqueTopLevelSlice(hlo.operand(0)).ok();
 }
 }  // namespace
 
@@ -1959,49 +1928,54 @@ GetHloBufferSlices(const HloInstruction* hlo,
       -> optional<std::pair<BufferAllocation::Slice, ShapeIndex>> {
     // Simple, common case: Is the buffer for instr known at runtime?  If so,
     // we're done.
-    auto slice = GetKnownAtRuntimeSlice(instr, index, buffer_assn);
-    if (slice.has_value()) {
-      return {{*slice, ShapeIndex()}};
+    auto slice = buffer_assn.GetUniqueSlice(instr, index);
+    if (slice.ok()) {
+      return {{slice.ValueOrDie(), ShapeIndex()}};
     }
 
-    // If we don't know the buffer for instr at index, see if we know the buffer
-    // for instr at index without its last element.  If so, we can dynamically
-    // find the buffer for instr by dereferencing a pointer in that buffer.
-    // Continue looking this way until we run out of elements in 'index'.
-    ShapeIndex new_index = index;
-    ShapeIndex gte_indices;
-    while (!new_index.empty()) {
-      gte_indices.push_front(new_index.back());
-      new_index.pop_back();
-      auto slice = GetKnownAtRuntimeSlice(instr, new_index, buffer_assn);
-      if (slice.has_value()) {
-        return {{*slice, gte_indices}};
-      }
-    }
-
-    // If *that* didn't work, walk up any bitcasts that we might see.  These
-    // must appear before any GTE instructions, because it's illegal to bitcast
-    // to a tuple type.
+    // If that didn't work, walk up any bitcasts that we might see.  These must
+    // appear before any GTE instructions, because it's illegal to bitcast to a
+    // tuple type.
     const HloInstruction* parent = instr;
     while (parent->opcode() == HloOpcode::kBitcast) {
       parent = parent->operand(0);
 
-      auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn);
-      if (slice.has_value()) {
-        return {{*slice, gte_indices}};
+      auto slice = buffer_assn.GetUniqueSlice(parent, {});
+      if (slice.ok()) {
+        return {{slice.ValueOrDie(), ShapeIndex()}};
       }
     }
 
-    // Finally, check whether instr is a GTE instruction.  If it is, see if we
-    // can get a buffer for its parent, and continue walking up parents until we
-    // find a defined buffer or we hit something that's not a GTE.
+    // Check whether instr is a GTE instruction.  If it is, see if we can get a
+    // buffer for its parent, and continue walking up parents until we find a
+    // defined buffer or we hit something that's not a GTE.
+    ShapeIndex gte_indices;
     while (parent->opcode() == HloOpcode::kGetTupleElement) {
       gte_indices.push_front(parent->tuple_index());
       parent = parent->operand(0);
 
-      auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn);
-      if (slice.has_value()) {
-        return {{*slice, gte_indices}};
+      auto slice = buffer_assn.GetUniqueSlice(parent, {});
+      if (slice.ok()) {
+        return {{slice.ValueOrDie(), gte_indices}};
+      }
+    }
+
+    // Finally, if we don't know the buffer for instr at index, see if we know
+    // the buffer for instr at index without its last element.  If so, we can
+    // dynamically find the buffer for instr by dereferencing a pointer in that
+    // buffer.  Continue looking this way until we run out of elements in
+    // 'index'.
+    //
+    // We can almost always get a buffer without resorting to this.  The only
+    // exception is for cases where the relevant sub-buffer is truly unknowable,
+    // for example the sub-buffer of a tuple-shaped select.
+    ShapeIndex new_index = index;
+    while (!new_index.empty()) {
+      gte_indices.push_front(new_index.back());
+      new_index.pop_back();
+      auto slice = buffer_assn.GetUniqueSlice(instr, new_index);
+      if (slice.ok()) {
+        return {{slice.ValueOrDie(), gte_indices}};
       }
     }
 
index 09b1dd2..7b994a4 100644 (file)
@@ -54,6 +54,25 @@ using TypesF16F32F64CF64 =
 #error "Situation not handled yet"
 #endif
 
+// Check that we can safely pass an input tuple's elements to a dot operation.
+TEST_F(DotOperationTest, DotOfInputTupleElem) {
+  ComputationBuilder builder(client_, TestName());
+
+  ComputationDataHandle param;
+  auto param_data = CreateParameterAndTransferLiteral(
+      0,
+      *Literal::MakeTuple({Literal::CreateR2<float>({{1, 2}, {3, 4}}).get(),
+                           Literal::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
+      "arg0", &builder, &param);
+  auto lhs = builder.GetTupleElement(param, 0);
+  auto rhs = builder.GetTupleElement(param, 1);
+  builder.Dot(lhs, rhs);
+
+  ComputeAndCompareLiteral(&builder,
+                           *Literal::CreateR2<float>({{19, 22}, {43, 50}}),
+                           {param_data.get()});
+}
+
 template <typename T>
 class DotOperationTest_F16F32F64CF64 : public DotOperationTest {};
 TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64);