Use the new gather HLO in the bridge when lowering TF gather ops; NFC
authorSanjoy Das <sanjoy@google.com>
Wed, 18 Apr 2018 20:47:17 +0000 (13:47 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 18 Apr 2018 20:52:45 +0000 (13:52 -0700)
After gather expansion this should boil down to a while loop very similar to
what we emit from the bridge today.

PiperOrigin-RevId: 193410095

tensorflow/compiler/tf2xla/kernels/gather_op.cc
tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h

index 7945c05..0b79cb0 100644 (file)
@@ -29,52 +29,54 @@ namespace tensorflow {
 Status XlaGather(const xla::ComputationDataHandle& input,
                  const TensorShape& input_shape,
                  const xla::ComputationDataHandle& indices,
-                 TensorShape indices_shape, int64 axis, bool indices_are_nd,
-                 DataType dtype, DataType index_type,
+                 const TensorShape& indices_shape, int64 axis,
+                 bool indices_are_nd, DataType dtype, DataType index_type,
                  xla::ComputationBuilder* builder,
                  xla::ComputationDataHandle* gather_output) {
+  // There is no deep reason why we need this precondition, but this is the only
+  // combination that is used and tested today.
+  CHECK(!indices_are_nd || axis == 0);
+
+  // num_index_dims is the number of components in each index in the indices
+  // tensor.
+  //
+  // num_indices is the total number of (n dimensional or scalar) indices in the
+  // indices tensor.
+  //
   // If the indices are N-dimensional, then the minor dimension of indices
   // should be of size N and correspond to the N indices.
-  int64 num_index_dims = 1;
+  int64 num_index_dims;
+  int64 num_indices = 1;
   if (indices_are_nd) {
     CHECK_GE(indices_shape.dims(), 1);
     num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1);
-    indices_shape.RemoveLastDims(1);
+    for (int64 i = 0, e = indices_shape.dims() - 1; i < e; i++) {
+      num_indices *= indices_shape.dim_size(i);
+    }
+  } else {
+    num_index_dims = 1;
+    for (int64 i = 0, e = indices_shape.dims(); i < e; i++) {
+      num_indices *= indices_shape.dim_size(i);
+    }
   }
 
-  // Although the indices Tensor is flattened into rank 1 during the lookup,
-  // and each scalar entry is used as an index into the first dimension of the
-  // input, the output is returned with shape:
-  // input.shape[:axis] + indices.shape + input.shape[axis+1:]
-
-  const int64 num_indices = indices_shape.num_elements();
-  TensorShape input_shape_pre_axis(input_shape);
-  input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims());
-  TensorShape input_shape_post_axis(input_shape);
-  input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims);
-  // Each slice of the input tensor has shape:
-  // [<input_shape_pre_axis>, 1, ..., 1, <input shape_post_axis>]
-  TensorShape slice_shape(input_shape);
-  for (int64 i = 0; i < num_index_dims; ++i) {
-    slice_shape.set_dim(axis + i, 1);
-  }
+  // Degenerate case: empty indices.
+  if (num_indices == 0) {
+    TensorShape input_shape_pre_axis{input_shape};
+    input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims());
+    TensorShape input_shape_post_axis{input_shape};
+    input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims);
 
-  TensorShape loop_out_shape;
-  loop_out_shape.AppendShape(input_shape_pre_axis);
-  loop_out_shape.AddDim(num_indices);
-  loop_out_shape.AppendShape(input_shape_post_axis);
-  TensorShape loop_out_slice_shape;
-  loop_out_slice_shape.AppendShape(input_shape_pre_axis);
-  loop_out_slice_shape.AddDim(1);
-  loop_out_slice_shape.AppendShape(input_shape_post_axis);
+    TensorShape indices_shape_no_index_vectors{indices_shape};
+    if (indices_are_nd) {
+      indices_shape_no_index_vectors.RemoveLastDims(1);
+    }
 
-  TensorShape out_shape;
-  out_shape.AppendShape(input_shape_pre_axis);
-  out_shape.AppendShape(indices_shape);
-  out_shape.AppendShape(input_shape_post_axis);
+    TensorShape out_shape;
+    out_shape.AppendShape(input_shape_pre_axis);
+    out_shape.AppendShape(indices_shape_no_index_vectors);
+    out_shape.AppendShape(input_shape_post_axis);
 
-  // Degenerate case: empty indices.
-  if (num_indices == 0) {
     *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
                                         out_shape.dim_sizes());
     return Status::OK();
@@ -88,76 +90,61 @@ Status XlaGather(const xla::ComputationDataHandle& input,
     }
   }
 
