],
)
+tf_cc_test(
+ name = "gather_expander_test",
+ srcs = ["gather_expander_test.cc"],
+ deps = [
+ ":gather_expander",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/tests:test_macros_header",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ ],
+)
+
cc_library(
name = "conditional_simplifier",
srcs = ["conditional_simplifier.cc"],
HloComputation* computation = gather_instr->parent();
HloInstruction* operand = gather_instr->mutable_operand(0);
HloInstruction* gather_indices = gather_instr->mutable_operand(1);
+ const Shape& gather_indices_shape = gather_indices->shape();
const Shape& output_shape = gather_instr->shape();
int64 output_rank = output_shape.dimensions_size();
const GatherDimensionNumbers& dim_numbers =
gather_instr->gather_dimension_numbers();
+ int64 gather_loop_trip_count = 1;
+ for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ if (i != dim_numbers.index_vector_dim()) {
+ gather_loop_trip_count *= gather_indices_shape.dimensions(i);
+ }
+ }
+
+ if (!IsInt32(gather_loop_trip_count)) {
+ return Unimplemented(
+ "Gather operations with more than 2147483647 gather indices are not "
+ "supported. This error occurred for %s.",
+ gather_instr->ToString().c_str());
+ }
+
TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices,
CanonicalizeGatherIndices(
gather_indices, dim_numbers.index_vector_dim()));
- const int64 gather_loop_trip_count =
- canonical_gather_indices->shape().dimensions(0);
+ CHECK_EQ(gather_loop_trip_count,
+ canonical_gather_indices->shape().dimensions(0));
TF_ASSIGN_OR_RETURN(
HloInstruction * accumulator_init,
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gather_expander.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+
+namespace xla {
+namespace {
+TEST(GatherExpanderTest, ErrorStatusOnTooManyIndices) {
+ const string hlo_text = R"(
+HloModule TensorFlowGatherMultipleBatchDims
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2147483647,5] parameter(1)
+ ROOT gather = s32[2147483647,3,5] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=2,
+ window_bounds={3, 1}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_text));
+
+ Status status = GatherExpander{}.Run(module.get()).status();
+ EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
+
+ ASSERT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("Gather operations with more than 2147483647 gather "
+ "indices are not supported."));
+}
+
+} // namespace
+} // namespace xla
static StatusOr<std::unique_ptr<HloComputation>>
MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
- int64 trip_count) {
+ int32 trip_count) {
Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
- Shape scalar_s64 = ShapeUtil::MakeShape(S64, {});
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> cond_computation,
CreateComputationWithSignature(
{&loop_state_shape}, scalar_pred, "while_cond"));
HloInstruction* trip_count_constant = cond_computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int64>(trip_count)));
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(trip_count)));
HloInstruction* param = cond_computation->parameter_instruction(0);
- TF_ASSIGN_OR_RETURN(HloInstruction * counter,
+ TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
MakeGetTupleElementHlo(param, 0));
+
TF_ASSIGN_OR_RETURN(
HloInstruction * compare,
- MakeBinaryHlo(HloOpcode::kLt, counter, trip_count_constant));
+ MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant));
cond_computation->set_root_instruction(compare);
return std::move(cond_computation);
}
CreateComputationWithSignature(
{&loop_state_shape}, loop_state_shape, "while_body"));
HloInstruction* one = body_computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int64>(1)));
-
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
HloInstruction* param = body_computation->parameter_instruction(0);
TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
MakeGetTupleElementHlo(param, 0));
std::vector<HloInstruction*> init_values_with_indvar;
init_values_with_indvar.reserve(init_values.size() + 1);
HloInstruction* zero = computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int64>(0)));
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
init_values_with_indvar.push_back(zero);
c_copy(init_values, std::back_inserter(init_values_with_indvar));
return computation->AddInstruction(
static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) {
std::vector<Shape> loop_state_shape_components;
loop_state_shape_components.reserve(init_values.size() + 1);
- loop_state_shape_components.push_back(ShapeUtil::MakeShape(S64, {}));
+ loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {}));
c_transform(init_values, std::back_inserter(loop_state_shape_components),
[](HloInstruction* instr) { return instr->shape(); });
return ShapeUtil::MakeTupleShape(loop_state_shape_components);
}
/*static*/ StatusOr<WhileUtil::LoopStateTy> WhileUtil::MakeCountedLoop(
- HloComputation* computation, int64 trip_count,
+ HloComputation* computation, int32 trip_count,
const WhileUtil::LoopStateTy& init_values,
const WhileUtil::LoopBodyGeneratorTy& loop_body_generator) {
+ CHECK_GE(trip_count, 0);
+
Shape loop_state_shape = MakeLoopStateShape(init_values);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloComputation> cond,
// return loop_state;
// }
static StatusOr<LoopStateTy> MakeCountedLoop(
- HloComputation* computation, int64 trip_count,
+ HloComputation* computation, int32 trip_count,
const LoopStateTy& init_values,
const LoopBodyGeneratorTy& loop_body_generator);
};
name = "gather_operation_test",
srcs = ["gather_operation_test.cc"],
deps = [
+ ":client_library_test_base",
":hlo_test_base",
+ "//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
RunTest(hlo_text, operand.get(), gather_indices.get());
}
+class GatherClientLibraryTest : public ClientLibraryTestBase {};
+
+// TODO(b/30671675): Asynchronous execution on stream is not yet supported on
+// GPU and CPU_PARALLEL.
+XLA_TEST_F(GatherClientLibraryTest,
+ DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(Basic))) {
+ // We create this HLO, but using the ComputationBuilder API.
+ //
+ // ENTRY main {
+ // operand = s32[3,3] parameter(0)
+ // indices = s32[2] parameter(1)
+ // ROOT gather = s32[2,3] gather(operand, indices),
+ // output_window_dims={1},
+ // elided_window_dims={0},
+ // gather_dims_to_operand_dims={0},
+ // index_vector_dim=1,
+ // window_bounds={1, 3}
+ // }
+
+ ComputationBuilder builder(client_, "gather_basic");
+
+ Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3});
+ Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
+
+ auto operand = builder.Parameter(0, operand_shape, "operand");
+ auto indices = builder.Parameter(1, indices_shape, "indices");
+ GatherDimensionNumbers dim_numbers;
+ dim_numbers.add_output_window_dims(1);
+ dim_numbers.add_elided_window_dims(0);
+ dim_numbers.add_gather_dims_to_operand_dims(0);
+ dim_numbers.set_index_vector_dim(1);
+ builder.Gather(operand, indices, dim_numbers, {1, 3});
+
+ std::vector<int32> expected = {};
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> operand_arg,
+ client_->TransferToServer(*Literal::CreateR2<int32>(
+ {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<GlobalData> indices_arg,
+ client_->TransferToServer(*Literal::CreateR1<int32>({0, 2})));
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
+ client_->GetDeviceHandles(1));
+ xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions();
+ *execution_options.add_device_handles() = devices[0];
+ TF_ASSERT_OK_AND_ASSIGN(Computation computation, builder.Build());
+ std::vector<xla::Client::ComputationInstance> computation_instances = {
+ {computation,
+ {operand_arg.get(), indices_arg.get()},
+ execution_options,
+ /*execution_profile=*/nullptr}};
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::vector<std::unique_ptr<xla::GlobalData>> result_data,
+ client_->ExecuteParallel(computation_instances));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
+ client_->Transfer(*(result_data[0])));
+ LiteralTestUtil::ExpectEqual(
+ *result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}));
+}
} // namespace
} // namespace xla
auto it = c_find(c, std::forward<Value>(value));
return std::distance(c.begin(), it);
}
+
+// Returns true if `x` fits in 32-bits.
+template <typename T>
+bool IsInt32(T x) {
+ // Following conversion rules: "the value is unchanged if it can be
+ // represented in the destination type (and bit-field width); otherwise, the
+ // value is implementation-defined."
+ return static_cast<int32>(x) == x;
+}
} // namespace xla
#define XLA_LOG_LINES(SEV, STRING) \