};
}
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
+ const HloInstruction* hlo,
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+ const llvm_ir::IrArray::Index& index) const {
+ TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
+ operand_to_generator.at(hlo->operand(0))(
+ ElementwiseSourceIndex(index, *hlo, 0)));
+ TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
+ operand_to_generator.at(hlo->operand(1))(
+ ElementwiseSourceIndex(index, *hlo, 1)));
+ TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
+ operand_to_generator.at(hlo->operand(2))(
+ ElementwiseSourceIndex(index, *hlo, 2)));
+ return ir_builder_->CreateSelect(
+ ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()),
+ on_true_value, on_false_value);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
+ const HloInstruction* hlo,
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+ const llvm_ir::IrArray::Index& index) const {
+ TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
+ operand_to_generator.at(hlo->operand(0))(
+ ElementwiseSourceIndex(index, *hlo, 0)));
+ TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
+ operand_to_generator.at(hlo->operand(1))(
+ ElementwiseSourceIndex(index, *hlo, 1)));
+ TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
+ operand_to_generator.at(hlo->operand(2))(
+ ElementwiseSourceIndex(index, *hlo, 2)));
+ PrimitiveType prim_type = hlo->shape().element_type();
+ if (primitive_util::IsFloatingPointType(prim_type)) {
+ return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
+ } else if (primitive_util::IsIntegralType(prim_type)) {
+ bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
+ return EmitIntegralMin(
+ max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed);
+ } else {
+ return Unimplemented("Clamp unimplemented for %s",
+ PrimitiveType_Name(prim_type).c_str());
+ }
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
+ const HloInstruction* hlo,
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+ const llvm_ir::IrArray::Index& target_index) const {
+ const int64 concat_dim = hlo->dimensions(0);
+ auto source_index = target_index;
+
+ llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock();
+
+ // A terminator should be present iff we're emitting code
+ // into the middle (as opposed to the end) of a basic block.
+ CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(),
+ init_block->getTerminator() == nullptr);
+
+ llvm::BasicBlock* exit_block;
+ if (ir_builder_->GetInsertPoint() == init_block->end()) {
+ exit_block = llvm_ir::CreateBasicBlock(
+ /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_);
+ } else {
+ exit_block = init_block->splitBasicBlock(ir_builder_->GetInsertPoint(),
+ AsStringRef(IrName(hlo, "merge")));
+ init_block->getTerminator()->eraseFromParent();
+ }
+
+ llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_);
+ llvm::PHINode* output = ir_builder_->CreatePHI(
+ llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
+ hlo->operands().size());
+ auto prior_insert_point = ir_builder_->GetInsertPoint();
+
+ ir_builder_->SetInsertPoint(init_block);
+
+ for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
+ ++operand_idx) {
+ const HloInstruction* operand = hlo->operand(operand_idx);
+ auto true_block = llvm_ir::CreateBasicBlock(
+ exit_block, StrCat("concat_index_from_operand", operand_idx),
+ ir_builder_);
+ auto false_block = llvm_ir::CreateBasicBlock(
+ exit_block, StrCat("concat_index_not_from_operand", operand_idx),
+ ir_builder_);
+ auto concat_dim_size =
+ llvm::ConstantInt::get(source_index[concat_dim]->getType(),
+ operand->shape().dimensions(concat_dim));
+ ir_builder_->CreateCondBr(
+ ir_builder_->CreateICmpULT(source_index[concat_dim], concat_dim_size),
+ true_block, false_block);
+
+ // Create the terminator of the true block before calling operand
+ // generators, because they require non-degenerate basic blocks.
+ ir_builder_->SetInsertPoint(
+ llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block));
+ TF_ASSIGN_OR_RETURN(llvm::Value * value,
+ operand_to_generator.at(operand)(source_index));
+ output->addIncoming(value, ir_builder_->GetInsertBlock());
+
+ // Subtract the size of the concat dimension of the current operand
+ // from the source index.
+ ir_builder_->SetInsertPoint(false_block);
+ source_index[concat_dim] =
+ ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size);
+ }
+
+ ir_builder_->CreateUnreachable();
+ ir_builder_->SetInsertPoint(exit_block, prior_insert_point);
+ return output;
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
+ const HloInstruction* hlo,
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+ const llvm_ir::IrArray::Index& index) const {
+ // Emit IR to read dynamic start indices from hlo->operand(1).
+ const HloInstruction* input_hlo = hlo->operand(0);
+ const int64 rank = ShapeUtil::Rank(input_hlo->shape());
+ llvm_ir::IrArray::Index slice_start_index(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
+ TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
+ operand_to_generator.at(hlo->operand(1))(dim_index));
+ start_index_value->setName(
+ AsStringRef(IrName(hlo, StrCat("start_idx", i))));
+ slice_start_index[i] = start_index_value;
+ }
+
+ llvm_ir::IrArray::Index input_index(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ // Emit IR which computes:
+ // input_index = (start_index + offset_index) % dim_size
+ // Security note: this is the code that keeps the indices in-bounds.
+ llvm::Value* dim_size = llvm::ConstantInt::get(
+ index[i]->getType(), input_hlo->shape().dimensions(i));
+ llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast(
+ slice_start_index[i], index[i]->getType());
+ input_index[i] = ir_builder_->CreateURem(
+ ir_builder_->CreateAdd(start_index, index[i]), dim_size);
+ }
+ return operand_to_generator.at(input_hlo)(input_index);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
+ const HloInstruction* hlo,
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+ const llvm_ir::IrArray::Index& index) const {
+ const Shape& operand_shape = hlo->operand(0)->shape();
+ const Shape& indices_shape = hlo->operand(1)->shape();
+ const Shape& output_shape = hlo->shape();
+
+ const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers();
+
+ const llvm_ir::ElementGenerator& operand_generator =
+ operand_to_generator.at(hlo->operand(0));
+ const llvm_ir::ElementGenerator& indices_generator =
+ operand_to_generator.at(hlo->operand(1));
+
+ // This is the index into `operand` that holds the element we want to
+ // generate. This index "unsafe" as in the components in here may be
+ // out of bounds.
+ IrArray::Index unsafe_operand_index;
+
+ // First copy in the window indices to unsafe_operand_index.
+ for (int64 i = 0, e = operand_shape.dimensions_size(),
+ unsafe_operand_index_dim = 0;
+ i < e; i++) {
+ if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
+ unsafe_operand_index.push_back(ir_builder_->getInt64(0));
+ } else {
+ unsafe_operand_index.push_back(
+ index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]);
+ }
+ }
+
+ // This is the index of the index vector in the gather_indices tensor.
+ IrArray::Index gather_index_index;
+ {
+ std::vector<llvm::Value*> gather_index_index_components;
+ for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
+ if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
+ gather_index_index.push_back(index[i]);
+ }
+ }
+
+ if (gather_index_index.size() != indices_shape.dimensions_size()) {
+ gather_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr);
+ }
+ }
+
+ auto add_to_unsafe_operand_index = [&](llvm::Value* index_component,
+ int64 dim) {
+ llvm::Value* gather_dim_component_extended = ir_builder_->CreateSExtOrTrunc(
+ index_component, ir_builder_->getInt64Ty());
+ unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] =
+ ir_builder_->CreateAdd(
+ unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)],
+ gather_dim_component_extended);
+ };
+
+ if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
+ TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
+ indices_generator(gather_index_index));
+ add_to_unsafe_operand_index(gather_dim_component, 0);
+ } else {
+ int64 index_vector_size =
+ indices_shape.dimensions(dim_numbers.index_vector_dim());
+ for (int64 i = 0; i < index_vector_size; i++) {
+ gather_index_index[dim_numbers.index_vector_dim()] =
+ ir_builder_->getInt64(i);
+ TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
+ indices_generator(gather_index_index));
+ add_to_unsafe_operand_index(gather_dim_component, i);
+ }
+ }
+
+ IrArray::Index safe_operand_index;
+ for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) {
+ safe_operand_index.push_back(ir_builder_->CreateURem(
+ unsafe_operand_index[i],
+ ir_builder_->getInt64(operand_shape.dimensions(i))));
+ }
+
+ return operand_generator(safe_operand_index);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
+ const HloInstruction* hlo,
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+ const llvm_ir::IrArray::Index& index) const {
+ const HloInstruction* input_hlo = hlo->operand(0);
+ const HloInstruction* update_hlo = hlo->operand(1);
+ const HloInstruction* start_hlo = hlo->operand(2);
+ // Calculate slice start/end indices.
+ const int64 rank = ShapeUtil::Rank(input_hlo->shape());
+ llvm_ir::IrArray::Index slice_start_index(rank);
+ llvm_ir::IrArray::Index slice_limit_index(rank);
+ // Slice starts at update[index - slice_start_index_adjusted],
+ // where adjusted value = slice_start_index when in bounds, and
+ // adjusted value = slice_start_index - input_dim, when wrapping.
+ llvm_ir::IrArray::Index slice_start_index_adjusted(rank);
+
+ // Slice intersection gathers (ANDs) conditions on all ranks for which
+ // 'input' is set to 'update'
+ llvm::Value* slice_intersection = ir_builder_->getTrue();
+
+ for (int64 i = 0; i < rank; ++i) {
+ // Emit IR to read dynamic start indices from 'start_hlo'.
+ llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
+ TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
+ operand_to_generator.at(start_hlo)(dim_index));
+ start_index_value->setName(
+ AsStringRef(IrName(hlo, StrCat("start_idx", i))));
+ slice_start_index[i] = ir_builder_->CreateZExtOrBitCast(
+ start_index_value, index[i]->getType());
+
+ llvm::Value* input_dim_size = llvm::ConstantInt::get(
+ index[i]->getType(), input_hlo->shape().dimensions(i));
+ llvm::Value* update_dim_size = llvm::ConstantInt::get(
+ index[i]->getType(), update_hlo->shape().dimensions(i));
+
+ // Generate code to handle wrapping semantics:
+ // slice_start_index[i] = slice_start_index[i] % input_dim_size;
+ // slice_limit_index[i] = slice_start_index[i] + update_dim_size.
+ // slice_start_index[i] is updated in place and it will now be in
+ // range. slice_limit_index[i] may be out of range, and it's being
+ // URem-ed below if so.
+ slice_start_index[i] =
+ ir_builder_->CreateURem(slice_start_index[i], input_dim_size);
+ slice_limit_index[i] =
+ ir_builder_->CreateAdd(slice_start_index[i], update_dim_size);
+
+ // Test if slice_limit_index[i] is in bounds
+ llvm::Value* in_bounds =
+ ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size);
+ llvm_ir::LlvmIfData if_in_bounds =
+ llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
+
+ // Handle true BB (slice_limit_index[i] <= input_dim_size).
+ SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
+ // Check that index[i] >= slice_start_index[i] &&
+ // index[i] < slice_limit_index[i]
+ llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd(
+ slice_intersection,
+ ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
+ "slice_intersection_in");
+ slice_intersection_in_bounds = ir_builder_->CreateAnd(
+ slice_intersection_in_bounds,
+ ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]),
+ "slice_intersection_in");
+
+ // Handle false BB (slice_limit_index[i] > input_dim_size).
+ SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_);
+ // Check that index[i] >= slice_start_index[i] ||
+ // index[i] < slice_limit_index[i]%input_dim_size.
+ llvm::Value* index_wraps = ir_builder_->CreateICmpSLT(
+ index[i],
+ ir_builder_->CreateURem(slice_limit_index[i], input_dim_size));
+ llvm::Value* slice_intersection_or = ir_builder_->CreateOr(
+ ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), index_wraps,
+ "slice_intersection_out");
+ llvm::Value* slice_intersection_out_of_bounds = ir_builder_->CreateAnd(
+ slice_intersection, slice_intersection_or, "slice_intersection_out");
+ // Create value for slice_start_index_adjusted[i] when out of bounds.
+ // If within out-of-bounds if.
+ llvm_ir::LlvmIfData if_start_needs_adjustment =
+ llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_);
+ SetToFirstInsertPoint(if_start_needs_adjustment.true_block, ir_builder_);
+ llvm::Value* slice_start_index_adjusted_oob =
+ ir_builder_->CreateSub(slice_start_index[i], input_dim_size);
+ SetToFirstInsertPoint(if_start_needs_adjustment.after_block, ir_builder_);
+ llvm::PHINode* slice_start_index_adjusted_phi =
+ ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), 2);
+ slice_start_index_adjusted_phi->addIncoming(
+ slice_start_index_adjusted_oob, if_start_needs_adjustment.true_block);
+ slice_start_index_adjusted_phi->addIncoming(
+ slice_start_index[i], if_start_needs_adjustment.false_block);
+ // End of if within if.
+
+ // After checking in/out of bounds.
+ SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_);
+ llvm::PHINode* phi_slice_intersection =
+ ir_builder_->CreatePHI(slice_intersection->getType(), 2);
+ phi_slice_intersection->addIncoming(slice_intersection_in_bounds,
+ if_in_bounds.true_block);
+ phi_slice_intersection->addIncoming(slice_intersection_out_of_bounds,
+ if_start_needs_adjustment.after_block);
+ slice_intersection = phi_slice_intersection;
+
+ llvm::PHINode* phi_index =
+ ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2);
+ phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block);
+ phi_index->addIncoming(slice_start_index_adjusted_phi,
+ if_start_needs_adjustment.after_block);
+ slice_start_index_adjusted[i] = phi_index;
+ }
+
+ // Emit:
+ // if (slice_intersection) -> return data from 'update'.
+ // else -> return data from 'input'.
+ llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
+ "ret_value_addr", ir_builder_);
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ slice_intersection, "slice_intersection", ir_builder_);
+
+ // Handle true BB (return data from 'update')
+ SetToFirstInsertPoint(if_data.true_block, ir_builder_);
+ // Compute update index for intersection case.
+ llvm_ir::IrArray::Index update_index(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ llvm::Value* update_dim_size = llvm::ConstantInt::get(
+ index[i]->getType(), update_hlo->shape().dimensions(i));
+ // NOTE: Subtraction will be positive due to bounds checking above.
+ update_index[i] = ir_builder_->CreateURem(
+ ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]),
+ update_dim_size);
+ }
+ TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
+ operand_to_generator.at(update_hlo)(update_index));
+ ir_builder_->CreateStore(true_value, ret_value_addr);
+
+ // Handle false BB (return data from 'input')
+ SetToFirstInsertPoint(if_data.false_block, ir_builder_);
+ TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
+ operand_to_generator.at(input_hlo)(index));
+ ir_builder_->CreateStore(false_value, ret_value_addr);
+
+ SetToFirstInsertPoint(if_data.after_block, ir_builder_);
+ return ir_builder_->CreateLoad(ret_value_addr);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
+ const HloInstruction* hlo,
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+ const llvm_ir::IrArray::Index& padded_index) const {
+ auto index = padded_index;
+ llvm::Value* in_bounds = ir_builder_->getTrue();
+ for (size_t i = 0; i < index.size(); ++i) {
+ auto index_typed_const = [=](int64 n) {
+ return llvm::ConstantInt::get(index[i]->getType(), n);
+ };
+ const auto& pad_dim = hlo->padding_config().dimensions(i);
+ index[i] = ir_builder_->CreateSub(
+ index[i], index_typed_const(pad_dim.edge_padding_low()));
+ in_bounds = ir_builder_->CreateAnd(
+ in_bounds, ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
+ "in_bounds");
+ in_bounds = ir_builder_->CreateAnd(
+ in_bounds,
+ ir_builder_->CreateICmpEQ(
+ index_typed_const(0),
+ ir_builder_->CreateURem(
+ index[i], index_typed_const(pad_dim.interior_padding() + 1))),
+ "in_bounds");
+ index[i] = ir_builder_->CreateSDiv(
+ index[i], index_typed_const(pad_dim.interior_padding() + 1));
+ in_bounds = ir_builder_->CreateAnd(
+ in_bounds,
+ ir_builder_->CreateICmpSLT(
+ index[i],
+ index_typed_const(hlo->operand(0)->shape().dimensions(i))),
+ "in_bounds");
+ }
+
+ // if (in_bounds) {
+ // ret_value = operand0[index]; // source
+ // } else {
+ // ret_value = *operand1; // padding
+ // }
+ llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
+ "pad_result_addr", ir_builder_);
+ llvm_ir::LlvmIfData if_data =
+ llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
+ SetToFirstInsertPoint(if_data.true_block, ir_builder_);
+ TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
+ operand_to_generator.at(hlo->operand(0))(index));
+ ir_builder_->CreateStore(operand_value, ret_value_addr);
+
+ SetToFirstInsertPoint(if_data.false_block, ir_builder_);
+ TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
+ operand_to_generator.at(hlo->operand(1))({}));
+ ir_builder_->CreateStore(padding_value, ret_value_addr);
+
+ SetToFirstInsertPoint(if_data.after_block, ir_builder_);
+ // Don't create phi(operand_value, padding_value) here, because invoking
+ // operand_to_generator may create new basic blocks, making the parent
+ // of operand_value or padding_value no longer a predecessor of
+ // if_data.after_block.
+ return ir_builder_->CreateLoad(ret_value_addr);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
+ const HloInstruction* hlo,
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
+ const llvm_ir::IrArray::Index& dot_result_index) const {
+ auto lhs_generator = operand_to_generator.at(hlo->operand(0));
+ auto rhs_generator = operand_to_generator.at(hlo->operand(1));
+ int64 contracted_dim_size = hlo->operand(0)->shape().dimensions(
+ hlo->operand(0)->shape().dimensions_size() - 1);
+ int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
+ int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
+
+ std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
+ IrName(hlo, "inner"), ir_builder_->getInt64(0),
+ ir_builder_->getInt64(contracted_dim_size), ir_builder_->getInt64(1),
+ ir_builder_);
+
+ SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_);
+ PrimitiveType primitive_type = hlo->shape().element_type();
+ llvm::Type* primitive_type_llvm =
+ llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
+ llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(
+ primitive_type_llvm, "dot_acc", ir_builder_);
+ ir_builder_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm),
+ accumulator_alloca);
+
+ SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_);
+
+ // This is the inner reduction loop for a dot operation that produces
+ // one element in the output. If the operands to the dot operation have
+ // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
+ // Given an output index [a,b,c,d,e] in the result, we compute:
+ // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
+
+ IrArray::Index lhs_index, rhs_index;
+
+ for (int64 i = 0; i < lhs_dims - 1; i++) {
+ lhs_index.push_back(dot_result_index[i]);
+ }
+ lhs_index.push_back(inner_loop->GetIndVarValue());
+
+ for (int64 i = 0; i < rhs_dims - 2; i++) {
+ rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]);
+ }
+ rhs_index.push_back(inner_loop->GetIndVarValue());
+ rhs_index.push_back(dot_result_index.back());
+
+ llvm::Value* current_accumulator =
+ ir_builder_->CreateLoad(accumulator_alloca);
+ TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
+ TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
+ llvm::Value* next_accumulator;
+ if (primitive_util::IsComplexType(primitive_type)) {
+ llvm::Value* product_real = ir_builder_->CreateFSub(
+ ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value)),
+ ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value)));
+ llvm::Value* product_imag = ir_builder_->CreateFAdd(
+ ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
+ EmitExtractImag(rhs_value)),
+ ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
+ EmitExtractReal(rhs_value)));
+ next_accumulator = ir_builder_->CreateInsertValue(
+ current_accumulator,
+ ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator),
+ product_real),
+ {0});
+ next_accumulator = ir_builder_->CreateInsertValue(
+ next_accumulator,
+ ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator),
+ product_imag),
+ {1});
+ } else if (primitive_util::IsFloatingPointType(primitive_type)) {
+ next_accumulator = ir_builder_->CreateFAdd(
+ current_accumulator, ir_builder_->CreateFMul(lhs_value, rhs_value));
+ } else {
+ next_accumulator = ir_builder_->CreateAdd(
+ current_accumulator, ir_builder_->CreateMul(lhs_value, rhs_value));
+ }
+ ir_builder_->CreateStore(next_accumulator, accumulator_alloca);
+
+ SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_);
+ return ir_builder_->CreateLoad(accumulator_alloca);
+}
+
llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
case HloOpcode::kSelect:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
- TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
- operand_to_generator.at(hlo->operand(0))(
- ElementwiseSourceIndex(index, *hlo, 0)));
- TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
- operand_to_generator.at(hlo->operand(1))(
- ElementwiseSourceIndex(index, *hlo, 1)));
- TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
- operand_to_generator.at(hlo->operand(2))(
- ElementwiseSourceIndex(index, *hlo, 2)));
- return ir_builder_->CreateSelect(
- ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()),
- on_true_value, on_false_value);
+ return EmitElementalSelect(hlo, operand_to_generator, index);
};
case HloOpcode::kClamp:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
- TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
- operand_to_generator.at(hlo->operand(0))(
- ElementwiseSourceIndex(index, *hlo, 0)));
- TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
- operand_to_generator.at(hlo->operand(1))(
- ElementwiseSourceIndex(index, *hlo, 1)));
- TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
- operand_to_generator.at(hlo->operand(2))(
- ElementwiseSourceIndex(index, *hlo, 2)));
- PrimitiveType prim_type = hlo->shape().element_type();
- if (primitive_util::IsFloatingPointType(prim_type)) {
- return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
- } else if (primitive_util::IsIntegralType(prim_type)) {
- bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
- return EmitIntegralMin(
- max_value, EmitIntegralMax(min_value, arg_value, is_signed),
- is_signed);
- } else {
- return Unimplemented("Clamp unimplemented for %s",
- PrimitiveType_Name(prim_type).c_str());
- }
+ return EmitElementalClamp(hlo, operand_to_generator, index);
};
case HloOpcode::kReducePrecision:
return [this, hlo, &operand_to_generator](
case HloOpcode::kConcatenate:
return [this, hlo, &operand_to_generator](
const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
- const int64 concat_dim = hlo->dimensions(0);
- auto source_index = target_index;
-
- llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock();
-
- // A terminator should be present iff we're emitting code
- // into the middle (as opposed to the end) of a basic block.
- CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(),
- init_block->getTerminator() == nullptr);
-
- llvm::BasicBlock* exit_block;
- if (ir_builder_->GetInsertPoint() == init_block->end()) {
- exit_block = llvm_ir::CreateBasicBlock(
- /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_);
- } else {
- exit_block = init_block->splitBasicBlock(
- ir_builder_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge")));
- init_block->getTerminator()->eraseFromParent();
- }
-
- llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_);
- llvm::PHINode* output =
- ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType(
- hlo->shape().element_type(), module_),
- hlo->operands().size());
- auto prior_insert_point = ir_builder_->GetInsertPoint();
-
- ir_builder_->SetInsertPoint(init_block);
-
- for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
- ++operand_idx) {
- const HloInstruction* operand = hlo->operand(operand_idx);
- auto true_block = llvm_ir::CreateBasicBlock(
- exit_block, StrCat("concat_index_from_operand", operand_idx),
- ir_builder_);
- auto false_block = llvm_ir::CreateBasicBlock(
- exit_block, StrCat("concat_index_not_from_operand", operand_idx),
- ir_builder_);
- auto concat_dim_size =
- llvm::ConstantInt::get(source_index[concat_dim]->getType(),
- operand->shape().dimensions(concat_dim));
- ir_builder_->CreateCondBr(
- ir_builder_->CreateICmpULT(source_index[concat_dim],
- concat_dim_size),
- true_block, false_block);
-
- // Create the terminator of the true block before calling operand
- // generators, because they require non-degenerate basic blocks.
- ir_builder_->SetInsertPoint(
- llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block));
- TF_ASSIGN_OR_RETURN(llvm::Value * value,
- operand_to_generator.at(operand)(source_index));
- output->addIncoming(value, ir_builder_->GetInsertBlock());
-
- // Subtract the size of the concat dimension of the current operand
- // from the source index.
- ir_builder_->SetInsertPoint(false_block);
- source_index[concat_dim] =
- ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size);
- }
-
- ir_builder_->CreateUnreachable();
- ir_builder_->SetInsertPoint(exit_block, prior_insert_point);
- return output;
+ return EmitElementalConcatenate(hlo, operand_to_generator,
+ target_index);
};
case HloOpcode::kReverse:
return [this, hlo, &operand_to_generator](
case HloOpcode::kDynamicSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
- // Emit IR to read dynamic start indices from hlo->operand(1).
- const HloInstruction* input_hlo = hlo->operand(0);
- const int64 rank = ShapeUtil::Rank(input_hlo->shape());
- llvm_ir::IrArray::Index slice_start_index(rank);
- for (int64 i = 0; i < rank; ++i) {
- llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
- TF_ASSIGN_OR_RETURN(
- llvm::Value * start_index_value,
- operand_to_generator.at(hlo->operand(1))(dim_index));
- start_index_value->setName(
- AsStringRef(IrName(hlo, StrCat("start_idx", i))));
- slice_start_index[i] = start_index_value;
- }
-
- llvm_ir::IrArray::Index input_index(rank);
- for (int64 i = 0; i < rank; ++i) {
- // Emit IR which computes:
- // input_index = (start_index + offset_index) % dim_size
- // Security note: this is the code that keeps the indices in-bounds.
- llvm::Value* dim_size = llvm::ConstantInt::get(
- index[i]->getType(), input_hlo->shape().dimensions(i));
- llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast(
- slice_start_index[i], index[i]->getType());
- input_index[i] = ir_builder_->CreateURem(
- ir_builder_->CreateAdd(start_index, index[i]), dim_size);
- }
- return operand_to_generator.at(input_hlo)(input_index);
+ return EmitElementalDynamicSlice(hlo, operand_to_generator, index);
};
case HloOpcode::kGather:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
- const Shape& operand_shape = hlo->operand(0)->shape();
- const Shape& indices_shape = hlo->operand(1)->shape();
- const Shape& output_shape = hlo->shape();
-
- const GatherDimensionNumbers& dim_numbers =
- hlo->gather_dimension_numbers();
-
- const llvm_ir::ElementGenerator& operand_generator =
- operand_to_generator.at(hlo->operand(0));
- const llvm_ir::ElementGenerator& indices_generator =
- operand_to_generator.at(hlo->operand(1));
-
- // This is the index into `operand` that holds the element we want to
- // generate. This index "unsafe" as in the components in here may be
- // out of bounds.
- IrArray::Index unsafe_operand_index;
-
- // First copy in the window indices to unsafe_operand_index.
- for (int64 i = 0, e = operand_shape.dimensions_size(),
- unsafe_operand_index_dim = 0;
- i < e; i++) {
- if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
- unsafe_operand_index.push_back(ir_builder_->getInt64(0));
- } else {
- unsafe_operand_index.push_back(index[dim_numbers.output_window_dims(
- unsafe_operand_index_dim++)]);
- }
- }
-
- // This is the index of the index vector in the gather_indices tensor.
- IrArray::Index gather_index_index;
- {
- std::vector<llvm::Value*> gather_index_index_components;
- for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
- if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
- gather_index_index.push_back(index[i]);
- }
- }
-
- if (gather_index_index.size() != indices_shape.dimensions_size()) {
- gather_index_index.InsertAt(dim_numbers.index_vector_dim(),
- nullptr);
- }
- }
-
- auto add_to_unsafe_operand_index = [&](llvm::Value* index_component,
- int64 dim) {
- llvm::Value* gather_dim_component_extended =
- ir_builder_->CreateSExtOrTrunc(index_component,
- ir_builder_->getInt64Ty());
- unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] =
- ir_builder_->CreateAdd(
- unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(
- dim)],
- gather_dim_component_extended);
- };
-
- if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
- TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
- indices_generator(gather_index_index));
- add_to_unsafe_operand_index(gather_dim_component, 0);
- } else {
- int64 index_vector_size =
- indices_shape.dimensions(dim_numbers.index_vector_dim());
- for (int64 i = 0; i < index_vector_size; i++) {
- gather_index_index[dim_numbers.index_vector_dim()] =
- ir_builder_->getInt64(i);
- TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
- indices_generator(gather_index_index));
- add_to_unsafe_operand_index(gather_dim_component, i);
- }
- }
-
- IrArray::Index safe_operand_index;
- for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) {
- safe_operand_index.push_back(ir_builder_->CreateURem(
- unsafe_operand_index[i],
- ir_builder_->getInt64(operand_shape.dimensions(i))));
- }
-
- return operand_generator(safe_operand_index);
+ return EmitElementalGather(hlo, operand_to_generator, index);
};
case HloOpcode::kDynamicUpdateSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
- const HloInstruction* input_hlo = hlo->operand(0);
- const HloInstruction* update_hlo = hlo->operand(1);
- const HloInstruction* start_hlo = hlo->operand(2);
- // Calculate slice start/end indices.
- const int64 rank = ShapeUtil::Rank(input_hlo->shape());
- llvm_ir::IrArray::Index slice_start_index(rank);
- llvm_ir::IrArray::Index slice_limit_index(rank);
- // Slice starts at update[index - slice_start_index_adjusted],
- // where adjusted value = slice_start_index when in bounds, and
- // adjusted value = slice_start_index - input_dim, when wrapping.
- llvm_ir::IrArray::Index slice_start_index_adjusted(rank);
-
- // Slice intersection gathers (ANDs) conditions on all ranks for which
- // 'input' is set to 'update'
- llvm::Value* slice_intersection = ir_builder_->getTrue();
-
- for (int64 i = 0; i < rank; ++i) {
- // Emit IR to read dynamic start indices from 'start_hlo'.
- llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
- TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
- operand_to_generator.at(start_hlo)(dim_index));
- start_index_value->setName(
- AsStringRef(IrName(hlo, StrCat("start_idx", i))));
- slice_start_index[i] = ir_builder_->CreateZExtOrBitCast(
- start_index_value, index[i]->getType());
-
- llvm::Value* input_dim_size = llvm::ConstantInt::get(
- index[i]->getType(), input_hlo->shape().dimensions(i));
- llvm::Value* update_dim_size = llvm::ConstantInt::get(
- index[i]->getType(), update_hlo->shape().dimensions(i));
-
- // Generate code to handle wrapping semantics:
- // slice_start_index[i] = slice_start_index[i] % input_dim_size;
- // slice_limit_index[i] = slice_start_index[i] + update_dim_size.
- // slice_start_index[i] is updated in place and it will now be in
- // range. slice_limit_index[i] may be out of range, and it's being
- // URem-ed below if so.
- slice_start_index[i] =
- ir_builder_->CreateURem(slice_start_index[i], input_dim_size);
- slice_limit_index[i] =
- ir_builder_->CreateAdd(slice_start_index[i], update_dim_size);
-
- // Test if slice_limit_index[i] is in bounds
- llvm::Value* in_bounds =
- ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size);
- llvm_ir::LlvmIfData if_in_bounds =
- llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
-
- // Handle true BB (slice_limit_index[i] <= input_dim_size).
- SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
- // Check that index[i] >= slice_start_index[i] &&
- // index[i] < slice_limit_index[i]
- llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd(
- slice_intersection,
- ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
- "slice_intersection_in");
- slice_intersection_in_bounds = ir_builder_->CreateAnd(
- slice_intersection_in_bounds,
- ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]),
- "slice_intersection_in");
-
- // Handle false BB (slice_limit_index[i] > input_dim_size).
- SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_);
- // Check that index[i] >= slice_start_index[i] ||
- // index[i] < slice_limit_index[i]%input_dim_size.
- llvm::Value* index_wraps = ir_builder_->CreateICmpSLT(
- index[i],
- ir_builder_->CreateURem(slice_limit_index[i], input_dim_size));
- llvm::Value* slice_intersection_or = ir_builder_->CreateOr(
- ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
- index_wraps, "slice_intersection_out");
- llvm::Value* slice_intersection_out_of_bounds =
- ir_builder_->CreateAnd(slice_intersection, slice_intersection_or,
- "slice_intersection_out");
- // Create value for slice_start_index_adjusted[i] when out of bounds.
- // If within out-of-bounds if.
- llvm_ir::LlvmIfData if_start_needs_adjustment =
- llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_);
- SetToFirstInsertPoint(if_start_needs_adjustment.true_block,
- ir_builder_);
- llvm::Value* slice_start_index_adjusted_oob =
- ir_builder_->CreateSub(slice_start_index[i], input_dim_size);
- SetToFirstInsertPoint(if_start_needs_adjustment.after_block,
- ir_builder_);
- llvm::PHINode* slice_start_index_adjusted_phi =
- ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(),
- 2);
- slice_start_index_adjusted_phi->addIncoming(
- slice_start_index_adjusted_oob,
- if_start_needs_adjustment.true_block);
- slice_start_index_adjusted_phi->addIncoming(
- slice_start_index[i], if_start_needs_adjustment.false_block);
- // End of if within if.
-
- // After checking in/out of bounds.
- SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_);
- llvm::PHINode* phi_slice_intersection =
- ir_builder_->CreatePHI(slice_intersection->getType(), 2);
- phi_slice_intersection->addIncoming(slice_intersection_in_bounds,
- if_in_bounds.true_block);
- phi_slice_intersection->addIncoming(
- slice_intersection_out_of_bounds,
- if_start_needs_adjustment.after_block);
- slice_intersection = phi_slice_intersection;
-
- llvm::PHINode* phi_index =
- ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2);
- phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block);
- phi_index->addIncoming(slice_start_index_adjusted_phi,
- if_start_needs_adjustment.after_block);
- slice_start_index_adjusted[i] = phi_index;
- }
-
- // Emit:
- // if (slice_intersection) -> return data from 'update'.
- // else -> return data from 'input'.
- llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
- module_),
- "ret_value_addr", ir_builder_);
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- slice_intersection, "slice_intersection", ir_builder_);
-
- // Handle true BB (return data from 'update')
- SetToFirstInsertPoint(if_data.true_block, ir_builder_);
- // Compute update index for intersection case.
- llvm_ir::IrArray::Index update_index(rank);
- for (int64 i = 0; i < rank; ++i) {
- llvm::Value* update_dim_size = llvm::ConstantInt::get(
- index[i]->getType(), update_hlo->shape().dimensions(i));
- // NOTE: Subtraction will be positive due to bounds checking above.
- update_index[i] = ir_builder_->CreateURem(
- ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]),
- update_dim_size);
- }
- TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
- operand_to_generator.at(update_hlo)(update_index));
- ir_builder_->CreateStore(true_value, ret_value_addr);
-
- // Handle false BB (return data from 'input')
- SetToFirstInsertPoint(if_data.false_block, ir_builder_);
- TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
- operand_to_generator.at(input_hlo)(index));
- ir_builder_->CreateStore(false_value, ret_value_addr);
-
- SetToFirstInsertPoint(if_data.after_block, ir_builder_);
- return ir_builder_->CreateLoad(ret_value_addr);
+ return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator,
+ index);
};
case HloOpcode::kBitcast:
CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
case HloOpcode::kRng:
return MakeRngElementGenerator(hlo, operand_to_generator);
case HloOpcode::kPad:
- return [=, &operand_to_generator](
+ return [this, hlo, &operand_to_generator](
const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
- auto index = padded_index;
- llvm::Value* in_bounds = ir_builder_->getTrue();
- for (size_t i = 0; i < index.size(); ++i) {
- auto index_typed_const = [=](int64 n) {
- return llvm::ConstantInt::get(index[i]->getType(), n);
- };
- const auto& pad_dim = hlo->padding_config().dimensions(i);
- index[i] = ir_builder_->CreateSub(
- index[i], index_typed_const(pad_dim.edge_padding_low()));
- in_bounds = ir_builder_->CreateAnd(
- in_bounds,
- ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
- "in_bounds");
- in_bounds = ir_builder_->CreateAnd(
- in_bounds,
- ir_builder_->CreateICmpEQ(
- index_typed_const(0),
- ir_builder_->CreateURem(
- index[i],
- index_typed_const(pad_dim.interior_padding() + 1))),
- "in_bounds");
- index[i] = ir_builder_->CreateSDiv(
- index[i], index_typed_const(pad_dim.interior_padding() + 1));
- in_bounds = ir_builder_->CreateAnd(
- in_bounds,
- ir_builder_->CreateICmpSLT(
- index[i],
- index_typed_const(hlo->operand(0)->shape().dimensions(i))),
- "in_bounds");
- }
-
- // if (in_bounds) {
- // ret_value = operand0[index]; // source
- // } else {
- // ret_value = *operand1; // padding
- // }
- llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
- module_),
- "pad_result_addr", ir_builder_);
- llvm_ir::LlvmIfData if_data =
- llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
- SetToFirstInsertPoint(if_data.true_block, ir_builder_);
- TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
- operand_to_generator.at(hlo->operand(0))(index));
- ir_builder_->CreateStore(operand_value, ret_value_addr);
-
- SetToFirstInsertPoint(if_data.false_block, ir_builder_);
- TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
- operand_to_generator.at(hlo->operand(1))({}));
- ir_builder_->CreateStore(padding_value, ret_value_addr);
-
- SetToFirstInsertPoint(if_data.after_block, ir_builder_);
- // Don't create phi(operand_value, padding_value) here, because invoking
- // operand_to_generator may create new basic blocks, making the parent
- // of operand_value or padding_value no longer a predecessor of
- // if_data.after_block.
- return ir_builder_->CreateLoad(ret_value_addr);
+ return EmitElementalPad(hlo, operand_to_generator, padded_index);
};
case HloOpcode::kDot:
- return [=, &operand_to_generator](const IrArray::Index& dot_result_index)
+ return [this, hlo,
+ &operand_to_generator](const IrArray::Index& dot_result_index)
-> StatusOr<llvm::Value*> {
- auto lhs_generator = operand_to_generator.at(hlo->operand(0));
- auto rhs_generator = operand_to_generator.at(hlo->operand(1));
- int64 contracted_dim_size = hlo->operand(0)->shape().dimensions(
- hlo->operand(0)->shape().dimensions_size() - 1);
- int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
- int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
-
- std::unique_ptr<llvm_ir::ForLoop> inner_loop =
- llvm_ir::ForLoop::EmitForLoop(
- IrName(hlo, "inner"), ir_builder_->getInt64(0),
- ir_builder_->getInt64(contracted_dim_size),
- ir_builder_->getInt64(1), ir_builder_);
-
- SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(),
- ir_builder_);
- PrimitiveType primitive_type = hlo->shape().element_type();
- llvm::Type* primitive_type_llvm =
- llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
- llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(
- primitive_type_llvm, "dot_acc", ir_builder_);
- ir_builder_->CreateStore(
- llvm::Constant::getNullValue(primitive_type_llvm),
- accumulator_alloca);
-
- SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_);
-
- // This is the inner reduction loop for a dot operation that produces
- // one element in the output. If the operands to the dot operation have
- // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
- // Given an output index [a,b,c,d,e] in the result, we compute:
- // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
-
- IrArray::Index lhs_index, rhs_index;
-
- for (int64 i = 0; i < lhs_dims - 1; i++) {
- lhs_index.push_back(dot_result_index[i]);
- }
- lhs_index.push_back(inner_loop->GetIndVarValue());
-
- for (int64 i = 0; i < rhs_dims - 2; i++) {
- rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]);
- }
- rhs_index.push_back(inner_loop->GetIndVarValue());
- rhs_index.push_back(dot_result_index.back());
-
- llvm::Value* current_accumulator =
- ir_builder_->CreateLoad(accumulator_alloca);
- TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
- TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
- llvm::Value* next_accumulator;
- if (primitive_util::IsComplexType(primitive_type)) {
- llvm::Value* product_real = ir_builder_->CreateFSub(
- ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value)));
- llvm::Value* product_imag = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractImag(rhs_value)),
- ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractReal(rhs_value)));
- next_accumulator = ir_builder_->CreateInsertValue(
- current_accumulator,
- ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator),
- product_real),
- {0});
- next_accumulator = ir_builder_->CreateInsertValue(
- next_accumulator,
- ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator),
- product_imag),
- {1});
- } else if (primitive_util::IsFloatingPointType(primitive_type)) {
- next_accumulator = ir_builder_->CreateFAdd(
- current_accumulator,
- ir_builder_->CreateFMul(lhs_value, rhs_value));
- } else {
- next_accumulator = ir_builder_->CreateAdd(
- current_accumulator,
- ir_builder_->CreateMul(lhs_value, rhs_value));
- }
- ir_builder_->CreateStore(next_accumulator, accumulator_alloca);
-
- SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_);
- return ir_builder_->CreateLoad(accumulator_alloca);
+ return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
};
default:
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {