Use 32 bit induction variable in gather expander
authorSanjoy Das <sanjoy@google.com>
Tue, 20 Mar 2018 18:13:48 +0000 (11:13 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 18:18:12 +0000 (11:18 -0700)
Right now this is unconditional (and we fail with Unimplemented() if a 32 bit
induction variable is not large enough), but eventually we may want to be
smarter about this.

PiperOrigin-RevId: 189773581

tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/gather_expander.cc
tensorflow/compiler/xla/service/gather_expander_test.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/while_util.cc
tensorflow/compiler/xla/service/while_util.h
tensorflow/compiler/xla/tests/BUILD
tensorflow/compiler/xla/tests/gather_operation_test.cc
tensorflow/compiler/xla/util.h

index 43c5648..d4d6787 100644 (file)
@@ -1276,6 +1276,18 @@ tf_cc_test(
     ],
 )
 
+tf_cc_test(
+    name = "gather_expander_test",
+    srcs = ["gather_expander_test.cc"],
+    deps = [
+        ":gather_expander",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla/tests:test_macros_header",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+    ],
+)
+
 cc_library(
     name = "conditional_simplifier",
     srcs = ["conditional_simplifier.cc"],
index 488bed3..221ff79 100644 (file)
@@ -306,18 +306,33 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
   HloComputation* computation = gather_instr->parent();
   HloInstruction* operand = gather_instr->mutable_operand(0);
   HloInstruction* gather_indices = gather_instr->mutable_operand(1);
+  const Shape& gather_indices_shape = gather_indices->shape();
   const Shape& output_shape = gather_instr->shape();
   int64 output_rank = output_shape.dimensions_size();
 
   const GatherDimensionNumbers& dim_numbers =
       gather_instr->gather_dimension_numbers();
 
+  int64 gather_loop_trip_count = 1;
+  for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+    if (i != dim_numbers.index_vector_dim()) {
+      gather_loop_trip_count *= gather_indices_shape.dimensions(i);
+    }
+  }
+
+  if (!IsInt32(gather_loop_trip_count)) {
+    return Unimplemented(
+        "Gather operations with more than 2147483647 gather indices are not "
+        "supported. This error occurred for %s.",
+        gather_instr->ToString().c_str());
+  }
+
   TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices,
                       CanonicalizeGatherIndices(
                           gather_indices, dim_numbers.index_vector_dim()));
 
-  const int64 gather_loop_trip_count =
-      canonical_gather_indices->shape().dimensions(0);
+  CHECK_EQ(gather_loop_trip_count,
+           canonical_gather_indices->shape().dimensions(0));
 
   TF_ASSIGN_OR_RETURN(
       HloInstruction * accumulator_init,
diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc
new file mode 100644 (file)
index 0000000..ba41ee8
--- /dev/null
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gather_expander.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+
+namespace xla {
+namespace {
+TEST(GatherExpanderTest, ErrorStatusOnTooManyIndices) {
+  const string hlo_text = R"(
+HloModule TensorFlowGatherMultipleBatchDims
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2147483647,5] parameter(1)
+  ROOT gather = s32[2147483647,3,5] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={1},
+      gather_dims_to_operand_dims={1},
+      index_vector_dim=2,
+      window_bounds={3, 1}
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          tools::Parse(hlo_text));
+
+  Status status = GatherExpander{}.Run(module.get()).status();
+  EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
+
+  ASSERT_THAT(
+      status.error_message(),
+      ::testing::HasSubstr("Gather operations with more than 2147483647 gather "
+                           "indices are not supported."));
+}
+
+}  // namespace
+}  // namespace xla
index 8cd5882..bd07941 100644 (file)
@@ -142,23 +142,23 @@ WhileUtil::MakeInstructionsLiveIn(
 
 static StatusOr<std::unique_ptr<HloComputation>>
 MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
-                                    int64 trip_count) {
+                                    int32 trip_count) {
   Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
-  Shape scalar_s64 = ShapeUtil::MakeShape(S64, {});
 
   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> cond_computation,
                       CreateComputationWithSignature(
                           {&loop_state_shape}, scalar_pred, "while_cond"));
 
   HloInstruction* trip_count_constant = cond_computation->AddInstruction(
-      HloInstruction::CreateConstant(Literal::CreateR0<int64>(trip_count)));
+      HloInstruction::CreateConstant(Literal::CreateR0<int32>(trip_count)));
 
   HloInstruction* param = cond_computation->parameter_instruction(0);
-  TF_ASSIGN_OR_RETURN(HloInstruction * counter,
+  TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
                       MakeGetTupleElementHlo(param, 0));
+
   TF_ASSIGN_OR_RETURN(
       HloInstruction * compare,
-      MakeBinaryHlo(HloOpcode::kLt, counter, trip_count_constant));
+      MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant));
   cond_computation->set_root_instruction(compare);
   return std::move(cond_computation);
 }
@@ -171,8 +171,7 @@ static StatusOr<std::unique_ptr<HloComputation>> MakeCountedLoopBodyComputation(
                       CreateComputationWithSignature(
                           {&loop_state_shape}, loop_state_shape, "while_body"));
   HloInstruction* one = body_computation->AddInstruction(
-      HloInstruction::CreateConstant(Literal::CreateR0<int64>(1)));
-
+      HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
   HloInstruction* param = body_computation->parameter_instruction(0);
   TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
                       MakeGetTupleElementHlo(param, 0));