-  // Flatten the major dimensions of indices into a single dimension for ease of
-  // iteration. If there is an axis dimension, we must leave it alone.
-  std::vector<int64> flat_indices_shape = {num_indices};
-  if (indices_are_nd) {
-    flat_indices_shape.push_back(num_index_dims);
-  }
-
-  // Specify the shape of the loop-carried Tensor tuple.
-
-  // Construct the initial values of the loop-carried Tensors.
-  auto flat_indices = builder->Reshape(indices, flat_indices_shape);
-  auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
-                                     loop_out_shape.dim_sizes());
-  auto init = {input, flat_indices, init_out};
-
-  // Construct the while loop body's function. The implementation of gather is:
-  // for i in range(num_indices):
-  //   index = dynamic-slice(indices, i)
-  //   xi = dynamic-slice(input, index)
-  //   output = dynamic-update-slice(output, xi, i)
-  auto body_fn = [&](xla::ComputationDataHandle i,
-                     gtl::ArraySlice<xla::ComputationDataHandle> loop_vars,
-                     xla::ComputationBuilder* bodyb) {
-    auto input = loop_vars[0];
-    auto indices = loop_vars[1];
-    auto output = loop_vars[2];
-
-    auto zero_index = XlaHelpers::Zero(bodyb, index_type);
-
-    // Slice the i-th index from the indices array.
-    xla::ComputationDataHandle index;
-    auto indices_offset = bodyb->Reshape(i, {1});
-    if (indices_are_nd) {
-      // Slice out the entire nd index, if applicable.
-      indices_offset = bodyb->Pad(indices_offset, zero_index,
-                                  xla::MakeEdgePaddingConfig({{0, 1}}));
-      index = bodyb->DynamicSlice(indices, indices_offset, {1, num_index_dims});
-      index = bodyb->Collapse(index, {0, 1});
+  // Example of a 1-D gather with axis=1, pulling two [3,1] tensors out of a
+  // tensor of shape [3,3].
+  //
+  //  operand = s32[3,3] parameter(0)
+  //  indices = s32[2] parameter(1)
+  //  gather = s32[3,2] gather(operand, indices),
+  //       output_window_dims={0},
+  //       elided_window_dims={1},
+  //       gather_dims_to_operand_dims={1},
+  //       index_vector_dim=1,
+  //       window_bounds={3, 1}
+  //
+  //
+  // Example of an N-D gather pulling out slices of shape [1,1,2] out of a
+  // tensor of shape [3,3,2].
+  //
+  //  operand = s32[3,3,2] parameter(0)
+  //  indices = s32[2,2] parameter(1)
+  //  gather = s32[2,2] gather(operand, indices),
+  //       output_window_dims={1},
+  //       elided_window_dims={0,1},
+  //       gather_dims_to_operand_dims={0,1},
+  //       index_vector_dim=0,
+  //       window_bounds={1,1,2}
+
+  xla::GatherDimensionNumbers dim_numbers;
+  std::vector<int64> window_bounds;
+  window_bounds.reserve(input_shape.dims());
+  for (int64 i = 0; i < input_shape.dims(); i++) {
+    int64 window_bound;
+    if (axis <= i && i < (axis + num_index_dims)) {
+      dim_numbers.add_elided_window_dims(i);
+      window_bound = 1;
     } else {
-      index = bodyb->DynamicSlice(indices, indices_offset, {1});
+      window_bound = input_shape.dim_size(i);
+    }
+
+    window_bounds.push_back(window_bound);
+
+    if (i < axis) {
+      dim_numbers.add_output_window_dims(i);
+    } else if (i >= (axis + num_index_dims)) {
+      int64 indices_rank =
+          indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims();
+      dim_numbers.add_output_window_dims(i + indices_rank - num_index_dims);
     }
+  }
+
+  dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1)
+                                                  : indices_shape.dims());
+  for (int64 i = axis; i < axis + num_index_dims; i++) {
+    dim_numbers.add_gather_dims_to_operand_dims(i);
+  }
 
-    // Slice the corresponding data from the input array.
-    auto start_indices = bodyb->Pad(
-        index, zero_index,
-        xla::MakeEdgePaddingConfig(
-            {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}}));
-    auto slice_i = bodyb->Reshape(
-        bodyb->DynamicSlice(input, start_indices, slice_shape.dim_sizes()),
-        loop_out_slice_shape.dim_sizes());
-
-    // Construct the index into the output Tensor 0, ..., <index>, 0, ...
-    std::vector<xla::ComputationDataHandle> out_index_vals(
-        loop_out_shape.dims(), bodyb->Reshape(zero_index, {1}));
-    out_index_vals[input_shape_pre_axis.dims()] = bodyb->Reshape(i, {1});
-    auto out_index = bodyb->ConcatInDim(out_index_vals, 0);
-
-    // Update the output Tensor
-    auto updated_output = bodyb->DynamicUpdateSlice(output, slice_i, out_index);
-
-    return std::vector<xla::ComputationDataHandle>{input, indices,
-                                                   updated_output};
-  };
-
-  // Construct the While loop, extract and reshape the output.
-  xla::PrimitiveType ptype;
-  TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(index_type, &ptype));
-  TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices, ptype, body_fn,
-                                                    init, "gather", builder));
-  *gather_output = builder->Reshape(outputs[2], out_shape.dim_sizes());
+  *gather_output = builder->Gather(input, indices, dim_numbers, window_bounds);
   return Status::OK();
 }
 
index bd8b92c..f9376f0 100644 (file)
@@ -36,8 +36,8 @@ namespace tensorflow {
 Status XlaGather(const xla::ComputationDataHandle& input,
                  const TensorShape& input_shape,
                  const xla::ComputationDataHandle& indices,
-                 TensorShape indices_shape, int64 axis, bool indices_are_nd,
-                 DataType dtype, DataType index_type,
+                 const TensorShape& indices_shape, int64 axis,
+                 bool indices_are_nd, DataType dtype, DataType index_type,
                  xla::ComputationBuilder* builder,
                  xla::ComputationDataHandle* gather_output);