Add support for unary and binary ops to indexed tensor analysis
authorSanjoy Das <sanjoy@google.com>
Sat, 26 May 2018 02:21:57 +0000 (19:21 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 26 May 2018 02:24:34 +0000 (19:24 -0700)
I've added a TODO to clean up the use of ValueOrDie which I will address in an
immediately following CL.

PiperOrigin-RevId: 198134579

tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/service/hlo_evaluator.h
tensorflow/compiler/xla/service/indexed_array_analysis.cc
tensorflow/compiler/xla/service/indexed_array_analysis.h
tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
tensorflow/compiler/xla/util.h

index 2976bdb..5472f9a 100644 (file)
@@ -2927,6 +2927,7 @@ cc_library(
     hdrs = ["indexed_array_analysis.h"],
     deps = [
         ":hlo",
+        ":hlo_evaluator",
         ":hlo_pass",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/core:lib",
index 2a8de02..e90eb06 100644 (file)
@@ -309,6 +309,35 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
   return result;
 }
 
+StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
+    HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
+  std::unique_ptr<HloInstruction> lhs_instr =
+      HloInstruction::CreateConstant(lhs.CloneToUnique());
+  std::unique_ptr<HloInstruction> rhs_instr =
+      HloInstruction::CreateConstant(rhs.CloneToUnique());
+
+  std::unique_ptr<HloInstruction> cloned_instruction =
+      HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
+                                   rhs_instr.get());
+  auto result = Evaluate(cloned_instruction.get());
+
+  cloned_instruction->DetachFromOperands();
+  return result;
+}
+
+StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
+    HloOpcode opcode, const Literal& operand) {
+  std::unique_ptr<HloInstruction> operand_instr =
+      HloInstruction::CreateConstant(operand.CloneToUnique());
+
+  std::unique_ptr<HloInstruction> cloned_instruction =
+      HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
+  auto result = Evaluate(cloned_instruction.get());
+
+  cloned_instruction->DetachFromOperands();
+  return result;
+}
+
 Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
   CHECK_LT(parameter->parameter_number(), arg_literals_.size());
   const Literal* input_literal = arg_literals_[parameter->parameter_number()];
index 2b72ff1..b53d564 100644 (file)
@@ -109,6 +109,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
       const std::unordered_map<const HloInstruction*, const Literal*>&
           substitutions);
 
+  StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseBinaryOp(
+      HloOpcode opcode, const Literal& lhs, const Literal& rhs);
+
+  StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
+      HloOpcode opcode, const Literal& operand);
+
  protected:
   // Make HloEvaluatorTypedVisitor a friend because it is logically part of this
   // class.
index b74f05e..5d870f9 100644 (file)
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
 #include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/lib/gtl/flatset.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -32,7 +33,9 @@ using tensorflow::gtl::ArraySlice;
 using tensorflow::str_util::Join;
 }  // namespace
 
-string IndexedArrayAnalysis::ToString(Array* root) {
+// TODO(sanjoy): Make this pass StatusOr safe.
+
+string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
   switch (root->kind()) {
     case Array::kUnknown: {
       auto* unknown_tensor = root->as<UnknownArray>();
@@ -41,6 +44,12 @@ string IndexedArrayAnalysis::ToString(Array* root) {
     }
 
     case Array::kConstant: {
+      if (print_constants) {
+        string contents = root->as<ConstantArray>()->literal()->ToString();
+        return tensorflow::strings::StrCat(
+            "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents,
+            ")");
+      }
       return tensorflow::strings::StrCat(
           "(constant ", ShapeUtil::HumanString(root->shape()), ")");
     }
@@ -52,9 +61,10 @@ string IndexedArrayAnalysis::ToString(Array* root) {
                         ? "scalar-indexed-const"
                         : "scalar-indexed";
       return tensorflow::strings::StrCat(
-          "(", name, " ", ToString(indexed_array->source()), " ",
-          ToString(indexed_array->indices()), " ", indexed_array->source_dim(),
-          "->[", Join(indexed_array->output_dims(), ","), "])");
+          "(", name, " ", ToString(indexed_array->source(), print_constants),
+          " ", ToString(indexed_array->indices(), print_constants), " ",
+          indexed_array->source_dim(), "->[",
+          Join(indexed_array->output_dims(), ","), "])");
     }
   }
 }
