Fix some edge cases around scalar indices in the gather expander
authorSanjoy Das <sanjoy@google.com>
Tue, 20 Mar 2018 06:15:42 +0000 (23:15 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 06:20:15 +0000 (23:20 -0700)
I discovered these when changing the tf2xla bridge to directly emit gather
operations.

 - DeScalarizeGatherIndices was assuming that gather_indices must be of at least
   rank 1.  Fix this to be more general.

 - We were passing in the wrong version of gather indices to
   ExpandFirstDimIntoNDims.  We don't strictly need to pass in
   transposed_gather_indices (since if transposed_gather_indices is rank 1 then
   the transpose has to be an identity transpose), passing in
   descalarized_gather_indices would also have been fine, but
   transposed_gather_indices seems more uniform.

 - ExpandGatherDimsInAccumulator was assuming that gather_indices must be of at
   least rank 1 (by calling CollapseFirstNDims).  Fix this to be more general.

 - We were trying to go through with emitting zero sized gather operations.  I
   don't think it is worth dealing with all of the edge cases this would expose
   so now we just punt to ZeroSizedHloElimination.

PiperOrigin-RevId: 189696444

tensorflow/compiler/xla/service/gather_expander.cc
tensorflow/compiler/xla/service/hlo_creation_utils.cc
tensorflow/compiler/xla/service/hlo_creation_utils.h
tensorflow/compiler/xla/tests/gather_operation_test.cc

index 58c62d8..488bed3 100644 (file)
@@ -53,9 +53,14 @@ static StatusOr<HloInstruction*> DeScalarizeGatherIndices(
     return gather_indices;
   }
 
-  int64 last_index = gather_indices_shape.dimensions(
-      gather_indices_shape.dimensions_size() - 1);
-  return ExpandLastDimIntoNDims(gather_indices, {last_index, 1});
+  DCHECK_EQ(index_vector_dim, gather_indices_shape.dimensions_size());
+
+  std::vector<int64> result_shape_dims;
+  c_copy(gather_indices_shape.dimensions(),
+         std::back_inserter(result_shape_dims));
+  result_shape_dims.push_back(1);
+
+  return MakeReshapeHlo(result_shape_dims, gather_indices);
 }
 
 // Canonicalizes the gather_indices tensors so that we only have deal with some
@@ -81,16 +86,17 @@ static StatusOr<HloInstruction*> CanonicalizeGatherIndices(
   // all of the non-index-vector dimensions.
   const Shape& shape = transposed_gather_indices->shape();
   if (shape.dimensions_size() == 1) {
-    return ExpandFirstDimIntoNDims(gather_indices, {1, shape.dimensions(0)});
+    return ExpandFirstDimIntoNDims(transposed_gather_indices,
+                                   {1, shape.dimensions(0)});
   } else {
     return CollapseFirstNDims(transposed_gather_indices,
                               shape.dimensions_size() - 1);
   }
 }
 
