From 68efa500c0f8ec9c42072b25a5d1b5bf4f0afb21 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Fri, 27 Apr 2018 18:41:27 -0700 Subject: [PATCH] Split up ElementaIrEmitter::MakeElementGenerator into smaller functions; NFC PiperOrigin-RevId: 194622198 --- .../compiler/xla/service/elemental_ir_emitter.cc | 1028 ++++++++++---------- .../compiler/xla/service/elemental_ir_emitter.h | 40 + 2 files changed, 572 insertions(+), 496 deletions(-) diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 4b01c87..ae32d33 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1344,6 +1344,525 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( }; } +StatusOr ElementalIrEmitter::EmitElementalSelect( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value, + operand_to_generator.at(hlo->operand(1))( + ElementwiseSourceIndex(index, *hlo, 1))); + TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, + operand_to_generator.at(hlo->operand(2))( + ElementwiseSourceIndex(index, *hlo, 2))); + return ir_builder_->CreateSelect( + ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()), + on_true_value, on_false_value); +} + +StatusOr ElementalIrEmitter::EmitElementalClamp( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + TF_ASSIGN_OR_RETURN(llvm::Value * min_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + TF_ASSIGN_OR_RETURN(llvm::Value * arg_value, + operand_to_generator.at(hlo->operand(1))( + ElementwiseSourceIndex(index, *hlo, 1))); + TF_ASSIGN_OR_RETURN(llvm::Value * max_value, + operand_to_generator.at(hlo->operand(2))( + ElementwiseSourceIndex(index, *hlo, 2))); + PrimitiveType prim_type = hlo->shape().element_type(); + if (primitive_util::IsFloatingPointType(prim_type)) { + return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); + } else if (primitive_util::IsIntegralType(prim_type)) { + bool is_signed = primitive_util::IsSignedIntegralType(prim_type); + return EmitIntegralMin( + max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed); + } else { + return Unimplemented("Clamp unimplemented for %s", + PrimitiveType_Name(prim_type).c_str()); + } +} + +StatusOr ElementalIrEmitter::EmitElementalConcatenate( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& target_index) const { + const int64 concat_dim = hlo->dimensions(0); + auto source_index = target_index; + + llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); + + // A terminator should be present iff we're emitting code + // into the middle (as opposed to the end) of a basic block. + CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), + init_block->getTerminator() == nullptr); + + llvm::BasicBlock* exit_block; + if (ir_builder_->GetInsertPoint() == init_block->end()) { + exit_block = llvm_ir::CreateBasicBlock( + /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); + } else { + exit_block = init_block->splitBasicBlock(ir_builder_->GetInsertPoint(), + AsStringRef(IrName(hlo, "merge"))); + init_block->getTerminator()->eraseFromParent(); + } + + llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); + llvm::PHINode* output = ir_builder_->CreatePHI( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + hlo->operands().size()); + auto prior_insert_point = ir_builder_->GetInsertPoint(); + + ir_builder_->SetInsertPoint(init_block); + + for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); + ++operand_idx) { + const HloInstruction* operand = hlo->operand(operand_idx); + auto true_block = llvm_ir::CreateBasicBlock( + exit_block, StrCat("concat_index_from_operand", operand_idx), + ir_builder_); + auto false_block = llvm_ir::CreateBasicBlock( + exit_block, StrCat("concat_index_not_from_operand", operand_idx), + ir_builder_); + auto concat_dim_size = + llvm::ConstantInt::get(source_index[concat_dim]->getType(), + operand->shape().dimensions(concat_dim)); + ir_builder_->CreateCondBr( + ir_builder_->CreateICmpULT(source_index[concat_dim], concat_dim_size), + true_block, false_block); + + // Create the terminator of the true block before calling operand + // generators, because they require non-degenerate basic blocks. + ir_builder_->SetInsertPoint( + llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); + TF_ASSIGN_OR_RETURN(llvm::Value * value, + operand_to_generator.at(operand)(source_index)); + output->addIncoming(value, ir_builder_->GetInsertBlock()); + + // Subtract the size of the concat dimension of the current operand + // from the source index. + ir_builder_->SetInsertPoint(false_block); + source_index[concat_dim] = + ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size); + } + + ir_builder_->CreateUnreachable(); + ir_builder_->SetInsertPoint(exit_block, prior_insert_point); + return output; +} + +StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + // Emit IR to read dynamic start indices from hlo->operand(1). + const HloInstruction* input_hlo = hlo->operand(0); + const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + llvm_ir::IrArray::Index slice_start_index(rank); + for (int64 i = 0; i < rank; ++i) { + llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, + operand_to_generator.at(hlo->operand(1))(dim_index)); + start_index_value->setName( + AsStringRef(IrName(hlo, StrCat("start_idx", i)))); + slice_start_index[i] = start_index_value; + } + + llvm_ir::IrArray::Index input_index(rank); + for (int64 i = 0; i < rank; ++i) { + // Emit IR which computes: + // input_index = (start_index + offset_index) % dim_size + // Security note: this is the code that keeps the indices in-bounds. + llvm::Value* dim_size = llvm::ConstantInt::get( + index[i]->getType(), input_hlo->shape().dimensions(i)); + llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast( + slice_start_index[i], index[i]->getType()); + input_index[i] = ir_builder_->CreateURem( + ir_builder_->CreateAdd(start_index, index[i]), dim_size); + } + return operand_to_generator.at(input_hlo)(input_index); +} + +StatusOr ElementalIrEmitter::EmitElementalGather( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + const Shape& operand_shape = hlo->operand(0)->shape(); + const Shape& indices_shape = hlo->operand(1)->shape(); + const Shape& output_shape = hlo->shape(); + + const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers(); + + const llvm_ir::ElementGenerator& operand_generator = + operand_to_generator.at(hlo->operand(0)); + const llvm_ir::ElementGenerator& indices_generator = + operand_to_generator.at(hlo->operand(1)); + + // This is the index into `operand` that holds the element we want to + // generate. This index "unsafe" as in the components in here may be + // out of bounds. + IrArray::Index unsafe_operand_index; + + // First copy in the window indices to unsafe_operand_index. + for (int64 i = 0, e = operand_shape.dimensions_size(), + unsafe_operand_index_dim = 0; + i < e; i++) { + if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + unsafe_operand_index.push_back(ir_builder_->getInt64(0)); + } else { + unsafe_operand_index.push_back( + index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]); + } + } + + // This is the index of the index vector in the gather_indices tensor. + IrArray::Index gather_index_index; + { + std::vector gather_index_index_components; + for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { + if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + gather_index_index.push_back(index[i]); + } + } + + if (gather_index_index.size() != indices_shape.dimensions_size()) { + gather_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr); + } + } + + auto add_to_unsafe_operand_index = [&](llvm::Value* index_component, + int64 dim) { + llvm::Value* gather_dim_component_extended = ir_builder_->CreateSExtOrTrunc( + index_component, ir_builder_->getInt64Ty()); + unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] = + ir_builder_->CreateAdd( + unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)], + gather_dim_component_extended); + }; + + if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { + TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, + indices_generator(gather_index_index)); + add_to_unsafe_operand_index(gather_dim_component, 0); + } else { + int64 index_vector_size = + indices_shape.dimensions(dim_numbers.index_vector_dim()); + for (int64 i = 0; i < index_vector_size; i++) { + gather_index_index[dim_numbers.index_vector_dim()] = + ir_builder_->getInt64(i); + TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, + indices_generator(gather_index_index)); + add_to_unsafe_operand_index(gather_dim_component, i); + } + } + + IrArray::Index safe_operand_index; + for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) { + safe_operand_index.push_back(ir_builder_->CreateURem( + unsafe_operand_index[i], + ir_builder_->getInt64(operand_shape.dimensions(i)))); + } + + return operand_generator(safe_operand_index); +} + +StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + const HloInstruction* input_hlo = hlo->operand(0); + const HloInstruction* update_hlo = hlo->operand(1); + const HloInstruction* start_hlo = hlo->operand(2); + // Calculate slice start/end indices. + const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + llvm_ir::IrArray::Index slice_start_index(rank); + llvm_ir::IrArray::Index slice_limit_index(rank); + // Slice starts at update[index - slice_start_index_adjusted], + // where adjusted value = slice_start_index when in bounds, and + // adjusted value = slice_start_index - input_dim, when wrapping. + llvm_ir::IrArray::Index slice_start_index_adjusted(rank); + + // Slice intersection gathers (ANDs) conditions on all ranks for which + // 'input' is set to 'update' + llvm::Value* slice_intersection = ir_builder_->getTrue(); + + for (int64 i = 0; i < rank; ++i) { + // Emit IR to read dynamic start indices from 'start_hlo'. + llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, + operand_to_generator.at(start_hlo)(dim_index)); + start_index_value->setName( + AsStringRef(IrName(hlo, StrCat("start_idx", i)))); + slice_start_index[i] = ir_builder_->CreateZExtOrBitCast( + start_index_value, index[i]->getType()); + + llvm::Value* input_dim_size = llvm::ConstantInt::get( + index[i]->getType(), input_hlo->shape().dimensions(i)); + llvm::Value* update_dim_size = llvm::ConstantInt::get( + index[i]->getType(), update_hlo->shape().dimensions(i)); + + // Generate code to handle wrapping semantics: + // slice_start_index[i] = slice_start_index[i] % input_dim_size; + // slice_limit_index[i] = slice_start_index[i] + update_dim_size. + // slice_start_index[i] is updated in place and it will now be in + // range. slice_limit_index[i] may be out of range, and it's being + // URem-ed below if so. + slice_start_index[i] = + ir_builder_->CreateURem(slice_start_index[i], input_dim_size); + slice_limit_index[i] = + ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); + + // Test if slice_limit_index[i] is in bounds + llvm::Value* in_bounds = + ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size); + llvm_ir::LlvmIfData if_in_bounds = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); + + // Handle true BB (slice_limit_index[i] <= input_dim_size). + SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); + // Check that index[i] >= slice_start_index[i] && + // index[i] < slice_limit_index[i] + llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd( + slice_intersection, + ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), + "slice_intersection_in"); + slice_intersection_in_bounds = ir_builder_->CreateAnd( + slice_intersection_in_bounds, + ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), + "slice_intersection_in"); + + // Handle false BB (slice_limit_index[i] > input_dim_size). + SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_); + // Check that index[i] >= slice_start_index[i] || + // index[i] < slice_limit_index[i]%input_dim_size. + llvm::Value* index_wraps = ir_builder_->CreateICmpSLT( + index[i], + ir_builder_->CreateURem(slice_limit_index[i], input_dim_size)); + llvm::Value* slice_intersection_or = ir_builder_->CreateOr( + ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), index_wraps, + "slice_intersection_out"); + llvm::Value* slice_intersection_out_of_bounds = ir_builder_->CreateAnd( + slice_intersection, slice_intersection_or, "slice_intersection_out"); + // Create value for slice_start_index_adjusted[i] when out of bounds. + // If within out-of-bounds if. + llvm_ir::LlvmIfData if_start_needs_adjustment = + llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_); + SetToFirstInsertPoint(if_start_needs_adjustment.true_block, ir_builder_); + llvm::Value* slice_start_index_adjusted_oob = + ir_builder_->CreateSub(slice_start_index[i], input_dim_size); + SetToFirstInsertPoint(if_start_needs_adjustment.after_block, ir_builder_); + llvm::PHINode* slice_start_index_adjusted_phi = + ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), 2); + slice_start_index_adjusted_phi->addIncoming( + slice_start_index_adjusted_oob, if_start_needs_adjustment.true_block); + slice_start_index_adjusted_phi->addIncoming( + slice_start_index[i], if_start_needs_adjustment.false_block); + // End of if within if. + + // After checking in/out of bounds. + SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_); + llvm::PHINode* phi_slice_intersection = + ir_builder_->CreatePHI(slice_intersection->getType(), 2); + phi_slice_intersection->addIncoming(slice_intersection_in_bounds, + if_in_bounds.true_block); + phi_slice_intersection->addIncoming(slice_intersection_out_of_bounds, + if_start_needs_adjustment.after_block); + slice_intersection = phi_slice_intersection; + + llvm::PHINode* phi_index = + ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2); + phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block); + phi_index->addIncoming(slice_start_index_adjusted_phi, + if_start_needs_adjustment.after_block); + slice_start_index_adjusted[i] = phi_index; + } + + // Emit: + // if (slice_intersection) -> return data from 'update'. + // else -> return data from 'input'. + llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + "ret_value_addr", ir_builder_); + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + slice_intersection, "slice_intersection", ir_builder_); + + // Handle true BB (return data from 'update') + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + // Compute update index for intersection case. + llvm_ir::IrArray::Index update_index(rank); + for (int64 i = 0; i < rank; ++i) { + llvm::Value* update_dim_size = llvm::ConstantInt::get( + index[i]->getType(), update_hlo->shape().dimensions(i)); + // NOTE: Subtraction will be positive due to bounds checking above. + update_index[i] = ir_builder_->CreateURem( + ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]), + update_dim_size); + } + TF_ASSIGN_OR_RETURN(llvm::Value * true_value, + operand_to_generator.at(update_hlo)(update_index)); + ir_builder_->CreateStore(true_value, ret_value_addr); + + // Handle false BB (return data from 'input') + SetToFirstInsertPoint(if_data.false_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * false_value, + operand_to_generator.at(input_hlo)(index)); + ir_builder_->CreateStore(false_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.after_block, ir_builder_); + return ir_builder_->CreateLoad(ret_value_addr); +} + +StatusOr ElementalIrEmitter::EmitElementalPad( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& padded_index) const { + auto index = padded_index; + llvm::Value* in_bounds = ir_builder_->getTrue(); + for (size_t i = 0; i < index.size(); ++i) { + auto index_typed_const = [=](int64 n) { + return llvm::ConstantInt::get(index[i]->getType(), n); + }; + const auto& pad_dim = hlo->padding_config().dimensions(i); + index[i] = ir_builder_->CreateSub( + index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = ir_builder_->CreateAnd( + in_bounds, ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), + "in_bounds"); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpEQ( + index_typed_const(0), + ir_builder_->CreateURem( + index[i], index_typed_const(pad_dim.interior_padding() + 1))), + "in_bounds"); + index[i] = ir_builder_->CreateSDiv( + index[i], index_typed_const(pad_dim.interior_padding() + 1)); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpSLT( + index[i], + index_typed_const(hlo->operand(0)->shape().dimensions(i))), + "in_bounds"); + } + + // if (in_bounds) { + // ret_value = operand0[index]; // source + // } else { + // ret_value = *operand1; // padding + // } + llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + "pad_result_addr", ir_builder_); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))(index)); + ir_builder_->CreateStore(operand_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.false_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, + operand_to_generator.at(hlo->operand(1))({})); + ir_builder_->CreateStore(padding_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.after_block, ir_builder_); + // Don't create phi(operand_value, padding_value) here, because invoking + // operand_to_generator may create new basic blocks, making the parent + // of operand_value or padding_value no longer a predecessor of + // if_data.after_block. + return ir_builder_->CreateLoad(ret_value_addr); +} + +StatusOr ElementalIrEmitter::EmitElementalDot( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& dot_result_index) const { + auto lhs_generator = operand_to_generator.at(hlo->operand(0)); + auto rhs_generator = operand_to_generator.at(hlo->operand(1)); + int64 contracted_dim_size = hlo->operand(0)->shape().dimensions( + hlo->operand(0)->shape().dimensions_size() - 1); + int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); + int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); + + std::unique_ptr inner_loop = llvm_ir::ForLoop::EmitForLoop( + IrName(hlo, "inner"), ir_builder_->getInt64(0), + ir_builder_->getInt64(contracted_dim_size), ir_builder_->getInt64(1), + ir_builder_); + + SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_); + PrimitiveType primitive_type = hlo->shape().element_type(); + llvm::Type* primitive_type_llvm = + llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); + llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry( + primitive_type_llvm, "dot_acc", ir_builder_); + ir_builder_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm), + accumulator_alloca); + + SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_); + + // This is the inner reduction loop for a dot operation that produces + // one element in the output. If the operands to the dot operation have + // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E]. + // Given an output index [a,b,c,d,e] in the result, we compute: + // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) + + IrArray::Index lhs_index, rhs_index; + + for (int64 i = 0; i < lhs_dims - 1; i++) { + lhs_index.push_back(dot_result_index[i]); + } + lhs_index.push_back(inner_loop->GetIndVarValue()); + + for (int64 i = 0; i < rhs_dims - 2; i++) { + rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); + } + rhs_index.push_back(inner_loop->GetIndVarValue()); + rhs_index.push_back(dot_result_index.back()); + + llvm::Value* current_accumulator = + ir_builder_->CreateLoad(accumulator_alloca); + TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); + TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); + llvm::Value* next_accumulator; + if (primitive_util::IsComplexType(primitive_type)) { + llvm::Value* product_real = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); + llvm::Value* product_imag = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value))); + next_accumulator = ir_builder_->CreateInsertValue( + current_accumulator, + ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), + product_real), + {0}); + next_accumulator = ir_builder_->CreateInsertValue( + next_accumulator, + ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), + product_imag), + {1}); + } else if (primitive_util::IsFloatingPointType(primitive_type)) { + next_accumulator = ir_builder_->CreateFAdd( + current_accumulator, ir_builder_->CreateFMul(lhs_value, rhs_value)); + } else { + next_accumulator = ir_builder_->CreateAdd( + current_accumulator, ir_builder_->CreateMul(lhs_value, rhs_value)); + } + ir_builder_->CreateStore(next_accumulator, accumulator_alloca); + + SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_); + return ir_builder_->CreateLoad(accumulator_alloca); +} + llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) @@ -1411,43 +1930,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSelect: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); - TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value, - operand_to_generator.at(hlo->operand(1))( - ElementwiseSourceIndex(index, *hlo, 1))); - TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, - operand_to_generator.at(hlo->operand(2))( - ElementwiseSourceIndex(index, *hlo, 2))); - return ir_builder_->CreateSelect( - ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()), - on_true_value, on_false_value); + return EmitElementalSelect(hlo, operand_to_generator, index); }; case HloOpcode::kClamp: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - TF_ASSIGN_OR_RETURN(llvm::Value * min_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); - TF_ASSIGN_OR_RETURN(llvm::Value * arg_value, - operand_to_generator.at(hlo->operand(1))( - ElementwiseSourceIndex(index, *hlo, 1))); - TF_ASSIGN_OR_RETURN(llvm::Value * max_value, - operand_to_generator.at(hlo->operand(2))( - ElementwiseSourceIndex(index, *hlo, 2))); - PrimitiveType prim_type = hlo->shape().element_type(); - if (primitive_util::IsFloatingPointType(prim_type)) { - return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); - } else if (primitive_util::IsIntegralType(prim_type)) { - bool is_signed = primitive_util::IsSignedIntegralType(prim_type); - return EmitIntegralMin( - max_value, EmitIntegralMax(min_value, arg_value, is_signed), - is_signed); - } else { - return Unimplemented("Clamp unimplemented for %s", - PrimitiveType_Name(prim_type).c_str()); - } + return EmitElementalClamp(hlo, operand_to_generator, index); }; case HloOpcode::kReducePrecision: return [this, hlo, &operand_to_generator]( @@ -1460,70 +1948,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kConcatenate: return [this, hlo, &operand_to_generator]( const IrArray::Index target_index) -> StatusOr { - const int64 concat_dim = hlo->dimensions(0); - auto source_index = target_index; - - llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); - - // A terminator should be present iff we're emitting code - // into the middle (as opposed to the end) of a basic block. - CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), - init_block->getTerminator() == nullptr); - - llvm::BasicBlock* exit_block; - if (ir_builder_->GetInsertPoint() == init_block->end()) { - exit_block = llvm_ir::CreateBasicBlock( - /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); - } else { - exit_block = init_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge"))); - init_block->getTerminator()->eraseFromParent(); - } - - llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); - llvm::PHINode* output = - ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType( - hlo->shape().element_type(), module_), - hlo->operands().size()); - auto prior_insert_point = ir_builder_->GetInsertPoint(); - - ir_builder_->SetInsertPoint(init_block); - - for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); - ++operand_idx) { - const HloInstruction* operand = hlo->operand(operand_idx); - auto true_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_from_operand", operand_idx), - ir_builder_); - auto false_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_not_from_operand", operand_idx), - ir_builder_); - auto concat_dim_size = - llvm::ConstantInt::get(source_index[concat_dim]->getType(), - operand->shape().dimensions(concat_dim)); - ir_builder_->CreateCondBr( - ir_builder_->CreateICmpULT(source_index[concat_dim], - concat_dim_size), - true_block, false_block); - - // Create the terminator of the true block before calling operand - // generators, because they require non-degenerate basic blocks. - ir_builder_->SetInsertPoint( - llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); - TF_ASSIGN_OR_RETURN(llvm::Value * value, - operand_to_generator.at(operand)(source_index)); - output->addIncoming(value, ir_builder_->GetInsertBlock()); - - // Subtract the size of the concat dimension of the current operand - // from the source index. - ir_builder_->SetInsertPoint(false_block); - source_index[concat_dim] = - ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size); - } - - ir_builder_->CreateUnreachable(); - ir_builder_->SetInsertPoint(exit_block, prior_insert_point); - return output; + return EmitElementalConcatenate(hlo, operand_to_generator, + target_index); }; case HloOpcode::kReverse: return [this, hlo, &operand_to_generator]( @@ -1559,270 +1985,19 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kDynamicSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - // Emit IR to read dynamic start indices from hlo->operand(1). - const HloInstruction* input_hlo = hlo->operand(0); - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); - TF_ASSIGN_OR_RETURN( - llvm::Value * start_index_value, - operand_to_generator.at(hlo->operand(1))(dim_index)); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = start_index_value; - } - - llvm_ir::IrArray::Index input_index(rank); - for (int64 i = 0; i < rank; ++i) { - // Emit IR which computes: - // input_index = (start_index + offset_index) % dim_size - // Security note: this is the code that keeps the indices in-bounds. - llvm::Value* dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast( - slice_start_index[i], index[i]->getType()); - input_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateAdd(start_index, index[i]), dim_size); - } - return operand_to_generator.at(input_hlo)(input_index); + return EmitElementalDynamicSlice(hlo, operand_to_generator, index); }; case HloOpcode::kGather: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - const Shape& operand_shape = hlo->operand(0)->shape(); - const Shape& indices_shape = hlo->operand(1)->shape(); - const Shape& output_shape = hlo->shape(); - - const GatherDimensionNumbers& dim_numbers = - hlo->gather_dimension_numbers(); - - const llvm_ir::ElementGenerator& operand_generator = - operand_to_generator.at(hlo->operand(0)); - const llvm_ir::ElementGenerator& indices_generator = - operand_to_generator.at(hlo->operand(1)); - - // This is the index into `operand` that holds the element we want to - // generate. This index "unsafe" as in the components in here may be - // out of bounds. - IrArray::Index unsafe_operand_index; - - // First copy in the window indices to unsafe_operand_index. - for (int64 i = 0, e = operand_shape.dimensions_size(), - unsafe_operand_index_dim = 0; - i < e; i++) { - if (c_binary_search(dim_numbers.elided_window_dims(), i)) { - unsafe_operand_index.push_back(ir_builder_->getInt64(0)); - } else { - unsafe_operand_index.push_back(index[dim_numbers.output_window_dims( - unsafe_operand_index_dim++)]); - } - } - - // This is the index of the index vector in the gather_indices tensor. - IrArray::Index gather_index_index; - { - std::vector gather_index_index_components; - for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { - if (!c_binary_search(dim_numbers.output_window_dims(), i)) { - gather_index_index.push_back(index[i]); - } - } - - if (gather_index_index.size() != indices_shape.dimensions_size()) { - gather_index_index.InsertAt(dim_numbers.index_vector_dim(), - nullptr); - } - } - - auto add_to_unsafe_operand_index = [&](llvm::Value* index_component, - int64 dim) { - llvm::Value* gather_dim_component_extended = - ir_builder_->CreateSExtOrTrunc(index_component, - ir_builder_->getInt64Ty()); - unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] = - ir_builder_->CreateAdd( - unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims( - dim)], - gather_dim_component_extended); - }; - - if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { - TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, - indices_generator(gather_index_index)); - add_to_unsafe_operand_index(gather_dim_component, 0); - } else { - int64 index_vector_size = - indices_shape.dimensions(dim_numbers.index_vector_dim()); - for (int64 i = 0; i < index_vector_size; i++) { - gather_index_index[dim_numbers.index_vector_dim()] = - ir_builder_->getInt64(i); - TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, - indices_generator(gather_index_index)); - add_to_unsafe_operand_index(gather_dim_component, i); - } - } - - IrArray::Index safe_operand_index; - for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) { - safe_operand_index.push_back(ir_builder_->CreateURem( - unsafe_operand_index[i], - ir_builder_->getInt64(operand_shape.dimensions(i)))); - } - - return operand_generator(safe_operand_index); + return EmitElementalGather(hlo, operand_to_generator, index); }; case HloOpcode::kDynamicUpdateSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - const HloInstruction* input_hlo = hlo->operand(0); - const HloInstruction* update_hlo = hlo->operand(1); - const HloInstruction* start_hlo = hlo->operand(2); - // Calculate slice start/end indices. - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); - llvm_ir::IrArray::Index slice_limit_index(rank); - // Slice starts at update[index - slice_start_index_adjusted], - // where adjusted value = slice_start_index when in bounds, and - // adjusted value = slice_start_index - input_dim, when wrapping. - llvm_ir::IrArray::Index slice_start_index_adjusted(rank); - - // Slice intersection gathers (ANDs) conditions on all ranks for which - // 'input' is set to 'update' - llvm::Value* slice_intersection = ir_builder_->getTrue(); - - for (int64 i = 0; i < rank; ++i) { - // Emit IR to read dynamic start indices from 'start_hlo'. - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); - TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, - operand_to_generator.at(start_hlo)(dim_index)); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = ir_builder_->CreateZExtOrBitCast( - start_index_value, index[i]->getType()); - - llvm::Value* input_dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - - // Generate code to handle wrapping semantics: - // slice_start_index[i] = slice_start_index[i] % input_dim_size; - // slice_limit_index[i] = slice_start_index[i] + update_dim_size. - // slice_start_index[i] is updated in place and it will now be in - // range. slice_limit_index[i] may be out of range, and it's being - // URem-ed below if so. - slice_start_index[i] = - ir_builder_->CreateURem(slice_start_index[i], input_dim_size); - slice_limit_index[i] = - ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); - - // Test if slice_limit_index[i] is in bounds - llvm::Value* in_bounds = - ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size); - llvm_ir::LlvmIfData if_in_bounds = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - - // Handle true BB (slice_limit_index[i] <= input_dim_size). - SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] && - // index[i] < slice_limit_index[i] - llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd( - slice_intersection, - ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection_in"); - slice_intersection_in_bounds = ir_builder_->CreateAnd( - slice_intersection_in_bounds, - ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection_in"); - - // Handle false BB (slice_limit_index[i] > input_dim_size). - SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] || - // index[i] < slice_limit_index[i]%input_dim_size. - llvm::Value* index_wraps = ir_builder_->CreateICmpSLT( - index[i], - ir_builder_->CreateURem(slice_limit_index[i], input_dim_size)); - llvm::Value* slice_intersection_or = ir_builder_->CreateOr( - ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), - index_wraps, "slice_intersection_out"); - llvm::Value* slice_intersection_out_of_bounds = - ir_builder_->CreateAnd(slice_intersection, slice_intersection_or, - "slice_intersection_out"); - // Create value for slice_start_index_adjusted[i] when out of bounds. - // If within out-of-bounds if. - llvm_ir::LlvmIfData if_start_needs_adjustment = - llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_); - SetToFirstInsertPoint(if_start_needs_adjustment.true_block, - ir_builder_); - llvm::Value* slice_start_index_adjusted_oob = - ir_builder_->CreateSub(slice_start_index[i], input_dim_size); - SetToFirstInsertPoint(if_start_needs_adjustment.after_block, - ir_builder_); - llvm::PHINode* slice_start_index_adjusted_phi = - ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), - 2); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index_adjusted_oob, - if_start_needs_adjustment.true_block); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index[i], if_start_needs_adjustment.false_block); - // End of if within if. - - // After checking in/out of bounds. - SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_); - llvm::PHINode* phi_slice_intersection = - ir_builder_->CreatePHI(slice_intersection->getType(), 2); - phi_slice_intersection->addIncoming(slice_intersection_in_bounds, - if_in_bounds.true_block); - phi_slice_intersection->addIncoming( - slice_intersection_out_of_bounds, - if_start_needs_adjustment.after_block); - slice_intersection = phi_slice_intersection; - - llvm::PHINode* phi_index = - ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2); - phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block); - phi_index->addIncoming(slice_start_index_adjusted_phi, - if_start_needs_adjustment.after_block); - slice_start_index_adjusted[i] = phi_index; - } - - // Emit: - // if (slice_intersection) -> return data from 'update'. - // else -> return data from 'input'. - llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - module_), - "ret_value_addr", ir_builder_); - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - slice_intersection, "slice_intersection", ir_builder_); - - // Handle true BB (return data from 'update') - SetToFirstInsertPoint(if_data.true_block, ir_builder_); - // Compute update index for intersection case. - llvm_ir::IrArray::Index update_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - // NOTE: Subtraction will be positive due to bounds checking above. - update_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]), - update_dim_size); - } - TF_ASSIGN_OR_RETURN(llvm::Value * true_value, - operand_to_generator.at(update_hlo)(update_index)); - ir_builder_->CreateStore(true_value, ret_value_addr); - - // Handle false BB (return data from 'input') - SetToFirstInsertPoint(if_data.false_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * false_value, - operand_to_generator.at(input_hlo)(index)); - ir_builder_->CreateStore(false_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - return ir_builder_->CreateLoad(ret_value_addr); + return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator, + index); }; case HloOpcode::kBitcast: CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()), @@ -1851,155 +2026,16 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kRng: return MakeRngElementGenerator(hlo, operand_to_generator); case HloOpcode::kPad: - return [=, &operand_to_generator]( + return [this, hlo, &operand_to_generator]( const IrArray::Index& padded_index) -> StatusOr { - auto index = padded_index; - llvm::Value* in_bounds = ir_builder_->getTrue(); - for (size_t i = 0; i < index.size(); ++i) { - auto index_typed_const = [=](int64 n) { - return llvm::ConstantInt::get(index[i]->getType(), n); - }; - const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = ir_builder_->CreateSub( - index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), - "in_bounds"); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpEQ( - index_typed_const(0), - ir_builder_->CreateURem( - index[i], - index_typed_const(pad_dim.interior_padding() + 1))), - "in_bounds"); - index[i] = ir_builder_->CreateSDiv( - index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpSLT( - index[i], - index_typed_const(hlo->operand(0)->shape().dimensions(i))), - "in_bounds"); - } - - // if (in_bounds) { - // ret_value = operand0[index]; // source - // } else { - // ret_value = *operand1; // padding - // } - llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - module_), - "pad_result_addr", ir_builder_); - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - SetToFirstInsertPoint(if_data.true_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(0))(index)); - ir_builder_->CreateStore(operand_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.false_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, - operand_to_generator.at(hlo->operand(1))({})); - ir_builder_->CreateStore(padding_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - // Don't create phi(operand_value, padding_value) here, because invoking - // operand_to_generator may create new basic blocks, making the parent - // of operand_value or padding_value no longer a predecessor of - // if_data.after_block. - return ir_builder_->CreateLoad(ret_value_addr); + return EmitElementalPad(hlo, operand_to_generator, padded_index); }; case HloOpcode::kDot: - return [=, &operand_to_generator](const IrArray::Index& dot_result_index) + return [this, hlo, + &operand_to_generator](const IrArray::Index& dot_result_index) -> StatusOr { - auto lhs_generator = operand_to_generator.at(hlo->operand(0)); - auto rhs_generator = operand_to_generator.at(hlo->operand(1)); - int64 contracted_dim_size = hlo->operand(0)->shape().dimensions( - hlo->operand(0)->shape().dimensions_size() - 1); - int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); - int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); - - std::unique_ptr inner_loop = - llvm_ir::ForLoop::EmitForLoop( - IrName(hlo, "inner"), ir_builder_->getInt64(0), - ir_builder_->getInt64(contracted_dim_size), - ir_builder_->getInt64(1), ir_builder_); - - SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), - ir_builder_); - PrimitiveType primitive_type = hlo->shape().element_type(); - llvm::Type* primitive_type_llvm = - llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); - llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry( - primitive_type_llvm, "dot_acc", ir_builder_); - ir_builder_->CreateStore( - llvm::Constant::getNullValue(primitive_type_llvm), - accumulator_alloca); - - SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_); - - // This is the inner reduction loop for a dot operation that produces - // one element in the output. If the operands to the dot operation have - // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E]. - // Given an output index [a,b,c,d,e] in the result, we compute: - // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) - - IrArray::Index lhs_index, rhs_index; - - for (int64 i = 0; i < lhs_dims - 1; i++) { - lhs_index.push_back(dot_result_index[i]); - } - lhs_index.push_back(inner_loop->GetIndVarValue()); - - for (int64 i = 0; i < rhs_dims - 2; i++) { - rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); - } - rhs_index.push_back(inner_loop->GetIndVarValue()); - rhs_index.push_back(dot_result_index.back()); - - llvm::Value* current_accumulator = - ir_builder_->CreateLoad(accumulator_alloca); - TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); - TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); - llvm::Value* next_accumulator; - if (primitive_util::IsComplexType(primitive_type)) { - llvm::Value* product_real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); - llvm::Value* product_imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value))); - next_accumulator = ir_builder_->CreateInsertValue( - current_accumulator, - ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), - product_real), - {0}); - next_accumulator = ir_builder_->CreateInsertValue( - next_accumulator, - ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), - product_imag), - {1}); - } else if (primitive_util::IsFloatingPointType(primitive_type)) { - next_accumulator = ir_builder_->CreateFAdd( - current_accumulator, - ir_builder_->CreateFMul(lhs_value, rhs_value)); - } else { - next_accumulator = ir_builder_->CreateAdd( - current_accumulator, - ir_builder_->CreateMul(lhs_value, rhs_value)); - } - ir_builder_->CreateStore(next_accumulator, accumulator_alloca); - - SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_); - return ir_builder_->CreateLoad(accumulator_alloca); + return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; default: return [this, hlo, &operand_to_generator](const IrArray::Index& index) { diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index c516a82..26dff0d 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -142,6 +142,46 @@ class ElementalIrEmitter { return ir_builder_->getIntN(128, 0); } + StatusOr EmitElementalSelect( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalClamp( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalConcatenate( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& target_index) const; + + StatusOr EmitElementalDynamicSlice( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalGather( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalDynamicUpdateSlice( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalPad( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& padded_index) const; + + StatusOr EmitElementalDot( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& dot_result_index) const; + llvm::IRBuilder<>* const ir_builder_; llvm::Module* module_; -- 2.7.4