Add HLO evaluator support for Gather
authorSanjoy Das <sanjoy@google.com>
Tue, 6 Mar 2018 20:15:47 +0000 (12:15 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 20:19:43 +0000 (12:19 -0800)
This isn't optimal -- it copies element by element -- but I figured, at least
for bringup, it will be helpful to have the HLO evaluator follow the spec
closely.

PiperOrigin-RevId: 188061274

tensorflow/compiler/xla/literal_util.cc
tensorflow/compiler/xla/literal_util.h
tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/service/hlo_evaluator.h
tensorflow/compiler/xla/service/hlo_evaluator_test.cc
tensorflow/compiler/xla/shape_util.h
tensorflow/compiler/xla/tests/BUILD
tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
tensorflow/compiler/xla/tests/hlo_verified_test_base.h
tensorflow/compiler/xla/util.h

index 1d1418f..d247aeb 100644 (file)
@@ -248,6 +248,28 @@ Status Literal::CopySliceFromInternal(
   return Status::OK();
 }
 
+Status Literal::CopyElementFrom(const Literal& src_literal,
+                                tensorflow::gtl::ArraySlice<int64> src_index,
+                                tensorflow::gtl::ArraySlice<int64> dest_index) {
+  DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
+  const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+      src_literal.shape(), src_index);
+  const int64 dest_linear_index =
+      IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
+  const int64 primitive_size =
+      ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
+
+  char* dest_address =
+      static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
+  const char* source_address =
+      static_cast<const char*>(src_literal.untyped_data()) +
+      src_linear_index * primitive_size;
+  if (dest_address != source_address) {
+    memcpy(dest_address, source_address, primitive_size);
+  }
+  return Status::OK();
+}
+
 std::vector<Literal> Literal::DecomposeTuple() {
   CHECK(ShapeUtil::IsTuple(shape()));
   std::vector<Literal> elements;
@@ -811,9 +833,10 @@ std::unique_ptr<Literal> Literal::Slice(
   DimensionVector result_dimensions;
   for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) {
     CHECK_GE(start_indices[dnum], 0);
-    CHECK_LE(limit_indices[dnum], shape().dimensions(dnum));
+    CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
+        << "dnum = " << dnum;
     int64 dimension = limit_indices[dnum] - start_indices[dnum];
-    CHECK_GE(dimension, 0);
+    CHECK_GE(dimension, 0) << "dnum = " << dnum;
     result_dimensions.push_back(dimension);
   }
   const auto result_shape =
index cdc5d80..d525487 100644 (file)
@@ -262,6 +262,11 @@ class Literal {
                        tensorflow::gtl::ArraySlice<int64> dest_base,
                        tensorflow::gtl::ArraySlice<int64> copy_size);
 
+  // Copies one element from src_literal[src_index] to (*this)[dest_index].
+  Status CopyElementFrom(const Literal& src_literal,
+                         tensorflow::gtl::ArraySlice<int64> src_index,
+                         tensorflow::gtl::ArraySlice<int64> dest_index);
+
   // Returns a vector containing the tuple elements of this Literal as separate
   // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
   // elements are moved into the new Literals; no data is copied. Upon return
index 534433b..a839f80 100644 (file)
@@ -2466,6 +2466,340 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
   return Status::OK();
 }
 
