From: Sanjoy Das Date: Tue, 20 Mar 2018 18:13:48 +0000 (-0700) Subject: Use 32 bit induction variable in gather expander X-Git-Tag: tflite-v0.1.7~145^2^2~9 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1c4e42b39fd9ae2da14d7eb323bedc144a6e659b;p=platform%2Fupstream%2Ftensorflow.git Use 32 bit induction variable in gather expander Right now this is unconditional (and we fail with Unimplemented() if a 32 bit induction variable is not large enough), but eventually we may want to be smarter about this. PiperOrigin-RevId: 189773581 --- diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 43c5648..d4d6787 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1276,6 +1276,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gather_expander_test", + srcs = ["gather_expander_test.cc"], + deps = [ + ":gather_expander", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:test_macros_header", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + cc_library( name = "conditional_simplifier", srcs = ["conditional_simplifier.cc"], diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 488bed3..221ff79 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -306,18 +306,33 @@ StatusOr GatherExpander::ExpandGather( HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); HloInstruction* gather_indices = gather_instr->mutable_operand(1); + const Shape& gather_indices_shape = gather_indices->shape(); const Shape& output_shape = gather_instr->shape(); int64 output_rank = output_shape.dimensions_size(); const GatherDimensionNumbers& dim_numbers = gather_instr->gather_dimension_numbers(); + int64 gather_loop_trip_count = 1; + for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + gather_loop_trip_count *= gather_indices_shape.dimensions(i); + } + } + + if (!IsInt32(gather_loop_trip_count)) { + return Unimplemented( + "Gather operations with more than 2147483647 gather indices are not " + "supported. This error occurred for %s.", + gather_instr->ToString().c_str()); + } + TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices, CanonicalizeGatherIndices( gather_indices, dim_numbers.index_vector_dim())); - const int64 gather_loop_trip_count = - canonical_gather_indices->shape().dimensions(0); + CHECK_EQ(gather_loop_trip_count, + canonical_gather_indices->shape().dimensions(0)); TF_ASSIGN_OR_RETURN( HloInstruction * accumulator_init, diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc new file mode 100644 index 0000000..ba41ee8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { +TEST(GatherExpanderTest, ErrorStatusOnTooManyIndices) { + const string hlo_text = R"( +HloModule TensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2147483647,5] parameter(1) + ROOT gather = s32[2147483647,3,5] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_text)); + + Status status = GatherExpander{}.Run(module.get()).status(); + EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); + + ASSERT_THAT( + status.error_message(), + ::testing::HasSubstr("Gather operations with more than 2147483647 gather " + "indices are not supported.")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 8cd5882..bd07941 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -142,23 +142,23 @@ WhileUtil::MakeInstructionsLiveIn( static StatusOr> MakeCountedLoopConditionComputation(const Shape& loop_state_shape, - int64 trip_count) { + int32 trip_count) { Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); - Shape scalar_s64 = ShapeUtil::MakeShape(S64, {}); TF_ASSIGN_OR_RETURN(std::unique_ptr cond_computation, CreateComputationWithSignature( {&loop_state_shape}, scalar_pred, "while_cond")); HloInstruction* trip_count_constant = cond_computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(trip_count))); + HloInstruction::CreateConstant(Literal::CreateR0(trip_count))); HloInstruction* param = cond_computation->parameter_instruction(0); - TF_ASSIGN_OR_RETURN(HloInstruction * counter, + TF_ASSIGN_OR_RETURN(HloInstruction * indvar, MakeGetTupleElementHlo(param, 0)); + TF_ASSIGN_OR_RETURN( HloInstruction * compare, - MakeBinaryHlo(HloOpcode::kLt, counter, trip_count_constant)); + MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant)); cond_computation->set_root_instruction(compare); return std::move(cond_computation); } @@ -171,8 +171,7 @@ static StatusOr> MakeCountedLoopBodyComputation( CreateComputationWithSignature( {&loop_state_shape}, loop_state_shape, "while_body")); HloInstruction* one = body_computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); - + HloInstruction::CreateConstant(Literal::CreateR0(1))); HloInstruction* param = body_computation->parameter_instruction(0); TF_ASSIGN_OR_RETURN(HloInstruction * indvar, MakeGetTupleElementHlo(param, 0)); @@ -200,7 +199,7 @@ static StatusOr MakeInitTupleFromInitValues( std::vector init_values_with_indvar; init_values_with_indvar.reserve(init_values.size() + 1); HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); init_values_with_indvar.push_back(zero); c_copy(init_values, std::back_inserter(init_values_with_indvar)); return computation->AddInstruction( @@ -210,16 +209,18 @@ static StatusOr MakeInitTupleFromInitValues( static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { std::vector loop_state_shape_components; loop_state_shape_components.reserve(init_values.size() + 1); - loop_state_shape_components.push_back(ShapeUtil::MakeShape(S64, {})); + loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {})); c_transform(init_values, std::back_inserter(loop_state_shape_components), [](HloInstruction* instr) { return instr->shape(); }); return ShapeUtil::MakeTupleShape(loop_state_shape_components); } /*static*/ StatusOr WhileUtil::MakeCountedLoop( - HloComputation* computation, int64 trip_count, + HloComputation* computation, int32 trip_count, const WhileUtil::LoopStateTy& init_values, const WhileUtil::LoopBodyGeneratorTy& loop_body_generator) { + CHECK_GE(trip_count, 0); + Shape loop_state_shape = MakeLoopStateShape(init_values); TF_ASSIGN_OR_RETURN( std::unique_ptr cond, diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 80f7e16..1688d46 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -71,7 +71,7 @@ class WhileUtil { // return loop_state; // } static StatusOr MakeCountedLoop( - HloComputation* computation, int64 trip_count, + HloComputation* computation, int32 trip_count, const LoopStateTy& init_values, const LoopBodyGeneratorTy& loop_body_generator); }; diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 025ac12..04a9c1e 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -676,7 +676,9 @@ xla_test( name = "gather_operation_test", srcs = ["gather_operation_test.cc"], deps = [ + ":client_library_test_base", ":hlo_test_base", + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:xla_internal_test_main", diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 8ba9194..9db68ff 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" @@ -397,5 +399,63 @@ ENTRY main { RunTest(hlo_text, operand.get(), gather_indices.get()); } +class GatherClientLibraryTest : public ClientLibraryTestBase {}; + +// TODO(b/30671675): Asynchronous execution on stream is not yet supported on +// GPU and CPU_PARALLEL. +XLA_TEST_F(GatherClientLibraryTest, + DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(Basic))) { + // We create this HLO, but using the ComputationBuilder API. + // + // ENTRY main { + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // ROOT gather = s32[2,3] gather(operand, indices), + // output_window_dims={1}, + // elided_window_dims={0}, + // gather_dims_to_operand_dims={0}, + // index_vector_dim=1, + // window_bounds={1, 3} + // } + + ComputationBuilder builder(client_, "gather_basic"); + + Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); + Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); + + auto operand = builder.Parameter(0, operand_shape, "operand"); + auto indices = builder.Parameter(1, indices_shape, "indices"); + GatherDimensionNumbers dim_numbers; + dim_numbers.add_output_window_dims(1); + dim_numbers.add_elided_window_dims(0); + dim_numbers.add_gather_dims_to_operand_dims(0); + dim_numbers.set_index_vector_dim(1); + builder.Gather(operand, indices, dim_numbers, {1, 3}); + + std::vector expected = {}; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr operand_arg, + client_->TransferToServer(*Literal::CreateR2( + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr indices_arg, + client_->TransferToServer(*Literal::CreateR1({0, 2}))); + TF_ASSERT_OK_AND_ASSIGN(std::vector devices, + client_->GetDeviceHandles(1)); + xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); + *execution_options.add_device_handles() = devices[0]; + TF_ASSERT_OK_AND_ASSIGN(Computation computation, builder.Build()); + std::vector computation_instances = { + {computation, + {operand_arg.get(), indices_arg.get()}, + execution_options, + /*execution_profile=*/nullptr}}; + TF_ASSERT_OK_AND_ASSIGN( + std::vector> result_data, + client_->ExecuteParallel(computation_instances)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + client_->Transfer(*(result_data[0]))); + LiteralTestUtil::ExpectEqual( + *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}})); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index ff99d37..2da9f9e 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -519,6 +519,15 @@ int64 FindIndex(const C& c, Value&& value) { auto it = c_find(c, std::forward(value)); return std::distance(c.begin(), it); } + +// Returns true if `x` fits in 32-bits. +template +bool IsInt32(T x) { + // Following conversion rules: "the value is unchanged if it can be + // represented in the destination type (and bit-field width); otherwise, the + // value is implementation-defined." + return static_cast(x) == x; +} } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \