From e2e67c528316be8ea4f624af8757e80d7f00b5b6 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Mon, 19 Mar 2018 23:15:42 -0700 Subject: [PATCH] Fix some edge cases around scalar indices in the gather expander I discovered these when changing the tf2xla bridge to directly emit gather operations. - DeScalarizeGatherIndices was assuming that gather_indices must be of at least rank 1. Fix this to be more general. - We were passing in the wrong version of gather indices to ExpandFirstDimIntoNDims. We don't strictly need to pass in transposed_gather_indices (since if transposed_gather_indices is rank 1 then the transpose has to be an identity transpose), passing in descalarized_gather_indices would also have been fine, but transposed_gather_indices seems more uniform. - ExpandGatherDimsInAccumulator was assuming that gather_indices must be of at least rank 1 (by calling CollapseFirstNDims). Fix this to be more general. - We were trying to go through with emitting zero sized gather operations. I don't think it is worth dealing with all of the edge cases this would expose so now we just punt to ZeroSizedHloElimination. PiperOrigin-RevId: 189696444 --- tensorflow/compiler/xla/service/gather_expander.cc | 44 ++++++++++----- .../compiler/xla/service/hlo_creation_utils.cc | 18 ------- .../compiler/xla/service/hlo_creation_utils.h | 10 ---- .../compiler/xla/tests/gather_operation_test.cc | 62 ++++++++++++++++++++++ 4 files changed, 93 insertions(+), 41 deletions(-) diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 58c62d8..488bed3 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -53,9 +53,14 @@ static StatusOr DeScalarizeGatherIndices( return gather_indices; } - int64 last_index = gather_indices_shape.dimensions( - gather_indices_shape.dimensions_size() - 1); - return ExpandLastDimIntoNDims(gather_indices, {last_index, 1}); + DCHECK_EQ(index_vector_dim, gather_indices_shape.dimensions_size()); + + std::vector result_shape_dims; + c_copy(gather_indices_shape.dimensions(), + std::back_inserter(result_shape_dims)); + result_shape_dims.push_back(1); + + return MakeReshapeHlo(result_shape_dims, gather_indices); } // Canonicalizes the gather_indices tensors so that we only have deal with some @@ -81,16 +86,17 @@ static StatusOr CanonicalizeGatherIndices( // all of the non-index-vector dimensions. const Shape& shape = transposed_gather_indices->shape(); if (shape.dimensions_size() == 1) { - return ExpandFirstDimIntoNDims(gather_indices, {1, shape.dimensions(0)}); + return ExpandFirstDimIntoNDims(transposed_gather_indices, + {1, shape.dimensions(0)}); } else { return CollapseFirstNDims(transposed_gather_indices, shape.dimensions_size() - 1); } } -// Expands out the gather dimensions in the accumulator produced by the while -// loop. -static StatusOr ExpandGatherDimsInAccumulator( +// Expands out or contracts away the gather dimensions in the accumulator +// produced by the while loop. +static StatusOr AdjustGatherDimsInAccumulator( const Shape& gather_indices_shape, HloInstruction* accumulator, int64 index_vector_dim) { std::vector output_gather_dim_bounds; @@ -103,9 +109,14 @@ static StatusOr ExpandGatherDimsInAccumulator( if (output_gather_dim_bounds.empty()) { // If output_gather_dim_bounds is empty we must be lowering a (effectively) - // dynamic-slice. + // dynamic-slice. In that case, there is a leading degenerate gather + // dimension that we added to make this special case play well with the + // general while loop which we need to remove now. CHECK_EQ(accumulator->shape().dimensions(0), 1); - return CollapseFirstNDims(accumulator, 2); + ArraySlice reshaped_dim_sizes = + AsInt64Slice(accumulator->shape().dimensions()); + reshaped_dim_sizes.remove_prefix(1); + return MakeReshapeHlo(reshaped_dim_sizes, accumulator); } return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds); @@ -290,6 +301,8 @@ static StatusOr PermuteGatherAndWindowDims( StatusOr GatherExpander::ExpandGather( HloInstruction* gather_instr) { + CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape())); + HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); HloInstruction* gather_indices = gather_instr->mutable_operand(1); @@ -331,7 +344,7 @@ StatusOr GatherExpander::ExpandGather( TF_ASSIGN_OR_RETURN( HloInstruction * accumulator_with_output_gather_dims_decanonicalized, - ExpandGatherDimsInAccumulator(gather_indices->shape(), + AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_with_window_dims_elided, dim_numbers.index_vector_dim())); @@ -341,12 +354,17 @@ StatusOr GatherExpander::ExpandGather( } StatusOr GatherExpander::Run(HloModule* module) { + auto is_nontrivial_gather = [](HloInstruction* inst) { + return inst->opcode() == HloOpcode::kGather && + // Avoid expanding gather ops that produce zero sized tensors, + // instead punt these to ZeroSizedHloElimination. + !ShapeUtil::HasZeroElements(inst->shape()); + }; + std::vector gather_instrs; for (HloComputation* computation : module->MakeNonfusionComputations()) { c_copy_if(computation->instructions(), std::back_inserter(gather_instrs), - [](HloInstruction* inst) { - return inst->opcode() == HloOpcode::kGather; - }); + is_nontrivial_gather); } for (HloInstruction* inst : gather_instrs) { diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index fbe71f8..b186767 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -201,24 +201,6 @@ StatusOr ExpandFirstDimIntoNDims( return MakeReshapeHlo(new_shape, operand); } -StatusOr ExpandLastDimIntoNDims( - HloInstruction* operand, ArraySlice expanded_dims) { - CHECK_GT(operand->shape().dimensions_size(), 0); - CHECK_EQ(operand->shape().dimensions(operand->shape().dimensions_size() - 1), - Product(expanded_dims)); - - std::vector expanded_shape_dim_bounds; - expanded_shape_dim_bounds.reserve(expanded_dims.size() + - operand->shape().dimensions_size() - 1); - std::copy(operand->shape().dimensions().begin(), - operand->shape().dimensions().end() - 1, - std::back_inserter(expanded_shape_dim_bounds)); - c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); - Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), - expanded_shape_dim_bounds); - return MakeReshapeHlo(new_shape, operand); -} - StatusOr ElideDegenerateDims(HloInstruction* operand, ArraySlice dims_to_elide) { CHECK(c_is_sorted(dims_to_elide)); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 6032eba..d99e32a 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -119,16 +119,6 @@ StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n); StatusOr ExpandFirstDimIntoNDims( HloInstruction* operand, tensorflow::gtl::ArraySlice expanded_dims); -// Expands (via reshape) the last (logical) dimension of `operand` into a -// sequence of `expanded_dims` dimensions. `operand` must at least be of rank 1 -// and the number of elements in its last dimension must be equal to the -// product of `expanded_dims`. -// -// For instance if `operand` has shape f32[9,7,200] and expanded_dims is -// {2,5,20} the result is `operand` reshaped to [9,7,2,5,20]. -StatusOr ExpandLastDimIntoNDims( - HloInstruction* operand, tensorflow::gtl::ArraySlice expanded_dims); - // Elides (via reshape) a set of degenerate dimensions (dimensions containing // exactly one element), `dims_to_elide` from `operand`. Every dimension in // `dims_to_elide` must be a degenerate dimension. `dims_to_elide` must be diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 0830e9c..8ba9194 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -335,5 +335,67 @@ ENTRY main { {operand.get(), gather_indices.get(), in_bounds_mask.get()}); } +XLA_TEST_F(GatherOperationTest, OneScalarIndex) { + const char* hlo_text = R"( +HloModule OneScalarIndex + +ENTRY main { + operand = s32[2,3,2]{2,1,0} parameter(0) + index = s32[] parameter(1) + ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index), + output_window_dims={0,1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0}, + index_vector_dim=0, + window_bounds={1,3,2} +} +)"; + std::unique_ptr operand = Literal::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); + std::unique_ptr gather_indices = Literal::CreateR0(1); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, ScalarResult) { + const char* hlo_text = R"( +HloModule ScalarResult + +ENTRY main { + operand = s32[4]{0} parameter(0) + index = s32[] parameter(1) + ROOT gather = s32[] gather(operand, index), + output_window_dims={}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=0, + window_bounds={1} +} +)"; + std::unique_ptr operand = Literal::CreateR1({1, 2, 3, 4}); + std::unique_ptr gather_indices = Literal::CreateR0(1); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, ZeroSizedResult) { + const string hlo_text = R"( +HloModule ZeroSizedResult + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[0] parameter(1) + ROOT gather = s32[0,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 3} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + } // namespace } // namespace xla -- 2.7.4