+// Returns an ShapeUtil::IndexIterationSpace that iterates over the output
+// gather dimensions while keeping the rest of the output dimensions clamped to
+// 0.
+ShapeUtil::IndexIterationSpace IterationSpaceForOutputGatherIndices(
+    const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
+  int64 output_rank = output_shape.dimensions_size();
+  std::vector<int64> index_base(output_rank, 0);
+  std::vector<int64> index_count;
+  index_count.reserve(output_rank);
+  for (int64 i = 0; i < output_rank; i++) {
+    bool is_output_gather_dim =
+        !c_binary_search(dim_numbers.output_window_dims(), i);
+    index_count.push_back(is_output_gather_dim ? output_shape.dimensions(i)
+                                               : 1);
+  }
+
+  return {std::move(index_base), std::move(index_count),
+          std::vector<int64>(output_rank, 1)};
+}
+
+// Return an ShapeUtil::IndexIterationSpace that iterates over the output window
+// dimensions while keeping the rest of the output dimensions clamped to 0.
+ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices(
+    int64 output_rank, ArraySlice<int64> window_bounds,
+    const GatherDimensionNumbers& dim_numbers) {
+  std::vector<int64> index_base(output_rank, 0);
+  std::vector<int64> index_count(output_rank, 1);
+  int64 window_bounds_idx = 0;
+  for (int64 i = 0; i < output_rank; i++) {
+    bool is_output_window_dim =
+        c_binary_search(dim_numbers.output_window_dims(), i);
+    if (is_output_window_dim) {
+      while (c_binary_search(dim_numbers.elided_window_dims(),
+                             window_bounds_idx)) {
+        window_bounds_idx++;
+      }
+      index_count[i] = window_bounds[window_bounds_idx++];
+    }
+  }
+
+  return {std::move(index_base), std::move(index_count),
+          std::vector<int64>(output_rank, 1)};
+}
+
+// This functor computes the contribution of gather_indices to an input index
+// corresponding to an output index.  That is, given an output index I, it picks
+// out the gather output indices in I and uses them to look up a gather index,
+// G, from the gather indices tensor, and expands G into the input space
+// according to gather_dims_to_operand_dims.
+class OutputGatherIndexToInputIndex {
+ public:
+  // The constructor does some setup work that is amortized across all
+  // iterations.
+  explicit OutputGatherIndexToInputIndex(
+      const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
+      const Shape& output_shape, const Literal* gather_indices)
+      : dim_numbers_(*dim_numbers), gather_indices_(*gather_indices) {
+    for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
+      output_dim_is_gather_dims_.push_back(
+          !c_binary_search(dim_numbers_.output_window_dims(), i));
+    }
+
+    for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
+      int64 index_of_input_dim_in_index_vector =
+          std::distance(dim_numbers_.gather_dims_to_operand_dims().begin(),
+                        c_find(dim_numbers_.gather_dims_to_operand_dims(), i));
+      if (index_of_input_dim_in_index_vector ==
+          dim_numbers_.gather_dims_to_operand_dims_size()) {
+        input_dim_value_to_index_vector_.push_back(-1);
+      } else {
+        input_dim_value_to_index_vector_.push_back(
+            index_of_input_dim_in_index_vector);
+      }
+    }
+
+    index_vector_index_.resize(gather_indices_.shape().dimensions_size());
+    input_index_.resize(input_shape.dimensions_size());
+    int64 index_vector_size =
+        gather_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
+    index_vector_.resize(index_vector_size);
+  }
+
+  // Returns the contribution of gather_indices to the input index corresponding
+  // to output_index.  See gather_inner_loop_body.
+  //
+  // This is conceptually  a stateless transformation from output_index to the
+  // gather input index, but:
+  //
+  //  - Instead of allocating memory to represent the gather input index on
+  //    every invocation we reuse the same storage for the result
+  //    (input_index_), mutating it in place.
+  //  - Instead of allocating buffers for temporary values like
+  //    index_vector_index_ and index_vector on every invocation, we reuse the
+  //    same storage for all invocations.
+  //
+  // This returns an arrayslice into memory owned by the class.
+  StatusOr<ArraySlice<int64>> operator()(ArraySlice<int64> output_index) {
+    PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
+    TF_RETURN_IF_ERROR(FetchIndexVector());
+    PropagateIndexVectorToInputIndex();
+    return ArraySlice<int64>(input_index_);
+  }
+
+ private:
+  // Propagates the gather index dimensions from the output index into
+  // index_vector_index_ by mutating index_vector_index_ in place.  Does not
+  // update the dim_numbers.index_vector_dim() dimension -- that's the dimension
+  // we iterate over in FetchIndexVector.
+  void PropagateOutputIndexGatherDimsToIndexVectorIndex(
+      ArraySlice<int64> output_index) {
+    int64 index_vector_index_i = 0;
+    for (int64 i = 0, e = output_index.size(); i < e; i++) {
+      if (!output_dim_is_gather_dims_[i]) {
+        continue;
+      }
+
+      if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
+        index_vector_index_i++;
+      }
+
+      index_vector_index_[index_vector_index_i++] = output_index[i];
+    }
+  }
+
+  // Populates index_vector_ by iterating over gather_indices_ according to
+  // index_vector_index_.
+  Status FetchIndexVector() {
+    int64 index_vector_dim = dim_numbers_.index_vector_dim();
+    for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
+      index_vector_index_[index_vector_dim] = i;
+      TF_ASSIGN_OR_RETURN(index_vector_[i], gather_indices_.GetIntegralAsS64(
+                                                index_vector_index_));
+    }
+    return Status::OK();
+  }
+
+  // Populates input_index_.
+  void PropagateIndexVectorToInputIndex() {
+    for (int64 i = 0, e = input_index_.size(); i < e; i++) {
+      if (input_dim_value_to_index_vector_[i] != -1) {
+        input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
+      }
+
+      // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
+      // remains 0, as set by the constructor.
+    }
+  }
+
+  // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
+  // the input index from the index vector.  See
+  // PropagateIndexVectorToInputIndex.
+  std::vector<int64> input_dim_value_to_index_vector_;
+
+  // output_dim_is_gather_dims_[i] is true iff the output index i is a gather
+  // dimension.
+  std::vector<bool> output_dim_is_gather_dims_;
+
+  // The buffer into which we construct an index into gather_indices_ to fetch
+  // the index vector.
+  std::vector<int64> index_vector_index_;
+
+  // The index vector fetched from gather_indices_.
+  std::vector<int64> index_vector_;
+
+  // The result computed by this functor.  operator() returns an ArraySlice into
+  // this vector.
+  std::vector<int64> input_index_;
+
+  const GatherDimensionNumbers& dim_numbers_;
+  const Literal& gather_indices_;
+};
+
+// This functor computes the contribution of the window indices in an output
+// index to an input index.  That is, given an output index I it picks out the
+// output window indices in I and expands it into a window index into the input
+// shape.
+class OutputWindowIndexToInputIndex {
+ public:
+  // The constructor does some setup work that is amortized across all
+  // iterations.
+  explicit OutputWindowIndexToInputIndex(
+      const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
+      const Shape& output_shape) {
+    std::vector<int64> window_index_to_output_index;
+    int64 output_index_count = 0;
+    for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
+      if (c_binary_search(dim_numbers.output_window_dims(), i)) {
+        window_index_to_output_index.push_back(output_index_count++);
+      } else {
+        output_index_count++;
+      }
+    }
+
+    int64 window_dim_count = 0;
+    for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
+      if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
+        input_dim_value_to_output_index_.push_back(-1);
+      } else {
+        input_dim_value_to_output_index_.push_back(
+            window_index_to_output_index[window_dim_count++]);
+      }
+    }
+
+    input_index_.resize(input_shape.dimensions_size());
+  }
+
+  // Returns the contribution of the window indices to the input index
+  // corresponding to output_index.  See gather_inner_loop_body.
+  //
+  // This is conceptually a stateless transformation from output_index to the
+  // window input index, but instead of allocating memory to represent the
+  // gather input index on every invocation we reuse the same storage for the
+  // result (input_index_), mutating it in place.
+  //
+  // This returns an arrayslice into memory owned by the class.
+  StatusOr<ArraySlice<int64>> operator()(ArraySlice<int64> output_index) {
+    PropagateOutputIndexWindowDimsToInputIndex(output_index);
+    return ArraySlice<int64>(input_index_);
+  }
+
+ private:
+  // Propagates window dimensions from the output index to input_index_ by
+  // mutating input_index_ in place.
+  void PropagateOutputIndexWindowDimsToInputIndex(
+      ArraySlice<int64> output_index) {
+    for (int64 i = 0, e = input_index_.size(); i < e; i++) {
+      if (input_dim_value_to_output_index_[i] != -1) {
+        input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
+      }
+
+      // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
+      // remains 0, as set by the constructor.
+    }
+  }
+
+  // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
+  // the input index from the output index. See
+  // PropagateOutputIndexToInputIndex.
+  std::vector<int64> input_dim_value_to_output_index_;
+
+  // The result computed by this functor.  operator() returns an ArraySlice into
+  // this vector.
+  std::vector<int64> input_index_;
+};
+
+// Rehapes the gather indices input to have a trailing degenerate `1` dimension
+// if necessary.  Hands over the ownership of the newly created literal (if
+// there is one) to `reshaped_gather_indices`.
+static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
+    int64 index_vector_dim, const Literal& gather_indices,
+    std::unique_ptr<Literal>* reshaped_gather_indices) {
+  if (gather_indices.shape().dimensions_size() != index_vector_dim) {
+    return std::cref(gather_indices);
+  }
+
+  std::vector<int64> new_shape(gather_indices.shape().dimensions().begin(),
+                               gather_indices.shape().dimensions().end());
+  new_shape.push_back(1);
+  TF_ASSIGN_OR_RETURN(*reshaped_gather_indices,
+                      gather_indices.Reshape(new_shape));
+  return std::cref(**reshaped_gather_indices);
+}
+
+Status HloEvaluator::HandleGather(HloInstruction* gather) {
+  std::unique_ptr<Literal> result = Literal::CreateFromShape(gather->shape());
+  const Shape& shape = gather->shape();
+  const GatherDimensionNumbers& dim_numbers =
+      gather->gather_dimension_numbers();
+  const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
+  std::unique_ptr<Literal> reshaped_gather_indices;
+  TF_ASSIGN_OR_RETURN(
+      const Literal& gather_indices,
+      ReshapedGatherIndices(dim_numbers.index_vector_dim(),
+                            GetEvaluatedLiteralFor(gather->operand(1)),
+                            &reshaped_gather_indices));
+
+  // We iterate over the gather dimensions in the output shape in an outer loop
+  // nest, and iterate over the window dimensions in the output shape in an
+  // inner loop nest.
+
+  ShapeUtil::IndexIterationSpace gather_indices_iteration_space =
+      IterationSpaceForOutputGatherIndices(shape, dim_numbers);
+  ShapeUtil::IndexIterationSpace window_indices_iteration_space =
+      IterationSpaceForOutputWindowIndices(
+          shape.dimensions_size(), gather->gather_window_bounds(), dim_numbers);
+
+  // Scratch buffers that hold an index in the output shape and the
+  // corresponding index in the input shape.
+  std::vector<int64> input_index(operand.shape().dimensions_size());
+  std::vector<int64> output_index(gather->shape().dimensions_size());
+
+  OutputGatherIndexToInputIndex output_gather_index_to_input_index(
+      &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
+      /*output_shape=*/shape, &gather_indices);
+  OutputWindowIndexToInputIndex output_window_index_to_input_index(
+      gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
+      /*output_shape=*/shape);
+
+  auto gather_inner_loop_body =
+      [&](ArraySlice<int64> output_window_index,
+          ArraySlice<int64> input_gather_index,
+          ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
+    TF_ASSIGN_OR_RETURN(
+        ArraySlice<int64> input_window_index,
+        output_window_index_to_input_index(output_window_index));
+    for (int i = 0, e = output_index.size(); i < e; i++) {
+      output_index[i] = output_gather_index[i] + output_window_index[i];
+    }
+    for (int i = 0, e = input_index.size(); i < e; i++) {
+      input_index[i] = input_gather_index[i] + input_window_index[i];
+    }
+    TF_RETURN_IF_ERROR(
+        result->CopyElementFrom(operand, input_index, output_index));
+    return true;
+  };
+
+  auto gather_outer_loop_body =
+      [&](ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
+    TF_ASSIGN_OR_RETURN(
+        ArraySlice<int64> input_gather_index,
+        output_gather_index_to_input_index(output_gather_index));
+    TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+        shape, window_indices_iteration_space,
+        std::bind(gather_inner_loop_body, std::placeholders::_1,
+                  input_gather_index, output_gather_index)));
+    return true;
+  };
+
+  TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+      shape, gather_indices_iteration_space, gather_outer_loop_body));
+  evaluated_[gather] = std::move(result);
+  return Status::OK();
+}
+
 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
   const auto result_shape = get_tuple_element->shape();
   const int64 index = get_tuple_element->tuple_index();
