[XLA:CPU] Implement fusion for the Gather HLO
authorSanjoy Das <sanjoy@google.com>
Fri, 27 Apr 2018 21:36:24 +0000 (14:36 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 27 Apr 2018 21:39:21 +0000 (14:39 -0700)
PiperOrigin-RevId: 194594759

tensorflow/compiler/xla/service/cpu/BUILD
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
tensorflow/compiler/xla/service/elemental_ir_emitter.cc
tensorflow/compiler/xla/service/llvm_ir/ir_array.h
tensorflow/compiler/xla/tests/gather_operation_test.cc

index cef4eba..2fc6c6b 100644 (file)
@@ -624,6 +624,7 @@ tf_cc_test(
         "//tensorflow/compiler/xla/service:transpose_folding",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/compiler/xla/tools/parser:hlo_parser",
         "//tensorflow/core:lib",
     ],
 )
index 3c0c367..150c12e 100644 (file)
@@ -258,7 +258,6 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
         /*rewrite_inference_op=*/true,
         /*rewrite_grad_op=*/true,
         /*use_fusion=*/false);
-    pipeline.AddPass<GatherExpander>();
     pass.AddPass<AlgebraicSimplifier>(
         /*is_layout_sensitive=*/false,
         [](const Shape&, const Shape&) { return false; },
@@ -287,6 +286,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
   pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
   pipeline.AddPass<CpuInstructionFusion>();
 
+  pipeline.AddPass<GatherExpander>();
+
   ReducePrecisionInsertion::AddPasses(
       &pipeline, module->config().debug_options(),
       ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
index 0fc5a74..b40d264 100644 (file)
@@ -34,6 +34,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) {
          hlo.opcode() == HloOpcode::kConcatenate ||
          hlo.opcode() == HloOpcode::kDynamicSlice ||
          hlo.opcode() == HloOpcode::kDynamicUpdateSlice ||
+         hlo.opcode() == HloOpcode::kGather ||
          hlo.opcode() == HloOpcode::kPad ||
          hlo.opcode() == HloOpcode::kReshape ||
          hlo.opcode() == HloOpcode::kReverse ||
index 6ed1cd3..a98e85a 100644 (file)
@@ -21,6 +21,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
 #include "tensorflow/compiler/xla/service/transpose_folding.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
 
 namespace op = xla::testing::opcode_matchers;
@@ -697,6 +698,154 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) {
               Not(op::Fusion()));
 }
 
+struct GatherLoopFusionTestSpec {
+  string test_name;
+  string hlo_computation_text;
+
+  static string Name(
+      const ::testing::TestParamInfo<GatherLoopFusionTestSpec>& info) {
+    return info.param.test_name;
+  }
+};
+
+class GatherLoopFusionTest
+    : public OpcodeFusionTest,
+      public ::testing::WithParamInterface<GatherLoopFusionTestSpec> {};
+
+TEST_P(GatherLoopFusionTest, GatherLoopFusion) {
+  const GatherLoopFusionTestSpec& spec = GetParam();
+  string hlo_string = tensorflow::strings::StrCat(
+      "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text);
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          tools::Parse(hlo_string));
+
+  RunFusionAndCheckOpcodesWereFused(
+      module.get(),
+      {HloOpcode::kGather, HloOpcode::kAdd, HloOpcode::kBroadcast,
+       HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter});
+}
+
+std::vector<GatherLoopFusionTestSpec> GetGatherLoopFusionTestSpecs() {
+  std::vector<GatherLoopFusionTestSpec> result;
+
+  result.push_back({"FusedTensorFlowGatherV2", R"(
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2] parameter(1)
+  gather = s32[3,2] gather(operand, indices),
+      output_window_dims={0},
+      elided_window_dims={1},
+      gather_dims_to_operand_dims={1},
+      index_vector_dim=1,
+      window_bounds={3, 1}
+  one = s32[] constant(1)
+  one_broadcasted = s32[3,2] broadcast(one), dimensions={}
+  ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
+}
+)"});
+
+  result.push_back({"FusedTensorFlowGatherMultipleBatchDims", R"(
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2,2] parameter(1)
+  gather = s32[2,3,2] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={1},
+      gather_dims_to_operand_dims={1},
+      index_vector_dim=2,
+      window_bounds={3, 1}
+  one = s32[] constant(1)
+  one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
+  ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
+}
+)"});
+
+  result.push_back({"FusedTensorFlowGatherNdMultipleBatchDims", R"(
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2,2,2] parameter(1)
+  gather = s32[2,2] gather(operand, indices),
+      output_window_dims={},
+      elided_window_dims={0,1},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=2,
+      window_bounds={1, 1}
+  one = s32[] constant(1)
+  one_broadcasted = s32[2,2] broadcast(one), dimensions={}
+  ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
+}
+)"});
+
+  result.push_back({"FusedTensorFlowGatherNd_0", R"(
+ENTRY main {
+  operand = s32[3,3,2] parameter(0)
+  indices = s32[2,2] parameter(1)
+  gather = s32[2,2] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0,1},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=1,
+      window_bounds={1,1,2}
+  one = s32[] constant(1)
+  one_broadcasted = s32[2,2] broadcast(one), dimensions={}
+  ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
+}
+)"});
+
+  result.push_back({"FusedTensorFlowGatherNd_1", R"(
+ENTRY main {
+  operand = s32[3,3,2] parameter(0)
+  indices = s32[2,2] parameter(1)
+  gather = s32[2,2] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0,1},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=0,
+      window_bounds={1,1,2}
+  one = s32[] constant(1)
+  one_broadcasted = s32[2,2] broadcast(one), dimensions={}
+  ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
+}
+)"});
+
+  result.push_back({"FusedDynamicSlice", R"(
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2] parameter(1)
+  gather = s32[1,1] gather(operand, indices),
+      output_window_dims={0,1},
+      elided_window_dims={},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=0,
+      window_bounds={1,1}
+  one = s32[] constant(1)
+  one_broadcasted = s32[1,1] broadcast(one), dimensions={}
+  ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
+}
+)"});
+
+  result.push_back({"FusedBatchDynamicSlice", R"(
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2,2] parameter(1)
+  gather = s32[2,1,1] gather(operand, indices),
+      output_window_dims={1,2},
+      elided_window_dims={},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=0,
+      window_bounds={1,1}
+  one = s32[] constant(1)
+  one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
+  ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
+}
+)"});
+
+  return result;
+}
+
+INSTANTIATE_TEST_CASE_P(GatherLoopFusionTestInstantiation, GatherLoopFusionTest,
+                        ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()),
+                        GatherLoopFusionTestSpec::Name);
 }  // namespace
 }  // namespace cpu
 }  // namespace xla
