From e8e1f1083d2c6ba4c2fb4e7e804d36624775e971 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Wed, 24 Jan 2018 11:42:22 -0800 Subject: [PATCH] [BatchNorm] Remove CPU implementation We now use batchnorm rewriter (tensorflow/compiler/xla/service/batchnorm_rewriter.h) to expand batch norm into smaller ops. A specific implementation should not be needed anymore (for CPU). RELNOTES:n/a PiperOrigin-RevId: 183117252 --- .../compiler/xla/service/cpu/ir_emitter.cc | 199 ------------------ .../compiler/xla/service/cpu/ir_emitter.h | 2 - 2 files changed, 201 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index ac8d5d33fa..b03a9f9aa5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1226,205 +1226,6 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex( return index_with_free_var; } -Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { - // The output of BatchNormTraining is a tuple of three element: - // - An N-dimensional array containing normalized values. - // - A 1 dimensional array containing the mean value for each feature. - // - A 1 dimensional array containing the variance value for each feature. - HloInstruction* operand = batch_norm_training->operands()[0]; - HloInstruction* scale = batch_norm_training->operands()[1]; - HloInstruction* offset = batch_norm_training->operands()[2]; - float epsilon = batch_norm_training->epsilon(); - int64 feature_index = batch_norm_training->feature_index(); - TF_RET_CHECK(ShapeUtil::IsTuple(batch_norm_training->shape()) && - ShapeUtil::TupleElementCount(batch_norm_training->shape()) == 3); - - const Shape& output_shape = - ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 0); - const Shape& feature_shape = - ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 1); - - // Reduce vector of the non-feature dimensions. - std::vector dimensions_to_reduce; - - for (int64 i = 0; i < operand->shape().dimensions_size(); ++i) { - if (i != feature_index) { - dimensions_to_reduce.push_back(i); - } - } - - // Get the second and third allocations in the output tuple, which should be - // used to store the result of mean and variance value calculation. - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice_mean, - assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{1})); - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice_var, - assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{2})); - const int feature_count = output_shape.dimensions(feature_index); - const int size_in_elements = ShapeUtil::ElementsIn(output_shape); - TF_RET_CHECK(ShapeUtil::ElementsIn(operand->shape()) == size_in_elements); - const int elements_per_feature = size_in_elements / feature_count; - - llvm::Value* mean = EmitTempBufferPointer(slice_mean, feature_shape); - llvm_ir::IrArray mean_array(mean, feature_shape); - - llvm::Value* var = EmitTempBufferPointer(slice_var, feature_shape); - llvm_ir::IrArray var_array(var, feature_shape); - - // This loop calculates mean and variance for each feature. - // - // In theory this could be swapped by multi-output fusion. We will evaluate - // this when it's ready. - // - // For variance calculation, we use a simplified formula so we can fuse the - // computation into the same loop to calculate mean: Var=E(X^2) - E(X)^2. - TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter( - [&](const llvm_ir::IrArray::Index& index) { - PrimitiveType element_type = operand->shape().element_type(); - // Used to calculate E(X). - llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - "sum_address", &ir_builder_, - MinimumAlignmentForPrimitiveType(element_type)); - - // Used to calculate E(X^2). - llvm::Value* sum_square_address = - llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - "sum_square_address", &ir_builder_, - MinimumAlignmentForPrimitiveType(element_type)); - - ir_builder_.CreateStore( - llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), - sum_address); - - ir_builder_.CreateStore( - llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), - sum_square_address); - - llvm_ir::ForLoopNest loops(IrName(batch_norm_training, "inner"), - &ir_builder_); - - const llvm_ir::IrArray::Index reduced_dims_index = - loops.AddLoopsForShapeOnDimensions( - operand->shape(), dimensions_to_reduce, "reduction_dim"); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), - &ir_builder_); - - llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); - llvm_ir::IrArray::Index input_index = - FillReducedDimensionIndex(reduced_dims_index, index); - llvm::Value* new_value = - operand_array.EmitReadArrayElement(input_index, &ir_builder_); - - llvm::Value* new_value_square = - ir_builder_.CreateFMul(new_value, new_value); - - llvm::Value* current_sum = ir_builder_.CreateLoad(sum_address); - llvm::Value* current_sum_square = - ir_builder_.CreateLoad(sum_square_address); - // Update sum. - ir_builder_.CreateStore( - ir_builder_.CreateFAdd(current_sum, new_value), sum_address); - - // Update sum square. - ir_builder_.CreateStore( - ir_builder_.CreateFAdd(current_sum_square, new_value_square), - sum_square_address); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), - &ir_builder_); - - llvm::Value* sum = ir_builder_.CreateLoad(sum_address); - llvm::Value* elements_per_feature_value = llvm::ConstantFP::get( - ir_builder_.getFloatTy(), elements_per_feature); - llvm::Value* mean = - ir_builder_.CreateFDiv(sum, elements_per_feature_value); - llvm::Value* mean_square = ir_builder_.CreateFMul(mean, mean); - llvm::Value* sum_square = - ir_builder_.CreateLoad(sum_square_address); - - // Var=E(X^2) - E(X)^2. - llvm::Value* var = ir_builder_.CreateFSub( - ir_builder_.CreateFDiv(sum_square, elements_per_feature_value), - mean_square); - - var_array.EmitWriteArrayElement(index, var, &ir_builder_); - return mean; - }, - mean_array, &ir_builder_) - .EmitLoop(IrName(batch_norm_training, "mean_var"))); - - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(batch_norm_training)); - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice, - assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{0})); - - llvm::Value* normalized = EmitTempBufferPointer(slice, output_shape); - - llvm_ir::IrArray target_array(normalized, output_shape); - - AddAliasingInformationToIrArray(*batch_norm_training, &target_array); - - TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter( - [this, mean_array, var_array, epsilon, operand, dimensions_to_reduce, - feature_index, offset, scale](const llvm_ir::IrArray::Index& index) { - // The following logic normalizes the input value, scales and shifts - // it: - // - // normalized = (input - mean) / sqrt(variance + epsilon) - // result = normalized * scale + offset - - // Current index in the feature dimension. - llvm_ir::IrArray::Index feature_index_value(1, - index[feature_index]); - - llvm::Value* mean = mean_array.EmitReadArrayElement( - feature_index_value, &ir_builder_); - llvm::Value* var = var_array.EmitReadArrayElement( - feature_index_value, &ir_builder_); - - llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); - llvm::Value* input = - operand_array.EmitReadArrayElement(index, &ir_builder_); - - llvm::Value* variance_with_epsilon = ir_builder_.CreateFAdd( - var, llvm::ConstantFP::get(ir_builder_.getFloatTy(), epsilon)); - llvm::Function* func_llvm_sqrt = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::sqrt, {ir_builder_.getFloatTy()}); - llvm::Value* variance_sqrt = - ir_builder_.CreateCall(func_llvm_sqrt, {variance_with_epsilon}); - llvm::Value* normalized = ir_builder_.CreateFDiv( - ir_builder_.CreateFSub(input, mean), variance_sqrt); - llvm_ir::IrArray offset_array(GetIrArrayFor(offset)); - llvm::Value* offset = offset_array.EmitReadArrayElement( - feature_index_value, &ir_builder_); - llvm_ir::IrArray scale_array(GetIrArrayFor(scale)); - llvm::Value* scale = scale_array.EmitReadArrayElement( - feature_index_value, &ir_builder_); - llvm::Value* result = ir_builder_.CreateFAdd( - ir_builder_.CreateFMul(normalized, scale), offset); - - return result; - }, - target_array, &ir_builder_) - .EmitLoop(IrName(batch_norm_training, "normalize"))); - - llvm_ir::EmitTuple(GetIrArrayFor(batch_norm_training), - {normalized, mean, var}, &ir_builder_, module_); - return Status::OK(); -} - -Status IrEmitter::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { - // TODO(b/62843645) Implement BatchNormGrad on CPU backend. - return Unimplemented( - "BatchNormGrad is not implemented on CPU. See b/62843645."); -} - Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); auto param_number = parameter->parameter_number(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 66f2aeeab3..5094402514 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -125,8 +125,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; - Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; - Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; -- 2.34.1