index 8a27cf9..410e5ce 100644 (file)
@@ -152,6 +152,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
 
   Status HandleTuple(HloInstruction* tuple) override;
 
+  Status HandleGather(HloInstruction* gather) override;
+
   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
 
   Status HandleCopy(HloInstruction* copy) override;
index 97765d6..685cacd 100644 (file)
@@ -1729,6 +1729,207 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) {
                                *result.ValueOrDie());
 }
 
+TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
+  const char* hlo_text = R"(
+HloModule TensorFlowGatherV1
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2] parameter(1)
+  ROOT gather = s32[2,3] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0},
+      gather_dims_to_operand_dims={0},
+      index_vector_dim=1,
+      window_bounds={1, 3}
+}
+)";
+  ParseAndVerifyModule(hlo_text);
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+  std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+  LiteralTestUtil::ExpectEqual(
+      *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
+      *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
+  const char* hlo_text = R"(
+HloModule TensorFlowGatherV2
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2] parameter(1)
+  ROOT gather = s32[3,2] gather(operand, indices),
+      output_window_dims={0},
+      elided_window_dims={1},
+      gather_dims_to_operand_dims={1},
+      index_vector_dim=1,
+      window_bounds={3, 1}
+}
+)";
+  ParseAndVerifyModule(hlo_text);
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+  std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+  LiteralTestUtil::ExpectEqual(
+      *Literal::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
+      *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
+  const char* hlo_text = R"(
+HloModule TensorFlowGatherMultipleBatchDims
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2,2] parameter(1)
+  ROOT gather = s32[2,3,2] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={1},
+      gather_dims_to_operand_dims={1},
+      index_vector_dim=2,
+      window_bounds={3, 1}
+}
+)";
+  ParseAndVerifyModule(hlo_text);
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+  std::unique_ptr<Literal> gather_indices =
+      Literal::CreateR2<int32>({{0, 2}, {2, 1}});
+  LiteralTestUtil::ExpectEqual(
+      *Literal::CreateR3<int32>(
+          {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
+      *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
+  const char* hlo_text = R"(
+HloModule TensorFlowGatherNd
+
+ENTRY main {
+  operand = s32[3,3,2] parameter(0)
+  indices = s32[2,2] parameter(1)
+  ROOT gather = s32[2,2] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0,1},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=1,
+      window_bounds={1,1,2}
+}
+)";
+  ParseAndVerifyModule(hlo_text);
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
+                                {{-4, 4}, {-5, 5}, {-6, 6}},  //
+                                {{-7, 7}, {-8, 8}, {-9, 9}}});
+  std::unique_ptr<Literal> gather_indices =
+      Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+  LiteralTestUtil::ExpectEqual(
+      *Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}),
+      *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest,
+       EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) {
+  const char* hlo_text = R"(
+HloModule TensorFlowGatherNd
+
+ENTRY main {
+  operand = s32[3,3,2] parameter(0)
+  indices = s32[2,2] parameter(1)
+  ROOT gather = s32[2,2] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0,1},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=0,
+      window_bounds={1,1,2}
+}
+)";
+  ParseAndVerifyModule(hlo_text);
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
+                                {{-4, 4}, {-5, 5}, {-6, 6}},  //
+                                {{-7, 7}, {-8, 8}, {-9, 9}}});
+  std::unique_ptr<Literal> gather_indices =
+      Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+  LiteralTestUtil::ExpectEqual(
+      *Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}),
+      *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
+  const char* hlo_text = R"(
+HloModule DynamicSlice
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2] parameter(1)
+  ROOT gather = s32[1,1] gather(operand, indices),
+      output_window_dims={0,1},
+      elided_window_dims={},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=0,
+      window_bounds={1,1}
+}
+)";
+  ParseAndVerifyModule(hlo_text);
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+  std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
+  LiteralTestUtil::ExpectEqual(
+      *Literal::CreateR2<int32>({{5}}),
+      *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
+  const char* hlo_text = R"(
+HloModule BatchDynamicSlice
+
+ENTRY main {
+  operand = s32[3,3] parameter(0)
+  indices = s32[2,2] parameter(1)
+  ROOT gather = s32[2,1,1] gather(operand, indices),
+      output_window_dims={1,2},
+      elided_window_dims={},
+      gather_dims_to_operand_dims={0,1},
+      index_vector_dim=0,
+      window_bounds={1,1}
+}
+)";
+  ParseAndVerifyModule(hlo_text);
+  std::unique_ptr<Literal> operand =
+      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+  std::unique_ptr<Literal> gather_indices =
+      Literal::CreateR2<int32>({{2, 1}, {1, 1}});
+  LiteralTestUtil::ExpectEqual(
+      *Literal::CreateR3<int32>({{{8}}, {{5}}}),
+      *Evaluate({operand.get(), gather_indices.get()}));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
+  const char* hlo_text = R"(
+HloModule TensorFlowGatherV1
+
+ENTRY main {
+  operand = s32[3,0] parameter(0)
+  indices = s32[2] parameter(1)
+  ROOT gather = s32[2,0] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0},
+      gather_dims_to_operand_dims={0},
+      index_vector_dim=1,
+      window_bounds={1, 0}
+}
+)";
+  ParseAndVerifyModule(hlo_text);
+  std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}});
+  std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+  LiteralTestUtil::ExpectEqual(
+      *Literal::CreateR2<int32>({{}, {}}),
+      *Evaluate({operand.get(), gather_indices.get()}));
+}
+
 INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest,
                         ::testing::ValuesIn(use_bf16_params));
 
