return Status::OK();
}
+Status Literal::CopyElementFrom(const Literal& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_index,
+ tensorflow::gtl::ArraySlice<int64> dest_index) {
+ DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
+ const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+ src_literal.shape(), src_index);
+ const int64 dest_linear_index =
+ IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
+ const int64 primitive_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
+
+ char* dest_address =
+ static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
+ const char* source_address =
+ static_cast<const char*>(src_literal.untyped_data()) +
+ src_linear_index * primitive_size;
+ if (dest_address != source_address) {
+ memcpy(dest_address, source_address, primitive_size);
+ }
+ return Status::OK();
+}
+
std::vector<Literal> Literal::DecomposeTuple() {
CHECK(ShapeUtil::IsTuple(shape()));
std::vector<Literal> elements;
DimensionVector result_dimensions;
for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) {
CHECK_GE(start_indices[dnum], 0);
- CHECK_LE(limit_indices[dnum], shape().dimensions(dnum));
+ CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
+ << "dnum = " << dnum;
int64 dimension = limit_indices[dnum] - start_indices[dnum];
- CHECK_GE(dimension, 0);
+ CHECK_GE(dimension, 0) << "dnum = " << dnum;
result_dimensions.push_back(dimension);
}
const auto result_shape =
tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size);
+ // Copies one element from src_literal[src_index] to (*this)[dest_index].
+ Status CopyElementFrom(const Literal& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_index,
+ tensorflow::gtl::ArraySlice<int64> dest_index);
+
// Returns a vector containing the tuple elements of this Literal as separate
// Literals. This Literal must be tuple-shaped and can be a nested tuple. The
// elements are moved into the new Literals; no data is copied. Upon return
return Status::OK();
}
+// Returns an ShapeUtil::IndexIterationSpace that iterates over the output
+// gather dimensions while keeping the rest of the output dimensions clamped to
+// 0.
+ShapeUtil::IndexIterationSpace IterationSpaceForOutputGatherIndices(
+ const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
+ int64 output_rank = output_shape.dimensions_size();
+ std::vector<int64> index_base(output_rank, 0);
+ std::vector<int64> index_count;
+ index_count.reserve(output_rank);
+ for (int64 i = 0; i < output_rank; i++) {
+ bool is_output_gather_dim =
+ !c_binary_search(dim_numbers.output_window_dims(), i);
+ index_count.push_back(is_output_gather_dim ? output_shape.dimensions(i)
+ : 1);
+ }
+
+ return {std::move(index_base), std::move(index_count),
+ std::vector<int64>(output_rank, 1)};
+}
+
+// Return an ShapeUtil::IndexIterationSpace that iterates over the output window
+// dimensions while keeping the rest of the output dimensions clamped to 0.
+ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices(
+ int64 output_rank, ArraySlice<int64> window_bounds,
+ const GatherDimensionNumbers& dim_numbers) {
+ std::vector<int64> index_base(output_rank, 0);
+ std::vector<int64> index_count(output_rank, 1);
+ int64 window_bounds_idx = 0;
+ for (int64 i = 0; i < output_rank; i++) {
+ bool is_output_window_dim =
+ c_binary_search(dim_numbers.output_window_dims(), i);
+ if (is_output_window_dim) {
+ while (c_binary_search(dim_numbers.elided_window_dims(),
+ window_bounds_idx)) {
+ window_bounds_idx++;
+ }
+ index_count[i] = window_bounds[window_bounds_idx++];
+ }
+ }
+
+ return {std::move(index_base), std::move(index_count),
+ std::vector<int64>(output_rank, 1)};
+}
+
+// This functor computes the contribution of gather_indices to an input index
+// corresponding to an output index. That is, given an output index I, it picks
+// out the gather output indices in I and uses them to look up a gather index,
+// G, from the gather indices tensor, and expands G into the input space
+// according to gather_dims_to_operand_dims.
+class OutputGatherIndexToInputIndex {
+ public:
+ // The constructor does some setup work that is amortized across all
+ // iterations.
+ explicit OutputGatherIndexToInputIndex(
+ const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
+ const Shape& output_shape, const Literal* gather_indices)
+ : dim_numbers_(*dim_numbers), gather_indices_(*gather_indices) {
+ for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
+ output_dim_is_gather_dims_.push_back(
+ !c_binary_search(dim_numbers_.output_window_dims(), i));
+ }
+
+ for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
+ int64 index_of_input_dim_in_index_vector =
+ std::distance(dim_numbers_.gather_dims_to_operand_dims().begin(),
+ c_find(dim_numbers_.gather_dims_to_operand_dims(), i));
+ if (index_of_input_dim_in_index_vector ==
+ dim_numbers_.gather_dims_to_operand_dims_size()) {
+ input_dim_value_to_index_vector_.push_back(-1);
+ } else {
+ input_dim_value_to_index_vector_.push_back(
+ index_of_input_dim_in_index_vector);
+ }
+ }
+
+ index_vector_index_.resize(gather_indices_.shape().dimensions_size());
+ input_index_.resize(input_shape.dimensions_size());
+ int64 index_vector_size =
+ gather_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
+ index_vector_.resize(index_vector_size);
+ }
+
+ // Returns the contribution of gather_indices to the input index corresponding
+ // to output_index. See gather_inner_loop_body.
+ //
+ // This is conceptually a stateless transformation from output_index to the
+ // gather input index, but:
+ //
+ // - Instead of allocating memory to represent the gather input index on
+ // every invocation we reuse the same storage for the result
+ // (input_index_), mutating it in place.
+ // - Instead of allocating buffers for temporary values like
+ // index_vector_index_ and index_vector on every invocation, we reuse the
+ // same storage for all invocations.
+ //
+ // This returns an arrayslice into memory owned by the class.
+ StatusOr<ArraySlice<int64>> operator()(ArraySlice<int64> output_index) {
+ PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
+ TF_RETURN_IF_ERROR(FetchIndexVector());
+ PropagateIndexVectorToInputIndex();
+ return ArraySlice<int64>(input_index_);
+ }
+
+ private:
+ // Propagates the gather index dimensions from the output index into
+ // index_vector_index_ by mutating index_vector_index_ in place. Does not
+ // update the dim_numbers.index_vector_dim() dimension -- that's the dimension
+ // we iterate over in FetchIndexVector.
+ void PropagateOutputIndexGatherDimsToIndexVectorIndex(
+ ArraySlice<int64> output_index) {
+ int64 index_vector_index_i = 0;
+ for (int64 i = 0, e = output_index.size(); i < e; i++) {
+ if (!output_dim_is_gather_dims_[i]) {
+ continue;
+ }
+
+ if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
+ index_vector_index_i++;
+ }
+
+ index_vector_index_[index_vector_index_i++] = output_index[i];
+ }
+ }
+
+ // Populates index_vector_ by iterating over gather_indices_ according to
+ // index_vector_index_.
+ Status FetchIndexVector() {
+ int64 index_vector_dim = dim_numbers_.index_vector_dim();
+ for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
+ index_vector_index_[index_vector_dim] = i;
+ TF_ASSIGN_OR_RETURN(index_vector_[i], gather_indices_.GetIntegralAsS64(
+ index_vector_index_));
+ }
+ return Status::OK();
+ }
+
+ // Populates input_index_.
+ void PropagateIndexVectorToInputIndex() {
+ for (int64 i = 0, e = input_index_.size(); i < e; i++) {
+ if (input_dim_value_to_index_vector_[i] != -1) {
+ input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
+ }
+
+ // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
+ // remains 0, as set by the constructor.
+ }
+ }
+
+ // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
+ // the input index from the index vector. See
+ // PropagateIndexVectorToInputIndex.
+ std::vector<int64> input_dim_value_to_index_vector_;
+
+ // output_dim_is_gather_dims_[i] is true iff the output index i is a gather
+ // dimension.
+ std::vector<bool> output_dim_is_gather_dims_;
+
+ // The buffer into which we construct an index into gather_indices_ to fetch
+ // the index vector.
+ std::vector<int64> index_vector_index_;
+
+ // The index vector fetched from gather_indices_.
+ std::vector<int64> index_vector_;
+
+ // The result computed by this functor. operator() returns an ArraySlice into
+ // this vector.
+ std::vector<int64> input_index_;
+
+ const GatherDimensionNumbers& dim_numbers_;
+ const Literal& gather_indices_;
+};
+
+// This functor computes the contribution of the window indices in an output
+// index to an input index. That is, given an output index I it picks out the
+// output window indices in I and expands it into a window index into the input
+// shape.
+class OutputWindowIndexToInputIndex {
+ public:
+ // The constructor does some setup work that is amortized across all
+ // iterations.
+ explicit OutputWindowIndexToInputIndex(
+ const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
+ const Shape& output_shape) {
+ std::vector<int64> window_index_to_output_index;
+ int64 output_index_count = 0;
+ for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
+ if (c_binary_search(dim_numbers.output_window_dims(), i)) {
+ window_index_to_output_index.push_back(output_index_count++);
+ } else {
+ output_index_count++;
+ }
+ }
+
+ int64 window_dim_count = 0;
+ for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
+ if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
+ input_dim_value_to_output_index_.push_back(-1);
+ } else {
+ input_dim_value_to_output_index_.push_back(
+ window_index_to_output_index[window_dim_count++]);
+ }
+ }
+
+ input_index_.resize(input_shape.dimensions_size());
+ }
+
+ // Returns the contribution of the window indices to the input index
+ // corresponding to output_index. See gather_inner_loop_body.
+ //
+ // This is conceptually a stateless transformation from output_index to the
+ // window input index, but instead of allocating memory to represent the
+ // gather input index on every invocation we reuse the same storage for the
+ // result (input_index_), mutating it in place.
+ //
+ // This returns an arrayslice into memory owned by the class.
+ StatusOr<ArraySlice<int64>> operator()(ArraySlice<int64> output_index) {
+ PropagateOutputIndexWindowDimsToInputIndex(output_index);
+ return ArraySlice<int64>(input_index_);
+ }
+
+ private:
+ // Propagates window dimensions from the output index to input_index_ by
+ // mutating input_index_ in place.
+ void PropagateOutputIndexWindowDimsToInputIndex(
+ ArraySlice<int64> output_index) {
+ for (int64 i = 0, e = input_index_.size(); i < e; i++) {
+ if (input_dim_value_to_output_index_[i] != -1) {
+ input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
+ }
+
+ // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
+ // remains 0, as set by the constructor.
+ }
+ }
+
+ // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
+ // the input index from the output index. See
+ // PropagateOutputIndexToInputIndex.
+ std::vector<int64> input_dim_value_to_output_index_;
+
+ // The result computed by this functor. operator() returns an ArraySlice into
+ // this vector.
+ std::vector<int64> input_index_;
+};
+
+// Rehapes the gather indices input to have a trailing degenerate `1` dimension
+// if necessary. Hands over the ownership of the newly created literal (if
+// there is one) to `reshaped_gather_indices`.
+static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
+ int64 index_vector_dim, const Literal& gather_indices,
+ std::unique_ptr<Literal>* reshaped_gather_indices) {
+ if (gather_indices.shape().dimensions_size() != index_vector_dim) {
+ return std::cref(gather_indices);
+ }
+
+ std::vector<int64> new_shape(gather_indices.shape().dimensions().begin(),
+ gather_indices.shape().dimensions().end());
+ new_shape.push_back(1);
+ TF_ASSIGN_OR_RETURN(*reshaped_gather_indices,
+ gather_indices.Reshape(new_shape));
+ return std::cref(**reshaped_gather_indices);
+}
+
+Status HloEvaluator::HandleGather(HloInstruction* gather) {
+ std::unique_ptr<Literal> result = Literal::CreateFromShape(gather->shape());
+ const Shape& shape = gather->shape();
+ const GatherDimensionNumbers& dim_numbers =
+ gather->gather_dimension_numbers();
+ const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
+ std::unique_ptr<Literal> reshaped_gather_indices;
+ TF_ASSIGN_OR_RETURN(
+ const Literal& gather_indices,
+ ReshapedGatherIndices(dim_numbers.index_vector_dim(),
+ GetEvaluatedLiteralFor(gather->operand(1)),
+ &reshaped_gather_indices));
+
+ // We iterate over the gather dimensions in the output shape in an outer loop
+ // nest, and iterate over the window dimensions in the output shape in an
+ // inner loop nest.
+
+ ShapeUtil::IndexIterationSpace gather_indices_iteration_space =
+ IterationSpaceForOutputGatherIndices(shape, dim_numbers);
+ ShapeUtil::IndexIterationSpace window_indices_iteration_space =
+ IterationSpaceForOutputWindowIndices(
+ shape.dimensions_size(), gather->gather_window_bounds(), dim_numbers);
+
+ // Scratch buffers that hold an index in the output shape and the
+ // corresponding index in the input shape.
+ std::vector<int64> input_index(operand.shape().dimensions_size());
+ std::vector<int64> output_index(gather->shape().dimensions_size());
+
+ OutputGatherIndexToInputIndex output_gather_index_to_input_index(
+ &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
+ /*output_shape=*/shape, &gather_indices);
+ OutputWindowIndexToInputIndex output_window_index_to_input_index(
+ gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
+ /*output_shape=*/shape);
+
+ auto gather_inner_loop_body =
+ [&](ArraySlice<int64> output_window_index,
+ ArraySlice<int64> input_gather_index,
+ ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
+ TF_ASSIGN_OR_RETURN(
+ ArraySlice<int64> input_window_index,
+ output_window_index_to_input_index(output_window_index));
+ for (int i = 0, e = output_index.size(); i < e; i++) {
+ output_index[i] = output_gather_index[i] + output_window_index[i];
+ }
+ for (int i = 0, e = input_index.size(); i < e; i++) {
+ input_index[i] = input_gather_index[i] + input_window_index[i];
+ }
+ TF_RETURN_IF_ERROR(
+ result->CopyElementFrom(operand, input_index, output_index));
+ return true;
+ };
+
+ auto gather_outer_loop_body =
+ [&](ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
+ TF_ASSIGN_OR_RETURN(
+ ArraySlice<int64> input_gather_index,
+ output_gather_index_to_input_index(output_gather_index));
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ shape, window_indices_iteration_space,
+ std::bind(gather_inner_loop_body, std::placeholders::_1,
+ input_gather_index, output_gather_index)));
+ return true;
+ };
+
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ shape, gather_indices_iteration_space, gather_outer_loop_body));
+ evaluated_[gather] = std::move(result);
+ return Status::OK();
+}
+
Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const auto result_shape = get_tuple_element->shape();
const int64 index = get_tuple_element->tuple_index();
Status HandleTuple(HloInstruction* tuple) override;
+ Status HandleGather(HloInstruction* gather) override;
+
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleCopy(HloInstruction* copy) override;
*result.ValueOrDie());
}
+TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
+ const char* hlo_text = R"(
+HloModule TensorFlowGatherV1
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ ROOT gather = s32[2,3] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1, 3}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+ std::unique_ptr<Literal> operand =
+ Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralTestUtil::ExpectEqual(
+ *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
+ *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
+ const char* hlo_text = R"(
+HloModule TensorFlowGatherV2
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ ROOT gather = s32[3,2] gather(operand, indices),
+ output_window_dims={0},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=1,
+ window_bounds={3, 1}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+ std::unique_ptr<Literal> operand =
+ Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralTestUtil::ExpectEqual(
+ *Literal::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
+ *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
+ const char* hlo_text = R"(
+HloModule TensorFlowGatherMultipleBatchDims
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2,2] parameter(1)
+ ROOT gather = s32[2,3,2] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=2,
+ window_bounds={3, 1}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+ std::unique_ptr<Literal> operand =
+ Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ Literal::CreateR2<int32>({{0, 2}, {2, 1}});
+ LiteralTestUtil::ExpectEqual(
+ *Literal::CreateR3<int32>(
+ {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
+ *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
+ const char* hlo_text = R"(
+HloModule TensorFlowGatherNd
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[2,2] parameter(1)
+ ROOT gather = s32[2,2] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0,1},
+ gather_dims_to_operand_dims={0,1},
+ index_vector_dim=1,
+ window_bounds={1,1,2}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+ std::unique_ptr<Literal> operand =
+ Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ std::unique_ptr<Literal> gather_indices =
+ Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralTestUtil::ExpectEqual(
+ *Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}),
+ *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest,
+ EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) {
+ const char* hlo_text = R"(
+HloModule TensorFlowGatherNd
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[2,2] parameter(1)
+ ROOT gather = s32[2,2] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0,1},
+ gather_dims_to_operand_dims={0,1},
+ index_vector_dim=0,
+ window_bounds={1,1,2}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+ std::unique_ptr<Literal> operand =
+ Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ std::unique_ptr<Literal> gather_indices =
+ Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralTestUtil::ExpectEqual(
+ *Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}),
+ *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
+ const char* hlo_text = R"(
+HloModule DynamicSlice
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ ROOT gather = s32[1,1] gather(operand, indices),
+ output_window_dims={0,1},
+ elided_window_dims={},
+ gather_dims_to_operand_dims={0,1},
+ index_vector_dim=0,
+ window_bounds={1,1}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+ std::unique_ptr<Literal> operand =
+ Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
+ LiteralTestUtil::ExpectEqual(
+ *Literal::CreateR2<int32>({{5}}),
+ *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
+ const char* hlo_text = R"(
+HloModule BatchDynamicSlice
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2,2] parameter(1)
+ ROOT gather = s32[2,1,1] gather(operand, indices),
+ output_window_dims={1,2},
+ elided_window_dims={},
+ gather_dims_to_operand_dims={0,1},
+ index_vector_dim=0,
+ window_bounds={1,1}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+ std::unique_ptr<Literal> operand =
+ Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ Literal::CreateR2<int32>({{2, 1}, {1, 1}});
+ LiteralTestUtil::ExpectEqual(
+ *Literal::CreateR3<int32>({{{8}}, {{5}}}),
+ *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
+ const char* hlo_text = R"(
+HloModule TensorFlowGatherV1
+
+ENTRY main {
+ operand = s32[3,0] parameter(0)
+ indices = s32[2] parameter(1)
+ ROOT gather = s32[2,0] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1, 0}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+ std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}});
+ std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralTestUtil::ExpectEqual(
+ *Literal::CreateR2<int32>({{}, {}}),
+ *Evaluate({operand.get(), gather_indices.get()}));
+}
+
INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest,
::testing::ValuesIn(use_bf16_params));
return Status::OK();
}
+ // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
+ struct IndexIterationSpace {
+ std::vector<int64> index_base;
+ std::vector<int64> index_count;
+ std::vector<int64> index_incr;
+ };
+
+ template <typename FnTy>
+ static Status ForEachIndexWithStatus(
+ const Shape& shape, const IndexIterationSpace& iteration_space,
+ FnTy&& function) {
+ return ShapeUtil::ForEachIndexWithStatus(
+ shape, iteration_space.index_base, iteration_space.index_count,
+ iteration_space.index_incr, std::forward<FnTy>(function));
+ }
+
template <typename FnType>
static void ForEachIndex(const Shape& shape,
tensorflow::gtl::ArraySlice<int64> base,
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
<< "TearDown called more than once; it should be called exactly once.";
tear_down_called_ = true;
if (module_) {
- HloVerifier verifier;
- xla::StatusOr<bool> mutated = verifier.Run(module_.get());
- if (!mutated.ok()) {
- ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
- } else {
- EXPECT_FALSE(mutated.ValueOrDie())
- << "HloVerifier should never mutate the HloModule";
- }
+ VerifyModule();
}
HloTestBase::TearDown();
}
+void HloVerifiedTestBase::VerifyModule() {
+ HloVerifier verifier;
+ xla::StatusOr<bool> mutated = verifier.Run(module_.get());
+ if (!mutated.ok()) {
+ ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
+ } else {
+ EXPECT_FALSE(mutated.ValueOrDie())
+ << "HloVerifier should never mutate the HloModule";
+ }
+}
+
HloModule& HloVerifiedTestBase::module() {
if (!module_) {
module_ = CreateNewModule();
return *module_;
}
+void HloVerifiedTestBase::ParseAndVerifyModule(const char* hlo_text) {
+ CHECK(!module_) << "Called ParseModule when test already has a module.";
+ TF_ASSERT_OK_AND_ASSIGN(module_, tools::Parse(hlo_text));
+ VerifyModule();
+}
} // namespace xla
// Returns the default HloModule, lazily creating it if necessary via
// HloTestBase::CreateNewModule().
HloModule& module();
+ void ParseAndVerifyModule(const char* hlo_text);
// Sets the shape-size function used during hlo verification. If this isn't
// called, a default ShapeVerifier is used instead.
std::unique_ptr<HloModule> module_; // Lazily populated. Access via module().
std::unique_ptr<ShapeVerifier> shape_verifier_;
bool tear_down_called_ = false;
+ void VerifyModule();
};
} // namespace xla
auto c_find_if(const C& c, Pred&& pred) -> decltype(std::begin(c)) {
return std::find_if(std::begin(c), std::end(c), std::forward<Pred>(pred));
}
+
+template <typename C, typename Value>
+auto c_find(const C& c, Value&& value) -> decltype(std::begin(c)) {
+ return std::find(std::begin(c), std::end(c), std::forward<Value>(value));
+}
} // namespace xla
#define XLA_LOG_LINES(SEV, STRING) \