-// Expands out the gather dimensions in the accumulator produced by the while
-// loop.
-static StatusOr<HloInstruction*> ExpandGatherDimsInAccumulator(
+// Expands out or contracts away the gather dimensions in the accumulator
+// produced by the while loop.
+static StatusOr<HloInstruction*> AdjustGatherDimsInAccumulator(
     const Shape& gather_indices_shape, HloInstruction* accumulator,
     int64 index_vector_dim) {
   std::vector<int64> output_gather_dim_bounds;
@@ -103,9 +109,14 @@ static StatusOr<HloInstruction*> ExpandGatherDimsInAccumulator(
 
   if (output_gather_dim_bounds.empty()) {
     // If output_gather_dim_bounds is empty we must be lowering a (effectively)
-    // dynamic-slice.
+    // dynamic-slice.  In that case, there is a leading degenerate gather
+    // dimension that we added to make this special case play well with the
+    // general while loop which we need to remove now.
     CHECK_EQ(accumulator->shape().dimensions(0), 1);
-    return CollapseFirstNDims(accumulator, 2);
+    ArraySlice<int64> reshaped_dim_sizes =
+        AsInt64Slice(accumulator->shape().dimensions());
+    reshaped_dim_sizes.remove_prefix(1);
+    return MakeReshapeHlo(reshaped_dim_sizes, accumulator);
   }
 
   return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds);
@@ -290,6 +301,8 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
 
 StatusOr<HloInstruction*> GatherExpander::ExpandGather(
     HloInstruction* gather_instr) {
+  CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape()));
+
   HloComputation* computation = gather_instr->parent();
   HloInstruction* operand = gather_instr->mutable_operand(0);
   HloInstruction* gather_indices = gather_instr->mutable_operand(1);
@@ -331,7 +344,7 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
 
   TF_ASSIGN_OR_RETURN(
       HloInstruction * accumulator_with_output_gather_dims_decanonicalized,
-      ExpandGatherDimsInAccumulator(gather_indices->shape(),
+      AdjustGatherDimsInAccumulator(gather_indices->shape(),
                                     accumulator_with_window_dims_elided,
                                     dim_numbers.index_vector_dim()));
 
@@ -341,12 +354,17 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
 }
 
 StatusOr<bool> GatherExpander::Run(HloModule* module) {
+  auto is_nontrivial_gather = [](HloInstruction* inst) {
+    return inst->opcode() == HloOpcode::kGather &&
+           // Avoid expanding gather ops that produce zero sized tensors,
+           // instead punt these to ZeroSizedHloElimination.
+           !ShapeUtil::HasZeroElements(inst->shape());
+  };
+
   std::vector<HloInstruction*> gather_instrs;
   for (HloComputation* computation : module->MakeNonfusionComputations()) {
     c_copy_if(computation->instructions(), std::back_inserter(gather_instrs),
-              [](HloInstruction* inst) {
-                return inst->opcode() == HloOpcode::kGather;
-              });
+              is_nontrivial_gather);
   }
 
   for (HloInstruction* inst : gather_instrs) {
index fbe71f8..b186767 100644 (file)
@@ -201,24 +201,6 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
   return MakeReshapeHlo(new_shape, operand);
 }
 
-StatusOr<HloInstruction*> ExpandLastDimIntoNDims(
-    HloInstruction* operand, ArraySlice<int64> expanded_dims) {
-  CHECK_GT(operand->shape().dimensions_size(), 0);
-  CHECK_EQ(operand->shape().dimensions(operand->shape().dimensions_size() - 1),
-           Product(expanded_dims));
-
-  std::vector<int64> expanded_shape_dim_bounds;
-  expanded_shape_dim_bounds.reserve(expanded_dims.size() +
-                                    operand->shape().dimensions_size() - 1);
-  std::copy(operand->shape().dimensions().begin(),
-            operand->shape().dimensions().end() - 1,
-            std::back_inserter(expanded_shape_dim_bounds));
-  c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
-  Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
-                                         expanded_shape_dim_bounds);
-  return MakeReshapeHlo(new_shape, operand);
-}
-
 StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
                                               ArraySlice<int64> dims_to_elide) {
   CHECK(c_is_sorted(dims_to_elide));
index 6032eba..d99e32a 100644 (file)
@@ -119,16 +119,6 @@ StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n);
 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
     HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> expanded_dims);
 
-// Expands (via reshape) the last (logical) dimension of `operand` into a
-// sequence of `expanded_dims` dimensions.  `operand` must at least be of rank 1
-// and the number of elements in its last dimension must be equal to the
-// product of `expanded_dims`.
-//
-// For instance if `operand` has shape f32[9,7,200] and expanded_dims is
-// {2,5,20} the result is `operand` reshaped to [9,7,2,5,20].
-StatusOr<HloInstruction*> ExpandLastDimIntoNDims(
-    HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> expanded_dims);
-
 // Elides (via reshape) a set of degenerate dimensions (dimensions containing
 // exactly one element), `dims_to_elide` from `operand`.  Every dimension in
 // `dims_to_elide` must be a degenerate dimension.  `dims_to_elide` must be
index 0830e9c..8ba9194 100644 (file)
@@ -335,5 +335,67 @@ ENTRY main {
           {operand.get(), gather_indices.get(), in_bounds_mask.get()});
 }
 
+XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
+  const char* hlo_text = R"(
+HloModule OneScalarIndex
+
+ENTRY main {
+  operand = s32[2,3,2]{2,1,0} parameter(0)
+  index = s32[] parameter(1)
+  ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index),
+      output_window_dims={0,1,2},
+      elided_window_dims={},
+      gather_dims_to_operand_dims={0},
+      index_vector_dim=0,
+      window_bounds={1,3,2}
+}
+)";
+  std::unique_ptr<Literal> operand = Literal::CreateR3<int32>(
+      {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
+  std::unique_ptr<Literal> gather_indices = Literal::CreateR0<int32>(1);
+  RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, ScalarResult) {
+  const char* hlo_text = R"(
+HloModule ScalarResult
+
+ENTRY main {
+  operand = s32[4]{0} parameter(0)
+  index = s32[] parameter(1)
+  ROOT gather = s32[] gather(operand, index),
+      output_window_dims={},
+      elided_window_dims={0},
+      gather_dims_to_operand_dims={0},
+      index_vector_dim=0,
+      window_bounds={1}
+}
+)";
+  std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({1, 2, 3, 4});
+  std::unique_ptr<Literal> gather_indices = Literal::CreateR0<int32>(1);
+  RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, ZeroSizedResult) {
+  const string hlo_text = R"(
+HloModule ZeroSizedResult
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[0] parameter(1)
+  ROOT gather = s32[0,3] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0},
+      gather_dims_to_operand_dims={0},
+      index_vector_dim=1,
+      window_bounds={1, 3}
+}
+)";
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+  std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({});
+  RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
 }  // namespace
 }  // namespace xla