@@ -115,23 +125,25 @@ void IndexedArrayAnalysis::TraverseAndPopulateCache(
 Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor(
     const HloInstruction* instr) {
   Array* computed_array;
-  switch (instr->opcode()) {
-    default:
-      computed_array = nullptr;
-      break;
-    case HloOpcode::kConstant:
-      computed_array = ComputeArrayForConstant(instr->literal());
-      break;
-    case HloOpcode::kGather:
-      computed_array = ComputeArrayForGather(
-          instr->shape(), instr->gather_dimension_numbers(),
-          instr->gather_window_bounds(), FindOrDie(cache_, instr->operand(0)),
-          FindOrDie(cache_, instr->operand(1)));
-      break;
-    case HloOpcode::kReshape:
-      computed_array = ComputeArrayForReshape(
-          instr->shape(), FindOrDie(cache_, instr->operand(0)));
-      break;
+  if (instr->IsElementwise() && instr->operand_count() == 1) {
+    computed_array = ComputeArrayForElementwiseUnaryOp(
+        instr, FindOrDie(cache_, instr->operand(0)));
+  } else if (instr->IsElementwise() && instr->operand_count() == 2) {
+    computed_array = ComputeArrayForElementwiseBinaryOp(
+        instr, FindOrDie(cache_, instr->operand(0)),
+        FindOrDie(cache_, instr->operand(1)));
+  } else if (instr->opcode() == HloOpcode::kConstant) {
+    computed_array = ComputeArrayForConstant(instr->literal());
+  } else if (instr->opcode() == HloOpcode::kGather) {
+    computed_array = ComputeArrayForGather(
+        instr->shape(), instr->gather_dimension_numbers(),
+        instr->gather_window_bounds(), FindOrDie(cache_, instr->operand(0)),
+        FindOrDie(cache_, instr->operand(1)));
+  } else if (instr->opcode() == HloOpcode::kReshape) {
+    computed_array = ComputeArrayForReshape(
+        instr->shape(), FindOrDie(cache_, instr->operand(0)));
+  } else {
+    computed_array = nullptr;
   }
 
   if (!computed_array) {
@@ -166,14 +178,14 @@ ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather(
                                               IndexComponent::Ungathered);
 
   // Simulate the first gather.
-  simulated_index.erase(simulated_index.begin() + source->source_dim());
+  EraseAt(&simulated_index, source->source_dim());
   for (int64 gather_dim : source->output_dims()) {
     simulated_index.insert(simulated_index.begin() + gather_dim,
                            IndexComponent::GatheredFirst);
   }
 
   // Simulate the second gather.
-  simulated_index.erase(simulated_index.begin() + source_dim);
+  EraseAt(&simulated_index, source_dim);
   for (int64 output_dim : output_dims) {
     simulated_index.insert(simulated_index.begin() + output_dim,
                            IndexComponent::GatheredSecond);
@@ -463,8 +475,7 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForReshape(
     int64 output_dim = scalar_indexed->output_dims()[i];
     int64 output_dim_after_reshape = MapPassthroughOperandDimToResultDim(
         reshape_passthrough_dims, output_dim);
-    new_scalar_indexed_source_shape.erase(
-        new_scalar_indexed_source_shape.begin() + output_dim_after_reshape);
+    EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape);
   }
 
   // After this, we need to add in the dimension that will be the source
@@ -541,6 +552,154 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForReshape(
       output_dims_for_new_scalar_indexed_node, shape);
 }
 
+Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
+    const HloInstruction* instr, Array* lhs, Array* rhs) {
+  // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices))
+  //          => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices)
+  //
+  // We can do this if every output dimension from the scalar-indexed node is a
+  // broadcasted dimension for the broadcast node.  Informally, the precondition
+  // means Broadcast(Const0)[IDX] is solely a function of the components of IDX
+  // that are not output-dims for the scalar-indexed node. In other words, for
+  // every assignment to the non-output dims in IDX we have a "constant" LHS to
+  // the BinaryOp.  This transform propagates this "constant" to the source for
+  // the scalar-indexed node.
+
+  ScalarIndexedConstantArray* lhs_scalar_indexed_const =
+      dynamic_cast<ScalarIndexedConstantArray*>(lhs);
+  ScalarIndexedConstantArray* rhs_scalar_indexed_const =
+      dynamic_cast<ScalarIndexedConstantArray*>(rhs);
+
+  bool lhs_is_indexed;
+
+  // One of the operands must be scalar-indexed and the other must be a
+  // broadcast of a constant.
+  if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) {
+    lhs_is_indexed = true;
+  } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) {
+    lhs_is_indexed = false;
+  } else {
+    return nullptr;
+  }
+
+  ScalarIndexedConstantArray* scalar_indexed_const =
+      lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const;
+  UnknownArray* candidate_broadcast_array =
+      dynamic_cast<UnknownArray*>(lhs_is_indexed ? rhs : lhs);
+  if (!candidate_broadcast_array ||
+      candidate_broadcast_array->instruction().opcode() !=
+          HloOpcode::kBroadcast) {
+    return nullptr;
+  }
+
+  const HloInstruction* broadcast_instr =
+      &candidate_broadcast_array->instruction();
+  const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0);
+  if (broadcast_const_operand->opcode() != HloOpcode::kConstant) {
+    return nullptr;
+  }
+
+  ArraySlice<int64> broadcast_dims = broadcast_instr->dimensions();
+  auto is_broadcasted_dim = [&](int64 output_dim) {
+    return c_find(broadcast_dims, output_dim) == broadcast_dims.end();
+  };
+
+  // All of the output dims must be "broadcasted" dims for the other operand.
+  if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) {
+    return nullptr;
+  }
+
+  // To figure out the broadcast dimensions for the (constant) source for the
+  // scalar-indexed node, we "simulate" the index transformation done by the
+  // existing broadcsat:
+  enum class IndexComponent { Broadcasted, NotBroadcasted };
+  std::vector<IndexComponent> simulated_index(
+      broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted);
+  for (int64 broadcast_dim : broadcast_dims) {
+    simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted;
+  }
+
+  // The scalar-indexed node "removes" the source dim and "inserts" the output
+  // dims.  We do the opposite here to undo the scalar-indexed operation.
+  ArraySlice<int64> output_dims = scalar_indexed_const->output_dims();
+  for (int64 i = output_dims.size() - 1; i >= 0; --i) {
+    CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted);
+    EraseAt(&simulated_index, output_dims[i]);
+  }
+
+  InsertAt(&simulated_index, scalar_indexed_const->source_dim(),
+           IndexComponent::Broadcasted);
+
+  // new_inner_broadcast_dims holds the broadcast dimensions for the inner
+  // BinaryOp(Broadcast'(Const0), Const1).  We now translate simulated_index to
+  // new_inner_broadcast_dims.
+  std::vector<int64> new_inner_broadcast_dims;
+  for (int64 i = 0; i < simulated_index.size(); i++) {
+    if (simulated_index[i] == IndexComponent::NotBroadcasted) {
+      new_inner_broadcast_dims.push_back(i);
+    }
+  }
+
+  // inner_broadcast_result is the Broadcast'(Const0) bit in
+  // BinaryOp(Broadcast'(Const0), Const1)
+  std::unique_ptr<Literal> inner_broadcast_result =
+      broadcast_const_operand->literal()
+          .Broadcast(scalar_indexed_const->source()->shape(),
+                     new_inner_broadcast_dims)
+          .ConsumeValueOrDie();
+
+  // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1)
+  const Literal* literal_for_new_source;
+  if (lhs_is_indexed) {
+    literal_for_new_source =
+        TakeOwnership(HloEvaluator{}
+                          .EvaluateElementwiseBinaryOp(
+                              instr->opcode(), scalar_indexed_const->literal(),
+                              *inner_broadcast_result)
+                          .ConsumeValueOrDie());
+  } else {
+    literal_for_new_source =
+        TakeOwnership(HloEvaluator{}
+                          .EvaluateElementwiseBinaryOp(
+                              instr->opcode(), *inner_broadcast_result,
+                              scalar_indexed_const->literal())
+                          .ConsumeValueOrDie());
+  }
+
+  ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
+  return Construct<ScalarIndexedConstantArray>(
+      new_source, scalar_indexed_const->indices(),
+      scalar_indexed_const->source_dim(),
+      std::vector<int64>(scalar_indexed_const->output_dims().begin(),
+                         scalar_indexed_const->output_dims().end()),
+      scalar_indexed_const->shape());
+}
+
+Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(
+    const HloInstruction* instr, Array* operand) {
+  auto* scalar_indexed_const =
+      dynamic_cast<ScalarIndexedConstantArray*>(operand);
+  if (operand == nullptr) {
+    return nullptr;
+  }
+
+  // Fold UnaryOp(ScalarIndexed(Const, Indices))
+  //   => ScalarIndexed(UnaryOp(Const), Indices)
+
+  Literal* literal_for_new_source =
+      TakeOwnership(HloEvaluator{}
+                        .EvaluateElementwiseUnaryOp(
+                            instr->opcode(), scalar_indexed_const->literal())
+                        .ConsumeValueOrDie());
+  ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
+  return Construct<ScalarIndexedConstantArray>(
+      new_source, scalar_indexed_const->indices(),
+      scalar_indexed_const->source_dim(),
+      std::vector<int64>(scalar_indexed_const->output_dims().begin(),
+                         scalar_indexed_const->output_dims().end()),
+      scalar_indexed_const->shape());
+}
+
 tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const {
   return "indexed-array-analysis-printer-pass";
 }
index 35d454a..8c1f616 100644 (file)
@@ -223,7 +223,7 @@ class IndexedArrayAnalysis {
   Array* GetArrayFor(const HloInstruction* instr);
 
   // Pretty-prints the expression rooted at `root`.
-  string ToString(Array* root);
+  string ToString(Array* root, bool print_constants = false);
 
  private:
   // Helper function that ensures that every HLO instruction that is
@@ -268,6 +268,11 @@ class IndexedArrayAnalysis {
 
   Array* ComputeArrayForReshape(const Shape& shape, Array* operand);
 
+  Array* ComputeArrayForElementwiseBinaryOp(const HloInstruction* instr,
+                                            Array* lhs, Array* rhs);
+  Array* ComputeArrayForElementwiseUnaryOp(const HloInstruction* instr,
+                                           Array* operand);
+
   template <typename T, typename... Args>
   T* Construct(Args&&... args) {
     T* new_tensor = new T(std::forward<Args>(args)...);
index e1090df..76e7e70 100644 (file)
@@ -23,12 +23,27 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
  protected:
   void AssertArrayForRootExpressionIs(const string& hlo_text,
                                       const string& root_expression) {
+    AssertArrayForRootExpressionIsImpl(hlo_text, root_expression,
+                                       /*print_constants=*/false);
+  }
+
+  void AssertArrayWithConstantsForRootExpressionIs(
+      const string& hlo_text, const string& root_expression) {
+    AssertArrayForRootExpressionIsImpl(hlo_text, root_expression,
+                                       /*print_constants=*/true);
+  }
+
+ private:
+  void AssertArrayForRootExpressionIsImpl(const string& hlo_text,
+                                          const string& root_expression,
+                                          bool print_constants) {
     IndexedArrayAnalysis indexed_tensor_analysis;
     ParseAndVerifyModule(hlo_text);
 
-    string result =
-        indexed_tensor_analysis.ToString(indexed_tensor_analysis.GetArrayFor(
-            module().entry_computation()->root_instruction()));
+    string result = indexed_tensor_analysis.ToString(
+        indexed_tensor_analysis.GetArrayFor(
+            module().entry_computation()->root_instruction()),
+        print_constants);
     LOG(INFO) << result;
     ASSERT_EQ(result, root_expression);
   }
@@ -298,5 +313,162 @@ ENTRY main {
 
   AssertArrayForRootExpressionIs(hlo_text, "%reshape");
 }
+
+TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) {
+  string hlo_text = R"(
+HloModule UnaryOpOfGather
+
+ENTRY main {
+  operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+  indices = s32[5] parameter(0)
+  gather = f32[5,4] gather(operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0},
+      gather_dims_to_operand_dims={0},
+      index_vector_dim=1,
+      window_bounds={1,4}
+  ROOT tanh = f32[5,4] tanh(gather)
+}
+)";
+
+  AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant f32[3,4] f32[3,4] {
+  { 0.761594176, 0.964027584, 0.995054781, 0.999329329 },
+  { 0.761594176, 0.995054781, 0.964027584, 0.999329329 },
+  { 0.999329329, 0.995054781, 0.964027584, 0.761594176 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) {
+  string hlo_text = R"(
+HloModule AddBroadcastedScalarWithGather
+
+ENTRY main {
+  gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+  constant = s32[] constant(5)
+  constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
+  indices = s32[5] parameter(0)
+  gather = s32[5,4] gather(gather_operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0},
+      gather_dims_to_operand_dims={0},
+      index_vector_dim=1,
+      window_bounds={1,4}
+  ROOT add = s32[5,4] add(gather, constant_broadcasted)
+}
+)";
+
+  AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant s32[3,4] s32[3,4] {
+  { 6, 7, 8, 9 },
+  { 6, 8, 7, 9 },
+  { 9, 8, 7, 6 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest,
+       SubtractBroadcastedScalarWithGather_GatherIsLhs) {
+  string hlo_text = R"(
+HloModule SubtractBroadcastedScalarWithGather
+
+ENTRY main {
+  gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+  constant = s32[] constant(5)
+  constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
+  indices = s32[5] parameter(0)
+  gather = s32[5,4] gather(gather_operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0},
+      gather_dims_to_operand_dims={0},
+      index_vector_dim=1,
+      window_bounds={1,4}
+  ROOT sub = s32[5,4] subtract(gather, constant_broadcasted)
+}
+)";
+
+  AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant s32[3,4] s32[3,4] {
+  { -4, -3, -2, -1 },
+  { -4, -2, -3, -1 },
+  { -1, -2, -3, -4 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest,
+       SubtractBroadcastedScalarWithGather_GatherIsRhs) {
+  string hlo_text = R"(
+HloModule SubtractBroadcastedScalarWithGather
+
+ENTRY main {
+  gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+  constant = s32[] constant(5)
+  constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
+  indices = s32[5] parameter(0)
+  gather = s32[5,4] gather(gather_operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0},
+      gather_dims_to_operand_dims={0},
+      index_vector_dim=1,
+      window_bounds={1,4}
+  ROOT sub = s32[5,4] subtract(constant_broadcasted, gather)
+}
+)";
+
+  AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant s32[3,4] s32[3,4] {
+  { 4, 3, 2, 1 },
+  { 4, 2, 3, 1 },
+  { 1, 2, 3, 4 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) {
+  string hlo_text = R"(
+HloModule AddBroadcastedVectorWithGather
+
+ENTRY main {
+  gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+  constant_vect = s32[4] constant({10,11,12,13})
+  constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1}
+  indices = s32[5] parameter(0)
+  gather = s32[5,4] gather(gather_operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0},
+      gather_dims_to_operand_dims={0},
+      index_vector_dim=1,
+      window_bounds={1,4}
+  ROOT add = s32[5,4] add(gather, constant_broadcasted)
+}
+)";
+
+  AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant s32[3,4] s32[3,4] {
+  { 11, 13, 15, 17 },
+  { 11, 14, 14, 17 },
+  { 14, 14, 14, 14 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) {
+  string hlo_text = R"(
+HloModule AddBroadcastedVectorWithGather
+
+ENTRY main {
+  gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+  constant_vect = s32[5] constant({10,11,12,13,14})
+  constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0}
+  indices = s32[5] parameter(0)
+  gather = s32[5,4] gather(gather_operand, indices),
+      output_window_dims={1},
+      elided_window_dims={0},
+      gather_dims_to_operand_dims={0},
+      index_vector_dim=1,
+      window_bounds={1,4}
+  ROOT add = s32[5,4] add(gather, constant_broadcasted)
+}
+)";
+
+  AssertArrayForRootExpressionIs(hlo_text, "%add");
+}
 }  // namespace
 }  // namespace xla
index 6ca0c02..7303640 100644 (file)
@@ -537,6 +537,11 @@ void InsertAt(C* c, int64 index, Value&& value) {
   c->insert(c->begin() + index, std::forward<Value>(value));
 }
 
+template <typename C>
+void EraseAt(C* c, int64 index) {
+  c->erase(c->begin() + index);
+}
+
 // Returns true if `x` fits in 32-bits.
 template <typename T>
 bool IsInt32(T x) {