From 8477e7cdd0dafb2e9f9f1c1ad3929b15a29a5ada Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Fri, 27 Apr 2018 14:36:24 -0700 Subject: [PATCH] [XLA:CPU] Implement fusion for the Gather HLO PiperOrigin-RevId: 194594759 --- tensorflow/compiler/xla/service/cpu/BUILD | 1 + .../compiler/xla/service/cpu/cpu_compiler.cc | 3 +- .../xla/service/cpu/cpu_instruction_fusion.cc | 1 + .../xla/service/cpu/cpu_instruction_fusion_test.cc | 149 +++++++++++++++++ .../compiler/xla/service/elemental_ir_emitter.cc | 86 ++++++++++ tensorflow/compiler/xla/service/llvm_ir/ir_array.h | 4 + .../compiler/xla/tests/gather_operation_test.cc | 178 +++++++++++++++++++++ 7 files changed, 421 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index cef4eba..2fc6c6b 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -624,6 +624,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 3c0c367..150c12e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -258,7 +258,6 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, /*use_fusion=*/false); - pipeline.AddPass(); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, @@ -287,6 +286,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); + pipeline.AddPass(); + ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 0fc5a74..b40d264 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -34,6 +34,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kConcatenate || hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || + hlo.opcode() == HloOpcode::kGather || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kReverse || diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 6ed1cd3..a98e85a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -697,6 +698,154 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { Not(op::Fusion())); } +struct GatherLoopFusionTestSpec { + string test_name; + string hlo_computation_text; + + static string Name( + const ::testing::TestParamInfo& info) { + return info.param.test_name; + } +}; + +class GatherLoopFusionTest + : public OpcodeFusionTest, + public ::testing::WithParamInterface {}; + +TEST_P(GatherLoopFusionTest, GatherLoopFusion) { + const GatherLoopFusionTestSpec& spec = GetParam(); + string hlo_string = tensorflow::strings::StrCat( + "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + RunFusionAndCheckOpcodesWereFused( + module.get(), + {HloOpcode::kGather, HloOpcode::kAdd, HloOpcode::kBroadcast, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter}); +} + +std::vector GetGatherLoopFusionTestSpecs() { + std::vector result; + + result.push_back({"FusedTensorFlowGatherV2", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[3,2] broadcast(one), dimensions={} + ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherMultipleBatchDims", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} + ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherNdMultipleBatchDims", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherNd_0", R"( +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherNd_1", R"( +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedDynamicSlice", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[1,1] broadcast(one), dimensions={} + ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedBatchDynamicSlice", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} + ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) +} +)"}); + + return result; +} + +INSTANTIATE_TEST_CASE_P(GatherLoopFusionTestInstantiation, GatherLoopFusionTest, + ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()), + GatherLoopFusionTestSpec::Name); } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 38b5efa..4b01c87 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1587,6 +1587,92 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( } return operand_to_generator.at(input_hlo)(input_index); }; + + case HloOpcode::kGather: + return [this, hlo, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + const Shape& operand_shape = hlo->operand(0)->shape(); + const Shape& indices_shape = hlo->operand(1)->shape(); + const Shape& output_shape = hlo->shape(); + + const GatherDimensionNumbers& dim_numbers = + hlo->gather_dimension_numbers(); + + const llvm_ir::ElementGenerator& operand_generator = + operand_to_generator.at(hlo->operand(0)); + const llvm_ir::ElementGenerator& indices_generator = + operand_to_generator.at(hlo->operand(1)); + + // This is the index into `operand` that holds the element we want to + // generate. This index "unsafe" as in the components in here may be + // out of bounds. + IrArray::Index unsafe_operand_index; + + // First copy in the window indices to unsafe_operand_index. + for (int64 i = 0, e = operand_shape.dimensions_size(), + unsafe_operand_index_dim = 0; + i < e; i++) { + if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + unsafe_operand_index.push_back(ir_builder_->getInt64(0)); + } else { + unsafe_operand_index.push_back(index[dim_numbers.output_window_dims( + unsafe_operand_index_dim++)]); + } + } + + // This is the index of the index vector in the gather_indices tensor. + IrArray::Index gather_index_index; + { + std::vector gather_index_index_components; + for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { + if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + gather_index_index.push_back(index[i]); + } + } + + if (gather_index_index.size() != indices_shape.dimensions_size()) { + gather_index_index.InsertAt(dim_numbers.index_vector_dim(), + nullptr); + } + } + + auto add_to_unsafe_operand_index = [&](llvm::Value* index_component, + int64 dim) { + llvm::Value* gather_dim_component_extended = + ir_builder_->CreateSExtOrTrunc(index_component, + ir_builder_->getInt64Ty()); + unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] = + ir_builder_->CreateAdd( + unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims( + dim)], + gather_dim_component_extended); + }; + + if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { + TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, + indices_generator(gather_index_index)); + add_to_unsafe_operand_index(gather_dim_component, 0); + } else { + int64 index_vector_size = + indices_shape.dimensions(dim_numbers.index_vector_dim()); + for (int64 i = 0; i < index_vector_size; i++) { + gather_index_index[dim_numbers.index_vector_dim()] = + ir_builder_->getInt64(i); + TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, + indices_generator(gather_index_index)); + add_to_unsafe_operand_index(gather_dim_component, i); + } + } + + IrArray::Index safe_operand_index; + for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) { + safe_operand_index.push_back(ir_builder_->CreateURem( + unsafe_operand_index[i], + ir_builder_->getInt64(operand_shape.dimensions(i)))); + } + + return operand_generator(safe_operand_index); + }; case HloOpcode::kDynamicUpdateSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 06cfb2a..4c3195c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -97,6 +97,10 @@ class IrArray { llvm::Value*& operator[](size_t i) { return multidim()[i]; } void push_back(llvm::Value* value) { multidim().push_back(value); } + void InsertAt(int64 index, llvm::Value* value) { + CHECK_LE(index, size()); + multidim().insert(multidim().begin() + index, value); + } using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 4dd3acd..130456e 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -399,6 +399,184 @@ ENTRY main { RunTest(hlo_text, operand.get(), gather_indices.get()); } +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherV2 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[3,2] broadcast(one), dimensions={} + ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} + ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherNdMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, + FusedTensorFlowGatherNdNonDefaultIndexVectorDim) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) { + const char* hlo_text = R"( +HloModule FusedDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[1,1] broadcast(one), dimensions={} + ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) { + const string hlo_text = R"( +HloModule FusedBatchDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} + ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + class GatherClientLibraryTest : public ClientLibraryTestBase {}; XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { -- 2.7.4