index 38b5efa..4b01c87 100644 (file)
@@ -1587,6 +1587,92 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
         }
         return operand_to_generator.at(input_hlo)(input_index);
       };
+
+    case HloOpcode::kGather:
+      return [this, hlo, &operand_to_generator](
+                 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+        const Shape& operand_shape = hlo->operand(0)->shape();
+        const Shape& indices_shape = hlo->operand(1)->shape();
+        const Shape& output_shape = hlo->shape();
+
+        const GatherDimensionNumbers& dim_numbers =
+            hlo->gather_dimension_numbers();
+
+        const llvm_ir::ElementGenerator& operand_generator =
+            operand_to_generator.at(hlo->operand(0));
+        const llvm_ir::ElementGenerator& indices_generator =
+            operand_to_generator.at(hlo->operand(1));
+
+        // This is the index into `operand` that holds the element we want to
+        // generate.  This index "unsafe" as in the components in here may be
+        // out of bounds.
+        IrArray::Index unsafe_operand_index;
+
+        // First copy in the window indices to unsafe_operand_index.
+        for (int64 i = 0, e = operand_shape.dimensions_size(),
+                   unsafe_operand_index_dim = 0;
+             i < e; i++) {
+          if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
+            unsafe_operand_index.push_back(ir_builder_->getInt64(0));
+          } else {
+            unsafe_operand_index.push_back(index[dim_numbers.output_window_dims(
+                unsafe_operand_index_dim++)]);
+          }
+        }
+
+        // This is the index of the index vector in the gather_indices tensor.
+        IrArray::Index gather_index_index;
+        {
+          std::vector<llvm::Value*> gather_index_index_components;
+          for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
+            if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
+              gather_index_index.push_back(index[i]);
+            }
+          }
+
+          if (gather_index_index.size() != indices_shape.dimensions_size()) {
+            gather_index_index.InsertAt(dim_numbers.index_vector_dim(),
+                                        nullptr);
+          }
+        }
+
+        auto add_to_unsafe_operand_index = [&](llvm::Value* index_component,
+                                               int64 dim) {
+          llvm::Value* gather_dim_component_extended =
+              ir_builder_->CreateSExtOrTrunc(index_component,
+                                             ir_builder_->getInt64Ty());
+          unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] =
+              ir_builder_->CreateAdd(
+                  unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(
+                      dim)],
+                  gather_dim_component_extended);
+        };
+
+        if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
+          TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
+                              indices_generator(gather_index_index));
+          add_to_unsafe_operand_index(gather_dim_component, 0);
+        } else {
+          int64 index_vector_size =
+              indices_shape.dimensions(dim_numbers.index_vector_dim());
+          for (int64 i = 0; i < index_vector_size; i++) {
+            gather_index_index[dim_numbers.index_vector_dim()] =
+                ir_builder_->getInt64(i);
+            TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
+                                indices_generator(gather_index_index));
+            add_to_unsafe_operand_index(gather_dim_component, i);
+          }
+        }
+
+        IrArray::Index safe_operand_index;
+        for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) {
+          safe_operand_index.push_back(ir_builder_->CreateURem(
+              unsafe_operand_index[i],
+              ir_builder_->getInt64(operand_shape.dimensions(i))));
+        }
+
+        return operand_generator(safe_operand_index);
+      };
     case HloOpcode::kDynamicUpdateSlice:
       return [this, hlo, &operand_to_generator](
                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
index 06cfb2a..4c3195c 100644 (file)
@@ -97,6 +97,10 @@ class IrArray {
     llvm::Value*& operator[](size_t i) { return multidim()[i]; }
 
     void push_back(llvm::Value* value) { multidim().push_back(value); }
+    void InsertAt(int64 index, llvm::Value* value) {
+      CHECK_LE(index, size());
+      multidim().insert(multidim().begin() + index, value);
+    }
 
     using iterator = std::vector<llvm::Value*>::iterator;
     using const_iterator = std::vector<llvm::Value*>::const_iterator;
index 4dd3acd..130456e 100644 (file)
@@ -399,6 +399,184 @@ ENTRY main {
   RunTest(hlo_text, operand.get(), gather_indices.get());
 }
 
+XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) {
+  const string hlo_text = R"(
+HloModule FusedTensorFlowGatherV2
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2] parameter(1)
+  gather = s32[3,2] gather(operand, indices),
+      output_window_dims={0},
+      elided_window_dims={1},
+      gather_dims_to_operand_dims={1},
+      index_vector_dim=1,
+      window_bounds={3, 1}
+  one = s32[] constant(1)
+  one_broadcasted = s32[3,2] broadcast(one), dimensions={}
+  ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
+}
+)";
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+  std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+  RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) {
+  const string hlo_text = R"(
+HloModule FusedTensorFlowGatherMultipleBatchDims
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2,2] parameter(1)
+  gather = s32[2,3,2] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={1},
+      gather_dims_to_operand_dims={1},
+      index_vector_dim=2,
+      window_bounds={3, 1}
+  one = s32[] constant(1)
+  one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
+  ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
+}
+)";
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+  std::unique_ptr<Literal> gather_indices =
+      Literal::CreateR2<int32>({{0, 2}, {2, 1}});
+  RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) {
+  const string hlo_text = R"(
+HloModule FusedTensorFlowGatherNdMultipleBatchDims
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2,2,2] parameter(1)
+  gather = s32[2,2] gather(operand, indices),
+      output_window_dims={},
+      elided_window_dims={0,1},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=2,
+      window_bounds={1, 1}
+  one = s32[] constant(1)
+  one_broadcasted = s32[2,2] broadcast(one), dimensions={}
+  ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
+}
+)";
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+  std::unique_ptr<Literal> gather_indices =
+      Literal::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
+  RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) {
+  const string hlo_text = R"(
+HloModule FusedTensorFlowGatherNd
+
+ENTRY main {
+  operand = s32[3,3,2] parameter(0)
+  indices = s32[2,2] parameter(1)
+  gather = s32[2,2] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0,1},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=1,
+      window_bounds={1,1,2}
+  one = s32[] constant(1)
+  one_broadcasted = s32[2,2] broadcast(one), dimensions={}
+  ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
+}
+)";
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
+                                {{-4, 4}, {-5, 5}, {-6, 6}},  //
+                                {{-7, 7}, {-8, 8}, {-9, 9}}});
+  std::unique_ptr<Literal> gather_indices =
+      Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+  RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest,
+           FusedTensorFlowGatherNdNonDefaultIndexVectorDim) {
+  const string hlo_text = R"(
+HloModule FusedTensorFlowGatherNd
+
+ENTRY main {
+  operand = s32[3,3,2] parameter(0)
+  indices = s32[2,2] parameter(1)
+  gather = s32[2,2] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0,1},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=0,
+      window_bounds={1,1,2}
+  one = s32[] constant(1)
+  one_broadcasted = s32[2,2] broadcast(one), dimensions={}
+  ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
+}
+)";
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
+                                {{-4, 4}, {-5, 5}, {-6, 6}},  //
+                                {{-7, 7}, {-8, 8}, {-9, 9}}});
+  std::unique_ptr<Literal> gather_indices =
+      Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+  RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) {
+  const char* hlo_text = R"(
+HloModule FusedDynamicSlice
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2] parameter(1)
+  gather = s32[1,1] gather(operand, indices),
+      output_window_dims={0,1},
+      elided_window_dims={},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=0,
+      window_bounds={1,1}
+  one = s32[] constant(1)
+  one_broadcasted = s32[1,1] broadcast(one), dimensions={}
+  ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
+}
+)";
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+  std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
+  RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) {
+  const string hlo_text = R"(
+HloModule FusedBatchDynamicSlice
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2,2] parameter(1)
+  gather = s32[2,1,1] gather(operand, indices),
+      output_window_dims={1,2},
+      elided_window_dims={},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=0,
+      window_bounds={1,1}
+  one = s32[] constant(1)
+  one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
+  ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
+}
+)";
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+  std::unique_ptr<Literal> gather_indices =
+      Literal::CreateR2<int32>({{2, 1}, {1, 1}});
+  RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
 class GatherClientLibraryTest : public ClientLibraryTestBase {};
 
 XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {