Split up ElementaIrEmitter::MakeElementGenerator into smaller functions; NFC
authorSanjoy Das <sanjoy@google.com>
Sat, 28 Apr 2018 01:41:27 +0000 (18:41 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 28 Apr 2018 01:44:25 +0000 (18:44 -0700)
PiperOrigin-RevId: 194622198

tensorflow/compiler/xla/service/elemental_ir_emitter.cc
tensorflow/compiler/xla/service/elemental_ir_emitter.h

index 4b01c87..ae32d33 100644 (file)
@@ -1344,6 +1344,525 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
   };
 }
 
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
+    const HloInstruction* hlo,
+    const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+    const llvm_ir::IrArray::Index& index) const {
+  TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
+                      operand_to_generator.at(hlo->operand(0))(
+                          ElementwiseSourceIndex(index, *hlo, 0)));
+  TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
+                      operand_to_generator.at(hlo->operand(1))(
+                          ElementwiseSourceIndex(index, *hlo, 1)));
+  TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
+                      operand_to_generator.at(hlo->operand(2))(
+                          ElementwiseSourceIndex(index, *hlo, 2)));
+  return ir_builder_->CreateSelect(
+      ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()),
+      on_true_value, on_false_value);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
+    const HloInstruction* hlo,
+    const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+    const llvm_ir::IrArray::Index& index) const {
+  TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
+                      operand_to_generator.at(hlo->operand(0))(
+                          ElementwiseSourceIndex(index, *hlo, 0)));
+  TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
+                      operand_to_generator.at(hlo->operand(1))(
+                          ElementwiseSourceIndex(index, *hlo, 1)));
+  TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
+                      operand_to_generator.at(hlo->operand(2))(
+                          ElementwiseSourceIndex(index, *hlo, 2)));
+  PrimitiveType prim_type = hlo->shape().element_type();
+  if (primitive_util::IsFloatingPointType(prim_type)) {
+    return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
+  } else if (primitive_util::IsIntegralType(prim_type)) {
+    bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
+    return EmitIntegralMin(
+        max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed);
+  } else {
+    return Unimplemented("Clamp unimplemented for %s",
+                         PrimitiveType_Name(prim_type).c_str());
+  }
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
+    const HloInstruction* hlo,
+    const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+    const llvm_ir::IrArray::Index& target_index) const {
+  const int64 concat_dim = hlo->dimensions(0);
+  auto source_index = target_index;
+
+  llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock();
+
+  // A terminator should be present iff we're emitting code
+  // into the middle (as opposed to the end) of a basic block.
+  CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(),
+           init_block->getTerminator() == nullptr);
+
+  llvm::BasicBlock* exit_block;
+  if (ir_builder_->GetInsertPoint() == init_block->end()) {
+    exit_block = llvm_ir::CreateBasicBlock(
+        /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_);
+  } else {
+    exit_block = init_block->splitBasicBlock(ir_builder_->GetInsertPoint(),
+                                             AsStringRef(IrName(hlo, "merge")));
+    init_block->getTerminator()->eraseFromParent();
+  }
+
+  llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_);
+  llvm::PHINode* output = ir_builder_->CreatePHI(
+      llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
+      hlo->operands().size());
+  auto prior_insert_point = ir_builder_->GetInsertPoint();
+
+  ir_builder_->SetInsertPoint(init_block);
+
+  for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
+       ++operand_idx) {
+    const HloInstruction* operand = hlo->operand(operand_idx);
+    auto true_block = llvm_ir::CreateBasicBlock(
+        exit_block, StrCat("concat_index_from_operand", operand_idx),
+        ir_builder_);
+    auto false_block = llvm_ir::CreateBasicBlock(
+        exit_block, StrCat("concat_index_not_from_operand", operand_idx),
+        ir_builder_);
+    auto concat_dim_size =
+        llvm::ConstantInt::get(source_index[concat_dim]->getType(),
+                               operand->shape().dimensions(concat_dim));
+    ir_builder_->CreateCondBr(
+        ir_builder_->CreateICmpULT(source_index[concat_dim], concat_dim_size),
+        true_block, false_block);
+
+    // Create the terminator of the true block before calling operand
+    // generators, because they require non-degenerate basic blocks.
+    ir_builder_->SetInsertPoint(
+        llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block));
+    TF_ASSIGN_OR_RETURN(llvm::Value * value,
+                        operand_to_generator.at(operand)(source_index));
+    output->addIncoming(value, ir_builder_->GetInsertBlock());
+
+    // Subtract the size of the concat dimension of the current operand
+    // from the source index.
+    ir_builder_->SetInsertPoint(false_block);
+    source_index[concat_dim] =
+        ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size);
+  }
+
+  ir_builder_->CreateUnreachable();
+  ir_builder_->SetInsertPoint(exit_block, prior_insert_point);
+  return output;
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
+    const HloInstruction* hlo,
+    const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+    const llvm_ir::IrArray::Index& index) const {
+  // Emit IR to read dynamic start indices from hlo->operand(1).
+  const HloInstruction* input_hlo = hlo->operand(0);
+  const int64 rank = ShapeUtil::Rank(input_hlo->shape());
+  llvm_ir::IrArray::Index slice_start_index(rank);
+  for (int64 i = 0; i < rank; ++i) {
+    llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
+    TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
+                        operand_to_generator.at(hlo->operand(1))(dim_index));
+    start_index_value->setName(
+        AsStringRef(IrName(hlo, StrCat("start_idx", i))));
+    slice_start_index[i] = start_index_value;
+  }
+
+  llvm_ir::IrArray::Index input_index(rank);
+  for (int64 i = 0; i < rank; ++i) {
+    // Emit IR which computes:
+    //   input_index = (start_index + offset_index) % dim_size
+    // Security note: this is the code that keeps the indices in-bounds.
+    llvm::Value* dim_size = llvm::ConstantInt::get(
+        index[i]->getType(), input_hlo->shape().dimensions(i));
+    llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast(
+        slice_start_index[i], index[i]->getType());
+    input_index[i] = ir_builder_->CreateURem(
+        ir_builder_->CreateAdd(start_index, index[i]), dim_size);
+  }
+  return operand_to_generator.at(input_hlo)(input_index);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
+    const HloInstruction* hlo,
+    const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+    const llvm_ir::IrArray::Index& index) const {
+  const Shape& operand_shape = hlo->operand(0)->shape();
+  const Shape& indices_shape = hlo->operand(1)->shape();
+  const Shape& output_shape = hlo->shape();
+
+  const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers();
+
+  const llvm_ir::ElementGenerator& operand_generator =
+      operand_to_generator.at(hlo->operand(0));
+  const llvm_ir::ElementGenerator& indices_generator =
+      operand_to_generator.at(hlo->operand(1));
+
+  // This is the index into `operand` that holds the element we want to
+  // generate.  This index "unsafe" as in the components in here may be
+  // out of bounds.
+  IrArray::Index unsafe_operand_index;
+
+  // First copy in the window indices to unsafe_operand_index.
+  for (int64 i = 0, e = operand_shape.dimensions_size(),
+             unsafe_operand_index_dim = 0;
+       i < e; i++) {
+    if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
+      unsafe_operand_index.push_back(ir_builder_->getInt64(0));
+    } else {
+      unsafe_operand_index.push_back(
+          index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]);
+    }
+  }
+
+  // This is the index of the index vector in the gather_indices tensor.
+  IrArray::Index gather_index_index;
+  {
+    std::vector<llvm::Value*> gather_index_index_components;
+    for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
+      if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
+        gather_index_index.push_back(index[i]);
+      }
+    }
+
+    if (gather_index_index.size() != indices_shape.dimensions_size()) {
+      gather_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr);
+    }
+  }
+
+  auto add_to_unsafe_operand_index = [&](llvm::Value* index_component,
+                                         int64 dim) {
+    llvm::Value* gather_dim_component_extended = ir_builder_->CreateSExtOrTrunc(
+        index_component, ir_builder_->getInt64Ty());
+    unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] =
+        ir_builder_->CreateAdd(
+            unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)],
+            gather_dim_component_extended);
+  };
+
+  if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
+    TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
+                        indices_generator(gather_index_index));
+    add_to_unsafe_operand_index(gather_dim_component, 0);
+  } else {
+    int64 index_vector_size =
+        indices_shape.dimensions(dim_numbers.index_vector_dim());
+    for (int64 i = 0; i < index_vector_size; i++) {
+      gather_index_index[dim_numbers.index_vector_dim()] =
+          ir_builder_->getInt64(i);
+      TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
+                          indices_generator(gather_index_index));
+      add_to_unsafe_operand_index(gather_dim_component, i);
+    }
+  }
+
+  IrArray::Index safe_operand_index;
+  for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) {
+    safe_operand_index.push_back(ir_builder_->CreateURem(
+        unsafe_operand_index[i],
+        ir_builder_->getInt64(operand_shape.dimensions(i))));
+  }
+
+  return operand_generator(safe_operand_index);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
+    const HloInstruction* hlo,
+    const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+    const llvm_ir::IrArray::Index& index) const {
+  const HloInstruction* input_hlo = hlo->operand(0);
+  const HloInstruction* update_hlo = hlo->operand(1);
+  const HloInstruction* start_hlo = hlo->operand(2);
+  // Calculate slice start/end indices.
+  const int64 rank = ShapeUtil::Rank(input_hlo->shape());
+  llvm_ir::IrArray::Index slice_start_index(rank);
+  llvm_ir::IrArray::Index slice_limit_index(rank);
+  // Slice starts at update[index - slice_start_index_adjusted],
+  // where adjusted value = slice_start_index when in bounds, and
+  // adjusted value = slice_start_index - input_dim, when wrapping.
+  llvm_ir::IrArray::Index slice_start_index_adjusted(rank);
+
+  // Slice intersection gathers (ANDs) conditions on all ranks for which
+  // 'input' is set to 'update'
+  llvm::Value* slice_intersection = ir_builder_->getTrue();
+
+  for (int64 i = 0; i < rank; ++i) {
+    // Emit IR to read dynamic start indices from 'start_hlo'.
+    llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
+    TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
+                        operand_to_generator.at(start_hlo)(dim_index));
+    start_index_value->setName(
+        AsStringRef(IrName(hlo, StrCat("start_idx", i))));
+    slice_start_index[i] = ir_builder_->CreateZExtOrBitCast(
+        start_index_value, index[i]->getType());
+
+    llvm::Value* input_dim_size = llvm::ConstantInt::get(
+        index[i]->getType(), input_hlo->shape().dimensions(i));
+    llvm::Value* update_dim_size = llvm::ConstantInt::get(
+        index[i]->getType(), update_hlo->shape().dimensions(i));
+
+    // Generate code to handle wrapping semantics:
+    // slice_start_index[i] = slice_start_index[i] % input_dim_size;
+    // slice_limit_index[i] = slice_start_index[i] + update_dim_size.
+    // slice_start_index[i] is updated in place and it will now be in
+    // range. slice_limit_index[i] may be out of range, and it's being
+    // URem-ed below if so.
+    slice_start_index[i] =
+        ir_builder_->CreateURem(slice_start_index[i], input_dim_size);
+    slice_limit_index[i] =
+        ir_builder_->CreateAdd(slice_start_index[i], update_dim_size);
+
+    // Test if slice_limit_index[i] is in bounds
+    llvm::Value* in_bounds =
+        ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size);
+    llvm_ir::LlvmIfData if_in_bounds =
+        llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
+
+    // Handle true BB (slice_limit_index[i] <= input_dim_size).
+    SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
+    // Check that index[i] >= slice_start_index[i] &&
+    //            index[i] < slice_limit_index[i]
+    llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd(
+        slice_intersection,
+        ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
+        "slice_intersection_in");
+    slice_intersection_in_bounds = ir_builder_->CreateAnd(
+        slice_intersection_in_bounds,
+        ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]),
+        "slice_intersection_in");
+
+    // Handle false BB (slice_limit_index[i] > input_dim_size).
+    SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_);
+    // Check that index[i] >= slice_start_index[i] ||
+    //            index[i] < slice_limit_index[i]%input_dim_size.
+    llvm::Value* index_wraps = ir_builder_->CreateICmpSLT(
+        index[i],
+        ir_builder_->CreateURem(slice_limit_index[i], input_dim_size));
+    llvm::Value* slice_intersection_or = ir_builder_->CreateOr(
+        ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), index_wraps,
+        "slice_intersection_out");
+    llvm::Value* slice_intersection_out_of_bounds = ir_builder_->CreateAnd(
+        slice_intersection, slice_intersection_or, "slice_intersection_out");
+    // Create value for slice_start_index_adjusted[i] when out of bounds.
+    // If within out-of-bounds if.
+    llvm_ir::LlvmIfData if_start_needs_adjustment =
+        llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_);
+    SetToFirstInsertPoint(if_start_needs_adjustment.true_block, ir_builder_);
+    llvm::Value* slice_start_index_adjusted_oob =
+        ir_builder_->CreateSub(slice_start_index[i], input_dim_size);
+    SetToFirstInsertPoint(if_start_needs_adjustment.after_block, ir_builder_);
+    llvm::PHINode* slice_start_index_adjusted_phi =
+        ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), 2);
+    slice_start_index_adjusted_phi->addIncoming(
+        slice_start_index_adjusted_oob, if_start_needs_adjustment.true_block);
+    slice_start_index_adjusted_phi->addIncoming(
+        slice_start_index[i], if_start_needs_adjustment.false_block);
+    // End of if within if.
+
+    // After checking in/out of bounds.
+    SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_);
+    llvm::PHINode* phi_slice_intersection =
+        ir_builder_->CreatePHI(slice_intersection->getType(), 2);
+    phi_slice_intersection->addIncoming(slice_intersection_in_bounds,
+                                        if_in_bounds.true_block);
+    phi_slice_intersection->addIncoming(slice_intersection_out_of_bounds,
+                                        if_start_needs_adjustment.after_block);
+    slice_intersection = phi_slice_intersection;
+
+    llvm::PHINode* phi_index =
+        ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2);
+    phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block);
+    phi_index->addIncoming(slice_start_index_adjusted_phi,
+                           if_start_needs_adjustment.after_block);
+    slice_start_index_adjusted[i] = phi_index;
+  }
+
+  // Emit:
+  // if (slice_intersection) -> return data from 'update'.
+  // else                    -> return data from 'input'.
+  llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+      llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
+      "ret_value_addr", ir_builder_);
+  llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+      slice_intersection, "slice_intersection", ir_builder_);
+
+  // Handle true BB (return data from 'update')
+  SetToFirstInsertPoint(if_data.true_block, ir_builder_);
+  // Compute update index for intersection case.
+  llvm_ir::IrArray::Index update_index(rank);
+  for (int64 i = 0; i < rank; ++i) {
+    llvm::Value* update_dim_size = llvm::ConstantInt::get(
+        index[i]->getType(), update_hlo->shape().dimensions(i));
+    // NOTE: Subtraction will be positive due to bounds checking above.
+    update_index[i] = ir_builder_->CreateURem(
+        ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]),
+        update_dim_size);
+  }
+  TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
+                      operand_to_generator.at(update_hlo)(update_index));
+  ir_builder_->CreateStore(true_value, ret_value_addr);
+
+  // Handle false BB (return data from 'input')
+  SetToFirstInsertPoint(if_data.false_block, ir_builder_);
+  TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
+                      operand_to_generator.at(input_hlo)(index));
+  ir_builder_->CreateStore(false_value, ret_value_addr);
+
+  SetToFirstInsertPoint(if_data.after_block, ir_builder_);
+  return ir_builder_->CreateLoad(ret_value_addr);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
+    const HloInstruction* hlo,
+    const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+    const llvm_ir::IrArray::Index& padded_index) const {
+  auto index = padded_index;
+  llvm::Value* in_bounds = ir_builder_->getTrue();
+  for (size_t i = 0; i < index.size(); ++i) {
+    auto index_typed_const = [=](int64 n) {
+      return llvm::ConstantInt::get(index[i]->getType(), n);
+    };
+    const auto& pad_dim = hlo->padding_config().dimensions(i);
+    index[i] = ir_builder_->CreateSub(
+        index[i], index_typed_const(pad_dim.edge_padding_low()));
+    in_bounds = ir_builder_->CreateAnd(
+        in_bounds, ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
+        "in_bounds");
+    in_bounds = ir_builder_->CreateAnd(
+        in_bounds,
+        ir_builder_->CreateICmpEQ(
+            index_typed_const(0),
+            ir_builder_->CreateURem(
+                index[i], index_typed_const(pad_dim.interior_padding() + 1))),
+        "in_bounds");
+    index[i] = ir_builder_->CreateSDiv(
+        index[i], index_typed_const(pad_dim.interior_padding() + 1));
+    in_bounds = ir_builder_->CreateAnd(
+        in_bounds,
+        ir_builder_->CreateICmpSLT(
+            index[i],
+            index_typed_const(hlo->operand(0)->shape().dimensions(i))),
+        "in_bounds");
+  }
+
+  // if (in_bounds) {
+  //   ret_value = operand0[index];  // source
+  // } else {
+  //   ret_value = *operand1;        // padding
+  // }
+  llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+      llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
+      "pad_result_addr", ir_builder_);
+  llvm_ir::LlvmIfData if_data =
+      llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
+  SetToFirstInsertPoint(if_data.true_block, ir_builder_);
+  TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
+                      operand_to_generator.at(hlo->operand(0))(index));
+  ir_builder_->CreateStore(operand_value, ret_value_addr);
+
+  SetToFirstInsertPoint(if_data.false_block, ir_builder_);
+  TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
+                      operand_to_generator.at(hlo->operand(1))({}));
+  ir_builder_->CreateStore(padding_value, ret_value_addr);
+
+  SetToFirstInsertPoint(if_data.after_block, ir_builder_);
+  // Don't create phi(operand_value, padding_value) here, because invoking
+  // operand_to_generator may create new basic blocks, making the parent
+  // of operand_value or padding_value no longer a predecessor of
+  // if_data.after_block.
+  return ir_builder_->CreateLoad(ret_value_addr);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
+    const HloInstruction* hlo,
+    const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+    const llvm_ir::IrArray::Index& dot_result_index) const {
+  auto lhs_generator = operand_to_generator.at(hlo->operand(0));
+  auto rhs_generator = operand_to_generator.at(hlo->operand(1));
+  int64 contracted_dim_size = hlo->operand(0)->shape().dimensions(
+      hlo->operand(0)->shape().dimensions_size() - 1);
+  int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
+  int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
+
+  std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
+      IrName(hlo, "inner"), ir_builder_->getInt64(0),
+      ir_builder_->getInt64(contracted_dim_size), ir_builder_->getInt64(1),
+      ir_builder_);
+
+  SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_);
+  PrimitiveType primitive_type = hlo->shape().element_type();
+  llvm::Type* primitive_type_llvm =
+      llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
+  llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(
+      primitive_type_llvm, "dot_acc", ir_builder_);
+  ir_builder_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm),
+                           accumulator_alloca);
+
+  SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_);
+
+  // This is the inner reduction loop for a dot operation that produces
+  // one element in the output.  If the operands to the dot operation have
+  // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
+  // Given an output index [a,b,c,d,e] in the result, we compute:
+  //   sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
+
+  IrArray::Index lhs_index, rhs_index;
+
+  for (int64 i = 0; i < lhs_dims - 1; i++) {
+    lhs_index.push_back(dot_result_index[i]);
+  }
+  lhs_index.push_back(inner_loop->GetIndVarValue());
+
+  for (int64 i = 0; i < rhs_dims - 2; i++) {
+    rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]);
+  }
+  rhs_index.push_back(inner_loop->GetIndVarValue());
+  rhs_index.push_back(dot_result_index.back());
+
+  llvm::Value* current_accumulator =
+      ir_builder_->CreateLoad(accumulator_alloca);
+  TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
+  TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
+  llvm::Value* next_accumulator;
+  if (primitive_util::IsComplexType(primitive_type)) {
+    llvm::Value* product_real = ir_builder_->CreateFSub(
+        ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
+                                EmitExtractReal(rhs_value)),
+        ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
+                                EmitExtractImag(rhs_value)));
+    llvm::Value* product_imag = ir_builder_->CreateFAdd(
+        ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
+                                EmitExtractImag(rhs_value)),
+        ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
+                                EmitExtractReal(rhs_value)));
+    next_accumulator = ir_builder_->CreateInsertValue(
+        current_accumulator,
+        ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator),
+                                product_real),
+        {0});
+    next_accumulator = ir_builder_->CreateInsertValue(
+        next_accumulator,
+        ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator),
+                                product_imag),
+        {1});
+  } else if (primitive_util::IsFloatingPointType(primitive_type)) {
+    next_accumulator = ir_builder_->CreateFAdd(
+        current_accumulator, ir_builder_->CreateFMul(lhs_value, rhs_value));
+  } else {
+    next_accumulator = ir_builder_->CreateAdd(
+        current_accumulator, ir_builder_->CreateMul(lhs_value, rhs_value));
+  }
+  ir_builder_->CreateStore(next_accumulator, accumulator_alloca);
+
+  SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_);
+  return ir_builder_->CreateLoad(accumulator_alloca);
+}
+
 llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
     const HloInstruction* hlo,
     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
@@ -1411,43 +1930,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
     case HloOpcode::kSelect:
       return [this, hlo, &operand_to_generator](
                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
-        TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
-                            operand_to_generator.at(hlo->operand(0))(
-                                ElementwiseSourceIndex(index, *hlo, 0)));
-        TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
-                            operand_to_generator.at(hlo->operand(1))(
-                                ElementwiseSourceIndex(index, *hlo, 1)));
-        TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
-                            operand_to_generator.at(hlo->operand(2))(
-                                ElementwiseSourceIndex(index, *hlo, 2)));
-        return ir_builder_->CreateSelect(
-            ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()),
-            on_true_value, on_false_value);
+        return EmitElementalSelect(hlo, operand_to_generator, index);
       };
     case HloOpcode::kClamp:
       return [this, hlo, &operand_to_generator](
                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
-        TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
-                            operand_to_generator.at(hlo->operand(0))(
-                                ElementwiseSourceIndex(index, *hlo, 0)));
-        TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
-                            operand_to_generator.at(hlo->operand(1))(
-                                ElementwiseSourceIndex(index, *hlo, 1)));
-        TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
-                            operand_to_generator.at(hlo->operand(2))(
-                                ElementwiseSourceIndex(index, *hlo, 2)));
-        PrimitiveType prim_type = hlo->shape().element_type();
-        if (primitive_util::IsFloatingPointType(prim_type)) {
-          return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
-        } else if (primitive_util::IsIntegralType(prim_type)) {
-          bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
-          return EmitIntegralMin(
-              max_value, EmitIntegralMax(min_value, arg_value, is_signed),
-              is_signed);
-        } else {
-          return Unimplemented("Clamp unimplemented for %s",
-                               PrimitiveType_Name(prim_type).c_str());
-        }
+        return EmitElementalClamp(hlo, operand_to_generator, index);
       };
     case HloOpcode::kReducePrecision:
       return [this, hlo, &operand_to_generator](
@@ -1460,70 +1948,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
     case HloOpcode::kConcatenate:
       return [this, hlo, &operand_to_generator](
                  const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
-        const int64 concat_dim = hlo->dimensions(0);
-        auto source_index = target_index;
-
-        llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock();
-
-        // A terminator should be present iff we're emitting code
-        // into the middle (as opposed to the end) of a basic block.
-        CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(),
-                 init_block->getTerminator() == nullptr);
-
-        llvm::BasicBlock* exit_block;
-        if (ir_builder_->GetInsertPoint() == init_block->end()) {
-          exit_block = llvm_ir::CreateBasicBlock(
-              /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_);
-        } else {
-          exit_block = init_block->splitBasicBlock(
-              ir_builder_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge")));
-          init_block->getTerminator()->eraseFromParent();
-        }
-
-        llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_);
-        llvm::PHINode* output =
-            ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType(
-                                       hlo->shape().element_type(), module_),
-                                   hlo->operands().size());
-        auto prior_insert_point = ir_builder_->GetInsertPoint();
-
-        ir_builder_->SetInsertPoint(init_block);
-
-        for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
-             ++operand_idx) {
-          const HloInstruction* operand = hlo->operand(operand_idx);
-          auto true_block = llvm_ir::CreateBasicBlock(
-              exit_block, StrCat("concat_index_from_operand", operand_idx),
-              ir_builder_);
-          auto false_block = llvm_ir::CreateBasicBlock(
-              exit_block, StrCat("concat_index_not_from_operand", operand_idx),
-              ir_builder_);
-          auto concat_dim_size =
-              llvm::ConstantInt::get(source_index[concat_dim]->getType(),
-                                     operand->shape().dimensions(concat_dim));
-          ir_builder_->CreateCondBr(
-              ir_builder_->CreateICmpULT(source_index[concat_dim],
-                                         concat_dim_size),
-              true_block, false_block);
-
-          // Create the terminator of the true block before calling operand
-          // generators, because they require non-degenerate basic blocks.
-          ir_builder_->SetInsertPoint(
-              llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block));
-          TF_ASSIGN_OR_RETURN(llvm::Value * value,
-                              operand_to_generator.at(operand)(source_index));
-          output->addIncoming(value, ir_builder_->GetInsertBlock());
-
-          // Subtract the size of the concat dimension of the current operand
-          // from the source index.
-          ir_builder_->SetInsertPoint(false_block);
-          source_index[concat_dim] =
-              ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size);
-        }
-
-        ir_builder_->CreateUnreachable();
-        ir_builder_->SetInsertPoint(exit_block, prior_insert_point);
-        return output;
+        return EmitElementalConcatenate(hlo, operand_to_generator,
+                                        target_index);
       };
     case HloOpcode::kReverse:
       return [this, hlo, &operand_to_generator](
@@ -1559,270 +1985,19 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
     case HloOpcode::kDynamicSlice:
       return [this, hlo, &operand_to_generator](
                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
-        // Emit IR to read dynamic start indices from hlo->operand(1).
-        const HloInstruction* input_hlo = hlo->operand(0);
-        const int64 rank = ShapeUtil::Rank(input_hlo->shape());
-        llvm_ir::IrArray::Index slice_start_index(rank);
-        for (int64 i = 0; i < rank; ++i) {
-          llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
-          TF_ASSIGN_OR_RETURN(
-              llvm::Value * start_index_value,
-              operand_to_generator.at(hlo->operand(1))(dim_index));
-          start_index_value->setName(
-              AsStringRef(IrName(hlo, StrCat("start_idx", i))));
-          slice_start_index[i] = start_index_value;
-        }
-
-        llvm_ir::IrArray::Index input_index(rank);
-        for (int64 i = 0; i < rank; ++i) {
-          // Emit IR which computes:
-          //   input_index = (start_index + offset_index) % dim_size
-          // Security note: this is the code that keeps the indices in-bounds.
-          llvm::Value* dim_size = llvm::ConstantInt::get(
-              index[i]->getType(), input_hlo->shape().dimensions(i));
-          llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast(
-              slice_start_index[i], index[i]->getType());
-          input_index[i] = ir_builder_->CreateURem(
-              ir_builder_->CreateAdd(start_index, index[i]), dim_size);
-        }
-        return operand_to_generator.at(input_hlo)(input_index);
+        return EmitElementalDynamicSlice(hlo, operand_to_generator, index);
       };
 
     case HloOpcode::kGather:
       return [this, hlo, &operand_to_generator](
                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
-        const Shape& operand_shape = hlo->operand(0)->shape();
-        const Shape& indices_shape = hlo->operand(1)->shape();
-        const Shape& output_shape = hlo->shape();
-
-        const GatherDimensionNumbers& dim_numbers =
-            hlo->gather_dimension_numbers();
-
-        const llvm_ir::ElementGenerator& operand_generator =
-            operand_to_generator.at(hlo->operand(0));
-        const llvm_ir::ElementGenerator& indices_generator =
-            operand_to_generator.at(hlo->operand(1));
-
-        // This is the index into `operand` that holds the element we want to
-        // generate.  This index "unsafe" as in the components in here may be
-        // out of bounds.
-        IrArray::Index unsafe_operand_index;
-
-        // First copy in the window indices to unsafe_operand_index.
-        for (int64 i = 0, e = operand_shape.dimensions_size(),
-                   unsafe_operand_index_dim = 0;
-             i < e; i++) {
-          if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
-            unsafe_operand_index.push_back(ir_builder_->getInt64(0));
-          } else {
-            unsafe_operand_index.push_back(index[dim_numbers.output_window_dims(
-                unsafe_operand_index_dim++)]);
-          }
-        }
-
-        // This is the index of the index vector in the gather_indices tensor.
-        IrArray::Index gather_index_index;
-        {
-          std::vector<llvm::Value*> gather_index_index_components;
-          for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
-            if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
-              gather_index_index.push_back(index[i]);
-            }
-          }
-
-          if (gather_index_index.size() != indices_shape.dimensions_size()) {
-            gather_index_index.InsertAt(dim_numbers.index_vector_dim(),
-                                        nullptr);
-          }
-        }
-
-        auto add_to_unsafe_operand_index = [&](llvm::Value* index_component,
-                                               int64 dim) {
-          llvm::Value* gather_dim_component_extended =
-              ir_builder_->CreateSExtOrTrunc(index_component,
-                                             ir_builder_->getInt64Ty());
-          unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] =
-              ir_builder_->CreateAdd(
-                  unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(
-                      dim)],
-                  gather_dim_component_extended);
-        };
-
-        if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
-          TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
-                              indices_generator(gather_index_index));
-          add_to_unsafe_operand_index(gather_dim_component, 0);
-        } else {
-          int64 index_vector_size =
-              indices_shape.dimensions(dim_numbers.index_vector_dim());
-          for (int64 i = 0; i < index_vector_size; i++) {
-            gather_index_index[dim_numbers.index_vector_dim()] =
-                ir_builder_->getInt64(i);
-            TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
-                                indices_generator(gather_index_index));
-            add_to_unsafe_operand_index(gather_dim_component, i);
-          }
-        }
-
-        IrArray::Index safe_operand_index;
-        for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) {
-          safe_operand_index.push_back(ir_builder_->CreateURem(
-              unsafe_operand_index[i],
-              ir_builder_->getInt64(operand_shape.dimensions(i))));
-        }
-
-        return operand_generator(safe_operand_index);
+        return EmitElementalGather(hlo, operand_to_generator, index);
       };
     case HloOpcode::kDynamicUpdateSlice:
       return [this, hlo, &operand_to_generator](
                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
-        const HloInstruction* input_hlo = hlo->operand(0);
-        const HloInstruction* update_hlo = hlo->operand(1);
-        const HloInstruction* start_hlo = hlo->operand(2);
-        // Calculate slice start/end indices.
-        const int64 rank = ShapeUtil::Rank(input_hlo->shape());
-        llvm_ir::IrArray::Index slice_start_index(rank);
-        llvm_ir::IrArray::Index slice_limit_index(rank);
-        // Slice starts at update[index - slice_start_index_adjusted],
-        // where adjusted value = slice_start_index when in bounds, and
-        // adjusted value = slice_start_index - input_dim, when wrapping.
-        llvm_ir::IrArray::Index slice_start_index_adjusted(rank);
-
-        // Slice intersection gathers (ANDs) conditions on all ranks for which
-        // 'input' is set to 'update'
-        llvm::Value* slice_intersection = ir_builder_->getTrue();
-
-        for (int64 i = 0; i < rank; ++i) {
-          // Emit IR to read dynamic start indices from 'start_hlo'.
-          llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
-          TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
-                              operand_to_generator.at(start_hlo)(dim_index));
-          start_index_value->setName(
-              AsStringRef(IrName(hlo, StrCat("start_idx", i))));
-          slice_start_index[i] = ir_builder_->CreateZExtOrBitCast(
-              start_index_value, index[i]->getType());
-
-          llvm::Value* input_dim_size = llvm::ConstantInt::get(
-              index[i]->getType(), input_hlo->shape().dimensions(i));
-          llvm::Value* update_dim_size = llvm::ConstantInt::get(
-              index[i]->getType(), update_hlo->shape().dimensions(i));
-
-          // Generate code to handle wrapping semantics:
-          // slice_start_index[i] = slice_start_index[i] % input_dim_size;
-          // slice_limit_index[i] = slice_start_index[i] + update_dim_size.
-          // slice_start_index[i] is updated in place and it will now be in
-          // range. slice_limit_index[i] may be out of range, and it's being
-          // URem-ed below if so.
-          slice_start_index[i] =
-              ir_builder_->CreateURem(slice_start_index[i], input_dim_size);
-          slice_limit_index[i] =
-              ir_builder_->CreateAdd(slice_start_index[i], update_dim_size);
-
-          // Test if slice_limit_index[i] is in bounds
-          llvm::Value* in_bounds =
-              ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size);
-          llvm_ir::LlvmIfData if_in_bounds =
-              llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
-
-          // Handle true BB (slice_limit_index[i] <= input_dim_size).
-          SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
-          // Check that index[i] >= slice_start_index[i] &&
-          //            index[i] < slice_limit_index[i]
-          llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd(
-              slice_intersection,
-              ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
-              "slice_intersection_in");
-          slice_intersection_in_bounds = ir_builder_->CreateAnd(
-              slice_intersection_in_bounds,
-              ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]),
-              "slice_intersection_in");
-
-          // Handle false BB (slice_limit_index[i] > input_dim_size).
-          SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_);
-          // Check that index[i] >= slice_start_index[i] ||
-          //            index[i] < slice_limit_index[i]%input_dim_size.
-          llvm::Value* index_wraps = ir_builder_->CreateICmpSLT(
-              index[i],
-              ir_builder_->CreateURem(slice_limit_index[i], input_dim_size));
-          llvm::Value* slice_intersection_or = ir_builder_->CreateOr(
-              ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
-              index_wraps, "slice_intersection_out");
-          llvm::Value* slice_intersection_out_of_bounds =
-              ir_builder_->CreateAnd(slice_intersection, slice_intersection_or,
-                                     "slice_intersection_out");
-          // Create value for slice_start_index_adjusted[i] when out of bounds.
-          // If within out-of-bounds if.
-          llvm_ir::LlvmIfData if_start_needs_adjustment =
-              llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_);
-          SetToFirstInsertPoint(if_start_needs_adjustment.true_block,
-                                ir_builder_);
-          llvm::Value* slice_start_index_adjusted_oob =
-              ir_builder_->CreateSub(slice_start_index[i], input_dim_size);
-          SetToFirstInsertPoint(if_start_needs_adjustment.after_block,
-                                ir_builder_);
-          llvm::PHINode* slice_start_index_adjusted_phi =
-              ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(),
-                                     2);
-          slice_start_index_adjusted_phi->addIncoming(
-              slice_start_index_adjusted_oob,
-              if_start_needs_adjustment.true_block);
-          slice_start_index_adjusted_phi->addIncoming(
-              slice_start_index[i], if_start_needs_adjustment.false_block);
-          // End of if within if.
-
-          // After checking in/out of bounds.
-          SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_);
-          llvm::PHINode* phi_slice_intersection =
-              ir_builder_->CreatePHI(slice_intersection->getType(), 2);
-          phi_slice_intersection->addIncoming(slice_intersection_in_bounds,
-                                              if_in_bounds.true_block);
-          phi_slice_intersection->addIncoming(
-              slice_intersection_out_of_bounds,
-              if_start_needs_adjustment.after_block);
-          slice_intersection = phi_slice_intersection;
-
-          llvm::PHINode* phi_index =
-              ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2);
-          phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block);
-          phi_index->addIncoming(slice_start_index_adjusted_phi,
-                                 if_start_needs_adjustment.after_block);
-          slice_start_index_adjusted[i] = phi_index;
-        }
-
-        // Emit:
-        // if (slice_intersection) -> return data from 'update'.
-        // else                    -> return data from 'input'.
-        llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
-            llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
-                                           module_),
-            "ret_value_addr", ir_builder_);
-        llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
-            slice_intersection, "slice_intersection", ir_builder_);
-
-        // Handle true BB (return data from 'update')
-        SetToFirstInsertPoint(if_data.true_block, ir_builder_);
-        // Compute update index for intersection case.
-        llvm_ir::IrArray::Index update_index(rank);
-        for (int64 i = 0; i < rank; ++i) {
-          llvm::Value* update_dim_size = llvm::ConstantInt::get(
-              index[i]->getType(), update_hlo->shape().dimensions(i));
-          // NOTE: Subtraction will be positive due to bounds checking above.
-          update_index[i] = ir_builder_->CreateURem(
-              ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]),
-              update_dim_size);
-        }
-        TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
-                            operand_to_generator.at(update_hlo)(update_index));
-        ir_builder_->CreateStore(true_value, ret_value_addr);
-
-        // Handle false BB (return data from 'input')
-        SetToFirstInsertPoint(if_data.false_block, ir_builder_);
-        TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
-                            operand_to_generator.at(input_hlo)(index));
-        ir_builder_->CreateStore(false_value, ret_value_addr);
-
-        SetToFirstInsertPoint(if_data.after_block, ir_builder_);
-        return ir_builder_->CreateLoad(ret_value_addr);
+        return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator,
+                                               index);
       };
     case HloOpcode::kBitcast:
       CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
@@ -1851,155 +2026,16 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
     case HloOpcode::kRng:
       return MakeRngElementGenerator(hlo, operand_to_generator);
     case HloOpcode::kPad:
-      return [=, &operand_to_generator](
+      return [this, hlo, &operand_to_generator](
                  const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
-        auto index = padded_index;
-        llvm::Value* in_bounds = ir_builder_->getTrue();
-        for (size_t i = 0; i < index.size(); ++i) {
-          auto index_typed_const = [=](int64 n) {
-            return llvm::ConstantInt::get(index[i]->getType(), n);
-          };
-          const auto& pad_dim = hlo->padding_config().dimensions(i);
-          index[i] = ir_builder_->CreateSub(
-              index[i], index_typed_const(pad_dim.edge_padding_low()));
-          in_bounds = ir_builder_->CreateAnd(
-              in_bounds,
-              ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
-              "in_bounds");
-          in_bounds = ir_builder_->CreateAnd(
-              in_bounds,
-              ir_builder_->CreateICmpEQ(
-                  index_typed_const(0),
-                  ir_builder_->CreateURem(
-                      index[i],
-                      index_typed_const(pad_dim.interior_padding() + 1))),
-              "in_bounds");
-          index[i] = ir_builder_->CreateSDiv(
-              index[i], index_typed_const(pad_dim.interior_padding() + 1));
-          in_bounds = ir_builder_->CreateAnd(
-              in_bounds,
-              ir_builder_->CreateICmpSLT(
-                  index[i],
-                  index_typed_const(hlo->operand(0)->shape().dimensions(i))),
-              "in_bounds");
-        }
-
-        // if (in_bounds) {
-        //   ret_value = operand0[index];  // source
-        // } else {
-        //   ret_value = *operand1;        // padding
-        // }
-        llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
-            llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
-                                           module_),
-            "pad_result_addr", ir_builder_);
-        llvm_ir::LlvmIfData if_data =
-            llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
-        SetToFirstInsertPoint(if_data.true_block, ir_builder_);
-        TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
-                            operand_to_generator.at(hlo->operand(0))(index));
-        ir_builder_->CreateStore(operand_value, ret_value_addr);
-
-        SetToFirstInsertPoint(if_data.false_block, ir_builder_);
-        TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
-                            operand_to_generator.at(hlo->operand(1))({}));
-        ir_builder_->CreateStore(padding_value, ret_value_addr);
-
-        SetToFirstInsertPoint(if_data.after_block, ir_builder_);
-        // Don't create phi(operand_value, padding_value) here, because invoking
-        // operand_to_generator may create new basic blocks, making the parent
-        // of operand_value or padding_value no longer a predecessor of
-        // if_data.after_block.
-        return ir_builder_->CreateLoad(ret_value_addr);
+        return EmitElementalPad(hlo, operand_to_generator, padded_index);
       };
 
     case HloOpcode::kDot:
-      return [=, &operand_to_generator](const IrArray::Index& dot_result_index)
+      return [this, hlo,
+              &operand_to_generator](const IrArray::Index& dot_result_index)
                  -> StatusOr<llvm::Value*> {
-        auto lhs_generator = operand_to_generator.at(hlo->operand(0));
-        auto rhs_generator = operand_to_generator.at(hlo->operand(1));
-        int64 contracted_dim_size = hlo->operand(0)->shape().dimensions(
-            hlo->operand(0)->shape().dimensions_size() - 1);
-        int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
-        int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
-
-        std::unique_ptr<llvm_ir::ForLoop> inner_loop =
-            llvm_ir::ForLoop::EmitForLoop(
-                IrName(hlo, "inner"), ir_builder_->getInt64(0),
-                ir_builder_->getInt64(contracted_dim_size),
-                ir_builder_->getInt64(1), ir_builder_);
-
-        SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(),
-                              ir_builder_);
-        PrimitiveType primitive_type = hlo->shape().element_type();
-        llvm::Type* primitive_type_llvm =
-            llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
-        llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(
-            primitive_type_llvm, "dot_acc", ir_builder_);
-        ir_builder_->CreateStore(
-            llvm::Constant::getNullValue(primitive_type_llvm),
-            accumulator_alloca);
-
-        SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_);
-
-        // This is the inner reduction loop for a dot operation that produces
-        // one element in the output.  If the operands to the dot operation have
-        // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
-        // Given an output index [a,b,c,d,e] in the result, we compute:
-        //   sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
-
-        IrArray::Index lhs_index, rhs_index;
-
-        for (int64 i = 0; i < lhs_dims - 1; i++) {
-          lhs_index.push_back(dot_result_index[i]);
-        }
-        lhs_index.push_back(inner_loop->GetIndVarValue());
-
-        for (int64 i = 0; i < rhs_dims - 2; i++) {
-          rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]);
-        }
-        rhs_index.push_back(inner_loop->GetIndVarValue());
-        rhs_index.push_back(dot_result_index.back());
-
-        llvm::Value* current_accumulator =
-            ir_builder_->CreateLoad(accumulator_alloca);
-        TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
-        TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
-        llvm::Value* next_accumulator;
-        if (primitive_util::IsComplexType(primitive_type)) {
-          llvm::Value* product_real = ir_builder_->CreateFSub(
-              ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
-                                      EmitExtractReal(rhs_value)),
-              ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
-                                      EmitExtractImag(rhs_value)));
-          llvm::Value* product_imag = ir_builder_->CreateFAdd(
-              ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
-                                      EmitExtractImag(rhs_value)),
-              ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
-                                      EmitExtractReal(rhs_value)));
-          next_accumulator = ir_builder_->CreateInsertValue(
-              current_accumulator,
-              ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator),
-                                      product_real),
-              {0});
-          next_accumulator = ir_builder_->CreateInsertValue(
-              next_accumulator,
-              ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator),
-                                      product_imag),
-              {1});
-        } else if (primitive_util::IsFloatingPointType(primitive_type)) {
-          next_accumulator = ir_builder_->CreateFAdd(
-              current_accumulator,
-              ir_builder_->CreateFMul(lhs_value, rhs_value));
-        } else {
-          next_accumulator = ir_builder_->CreateAdd(
-              current_accumulator,
-              ir_builder_->CreateMul(lhs_value, rhs_value));
-        }
-        ir_builder_->CreateStore(next_accumulator, accumulator_alloca);
-
-        SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_);
-        return ir_builder_->CreateLoad(accumulator_alloca);
+        return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
       };
     default:
       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
index c516a82..26dff0d 100644 (file)
@@ -142,6 +142,46 @@ class ElementalIrEmitter {
     return ir_builder_->getIntN(128, 0);
   }
 
+  StatusOr<llvm::Value*> EmitElementalSelect(
+      const HloInstruction* hlo,
+      const HloToElementGeneratorMap& operand_to_generator,
+      const llvm_ir::IrArray::Index& index) const;
+
+  StatusOr<llvm::Value*> EmitElementalClamp(
+      const HloInstruction* hlo,
+      const HloToElementGeneratorMap& operand_to_generator,
+      const llvm_ir::IrArray::Index& index) const;
+
+  StatusOr<llvm::Value*> EmitElementalConcatenate(
+      const HloInstruction* hlo,
+      const HloToElementGeneratorMap& operand_to_generator,
+      const llvm_ir::IrArray::Index& target_index) const;
+
+  StatusOr<llvm::Value*> EmitElementalDynamicSlice(
+      const HloInstruction* hlo,
+      const HloToElementGeneratorMap& operand_to_generator,
+      const llvm_ir::IrArray::Index& index) const;
+
+  StatusOr<llvm::Value*> EmitElementalGather(
+      const HloInstruction* hlo,
+      const HloToElementGeneratorMap& operand_to_generator,
+      const llvm_ir::IrArray::Index& index) const;
+
+  StatusOr<llvm::Value*> EmitElementalDynamicUpdateSlice(
+      const HloInstruction* hlo,
+      const HloToElementGeneratorMap& operand_to_generator,
+      const llvm_ir::IrArray::Index& index) const;
+
+  StatusOr<llvm::Value*> EmitElementalPad(
+      const HloInstruction* hlo,
+      const HloToElementGeneratorMap& operand_to_generator,
+      const llvm_ir::IrArray::Index& padded_index) const;
+
+  StatusOr<llvm::Value*> EmitElementalDot(
+      const HloInstruction* hlo,
+      const HloToElementGeneratorMap& operand_to_generator,
+      const llvm_ir::IrArray::Index& dot_result_index) const;
+
   llvm::IRBuilder<>* const ir_builder_;
 
   llvm::Module* module_;