@@ -200,7 +199,7 @@ static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
   std::vector<HloInstruction*> init_values_with_indvar;
   init_values_with_indvar.reserve(init_values.size() + 1);
   HloInstruction* zero = computation->AddInstruction(
-      HloInstruction::CreateConstant(Literal::CreateR0<int64>(0)));
+      HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
   init_values_with_indvar.push_back(zero);
   c_copy(init_values, std::back_inserter(init_values_with_indvar));
   return computation->AddInstruction(
@@ -210,16 +209,18 @@ static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
 static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) {
   std::vector<Shape> loop_state_shape_components;
   loop_state_shape_components.reserve(init_values.size() + 1);
-  loop_state_shape_components.push_back(ShapeUtil::MakeShape(S64, {}));
+  loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {}));
   c_transform(init_values, std::back_inserter(loop_state_shape_components),
               [](HloInstruction* instr) { return instr->shape(); });
   return ShapeUtil::MakeTupleShape(loop_state_shape_components);
 }
 
 /*static*/ StatusOr<WhileUtil::LoopStateTy> WhileUtil::MakeCountedLoop(
-    HloComputation* computation, int64 trip_count,
+    HloComputation* computation, int32 trip_count,
     const WhileUtil::LoopStateTy& init_values,
     const WhileUtil::LoopBodyGeneratorTy& loop_body_generator) {
+  CHECK_GE(trip_count, 0);
+
   Shape loop_state_shape = MakeLoopStateShape(init_values);
   TF_ASSIGN_OR_RETURN(
       std::unique_ptr<HloComputation> cond,
index 80f7e16..1688d46 100644 (file)
@@ -71,7 +71,7 @@ class WhileUtil {
   //    return loop_state;
   //  }
   static StatusOr<LoopStateTy> MakeCountedLoop(
-      HloComputation* computation, int64 trip_count,
+      HloComputation* computation, int32 trip_count,
       const LoopStateTy& init_values,
       const LoopBodyGeneratorTy& loop_body_generator);
 };
index 025ac12..04a9c1e 100644 (file)
@@ -676,7 +676,9 @@ xla_test(
     name = "gather_operation_test",
     srcs = ["gather_operation_test.cc"],
     deps = [
+        ":client_library_test_base",
         ":hlo_test_base",
+        "//tensorflow/compiler/xla:execution_options_util",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
index 8ba9194..9db68ff 100644 (file)
@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/compiler/xla/execution_options_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 #include "tensorflow/compiler/xla/tests/test_macros.h"
 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
@@ -397,5 +399,63 @@ ENTRY main {
   RunTest(hlo_text, operand.get(), gather_indices.get());
 }
 
+class GatherClientLibraryTest : public ClientLibraryTestBase {};
+
+// TODO(b/30671675): Asynchronous execution on stream is not yet supported on
+// GPU and CPU_PARALLEL.
+XLA_TEST_F(GatherClientLibraryTest,
+           DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(Basic))) {
+  // We create this HLO, but using the ComputationBuilder API.
+  //
+  // ENTRY main {
+  //   operand = s32[3,3] parameter(0)
+  //   indices = s32[2] parameter(1)
+  //   ROOT gather = s32[2,3] gather(operand, indices),
+  //       output_window_dims={1},
+  //       elided_window_dims={0},
+  //       gather_dims_to_operand_dims={0},
+  //       index_vector_dim=1,
+  //       window_bounds={1, 3}
+  // }
+
+  ComputationBuilder builder(client_, "gather_basic");
+
+  Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3});
+  Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
+
+  auto operand = builder.Parameter(0, operand_shape, "operand");
+  auto indices = builder.Parameter(1, indices_shape, "indices");
+  GatherDimensionNumbers dim_numbers;
+  dim_numbers.add_output_window_dims(1);
+  dim_numbers.add_elided_window_dims(0);
+  dim_numbers.add_gather_dims_to_operand_dims(0);
+  dim_numbers.set_index_vector_dim(1);
+  builder.Gather(operand, indices, dim_numbers, {1, 3});
+
+  std::vector<int32> expected = {};
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> operand_arg,
+                          client_->TransferToServer(*Literal::CreateR2<int32>(
+                              {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<GlobalData> indices_arg,
+      client_->TransferToServer(*Literal::CreateR1<int32>({0, 2})));
+  TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
+                          client_->GetDeviceHandles(1));
+  xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions();
+  *execution_options.add_device_handles() = devices[0];
+  TF_ASSERT_OK_AND_ASSIGN(Computation computation, builder.Build());
+  std::vector<xla::Client::ComputationInstance> computation_instances = {
+      {computation,
+       {operand_arg.get(), indices_arg.get()},
+       execution_options,
+       /*execution_profile=*/nullptr}};
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::vector<std::unique_ptr<xla::GlobalData>> result_data,
+      client_->ExecuteParallel(computation_instances));
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
+                          client_->Transfer(*(result_data[0])));
+  LiteralTestUtil::ExpectEqual(
+      *result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}));
+}
 }  // namespace
 }  // namespace xla
index ff99d37..2da9f9e 100644 (file)
@@ -519,6 +519,15 @@ int64 FindIndex(const C& c, Value&& value) {
   auto it = c_find(c, std::forward<Value>(value));
   return std::distance(c.begin(), it);
 }
+
+// Returns true if `x` fits in 32-bits.
+template <typename T>
+bool IsInt32(T x) {
+  // Following conversion rules: "the value is unchanged if it can be
+  // represented in the destination type (and bit-field width); otherwise, the
+  // value is implementation-defined."
+  return static_cast<int32>(x) == x;
+}
 }  // namespace xla
 
 #define XLA_LOG_LINES(SEV, STRING) \