index fb66f69..92b365e 100644 (file)
@@ -612,6 +612,22 @@ class ShapeUtil {
     return Status::OK();
   }
 
+  // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
+  struct IndexIterationSpace {
+    std::vector<int64> index_base;
+    std::vector<int64> index_count;
+    std::vector<int64> index_incr;
+  };
+
+  template <typename FnTy>
+  static Status ForEachIndexWithStatus(
+      const Shape& shape, const IndexIterationSpace& iteration_space,
+      FnTy&& function) {
+    return ShapeUtil::ForEachIndexWithStatus(
+        shape, iteration_space.index_base, iteration_space.index_count,
+        iteration_space.index_incr, std::forward<FnTy>(function));
+  }
+
   template <typename FnType>
   static void ForEachIndex(const Shape& shape,
                            tensorflow::gtl::ArraySlice<int64> base,
index 1b2008a..5fb38d6 100644 (file)
@@ -139,6 +139,7 @@ cc_library(
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:hlo_verifier",
+        "//tensorflow/compiler/xla/tools/parser:hlo_parser",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
     ],
index 506091d..641907a 100644 (file)
@@ -18,6 +18,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/test.h"
 
@@ -40,18 +41,22 @@ void HloVerifiedTestBase::TearDown() {
       << "TearDown called more than once; it should be called exactly once.";
   tear_down_called_ = true;
   if (module_) {
-    HloVerifier verifier;
-    xla::StatusOr<bool> mutated = verifier.Run(module_.get());
-    if (!mutated.ok()) {
-      ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
-    } else {
-      EXPECT_FALSE(mutated.ValueOrDie())
-          << "HloVerifier should never mutate the HloModule";
-    }
+    VerifyModule();
   }
   HloTestBase::TearDown();
 }
 
+void HloVerifiedTestBase::VerifyModule() {
+  HloVerifier verifier;
+  xla::StatusOr<bool> mutated = verifier.Run(module_.get());
+  if (!mutated.ok()) {
+    ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
+  } else {
+    EXPECT_FALSE(mutated.ValueOrDie())
+        << "HloVerifier should never mutate the HloModule";
+  }
+}
+
 HloModule& HloVerifiedTestBase::module() {
   if (!module_) {
     module_ = CreateNewModule();
@@ -59,4 +64,9 @@ HloModule& HloVerifiedTestBase::module() {
   return *module_;
 }
 
+void HloVerifiedTestBase::ParseAndVerifyModule(const char* hlo_text) {
+  CHECK(!module_) << "Called ParseModule when test already has a module.";
+  TF_ASSERT_OK_AND_ASSIGN(module_, tools::Parse(hlo_text));
+  VerifyModule();
+}
 }  // namespace xla
index 492688b..c0cb12b 100644 (file)
@@ -44,6 +44,7 @@ class HloVerifiedTestBase : public HloTestBase {
   // Returns the default HloModule, lazily creating it if necessary via
   // HloTestBase::CreateNewModule().
   HloModule& module();
+  void ParseAndVerifyModule(const char* hlo_text);
 
   // Sets the shape-size function used during hlo verification. If this isn't
   // called, a default ShapeVerifier is used instead.
@@ -55,6 +56,7 @@ class HloVerifiedTestBase : public HloTestBase {
   std::unique_ptr<HloModule> module_;  // Lazily populated. Access via module().
   std::unique_ptr<ShapeVerifier> shape_verifier_;
   bool tear_down_called_ = false;
+  void VerifyModule();
 };
 
 }  // namespace xla
index 82e5a59..98467cd 100644 (file)
@@ -494,6 +494,11 @@ template <typename C, typename Pred>
 auto c_find_if(const C& c, Pred&& pred) -> decltype(std::begin(c)) {
   return std::find_if(std::begin(c), std::end(c), std::forward<Pred>(pred));
 }
+
+template <typename C, typename Value>
+auto c_find(const C& c, Value&& value) -> decltype(std::begin(c)) {
+  return std::find(std::begin(c), std::end(c), std::forward<Value>(value));
+}
 }  // namespace xla
 
 #define XLA_LOG_LINES(SEV, STRING) \