From 2c83cddab0cd0de78f863e47d81b4427d6519eb7 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Tue, 15 May 2018 10:23:27 -0700 Subject: [PATCH] [XLA] Cache computations when creating reduces in algebraic simplifier or batchnorm expander Otherwise we create a lot of identical small computations. This shouldn't have an effect except for cluttering the HLO, but turns out HloCSE doesn't look inside of the computation of reduces, effectively never eliminating reduces that were produced via this code path. While there clean up some YAGNI, this only worked for F32 anyways, so just hardcode it. PiperOrigin-RevId: 196689316 --- tensorflow/compiler/xla/service/BUILD | 2 - .../compiler/xla/service/algebraic_simplifier.cc | 46 ++++++++++---------- .../compiler/xla/service/batchnorm_expander.cc | 49 +++++++++++++--------- 3 files changed, 54 insertions(+), 43 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 1049083..04a9a4a 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1271,13 +1271,11 @@ cc_library( deps = [ ":hlo", ":hlo_pass", - ":hlo_query", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 3ce80bb..f732ed8 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -92,26 +92,6 @@ bool ReshapeIsBitcast( valid_bitcast_callback(operand->shape(), reshape->shape()); } -// Adds a scalar computation to the module to enable optimizations with dot -// converting into reduction. -HloComputation* CreateScalarBinaryComputation(HloModule* module, - PrimitiveType primitive_type, - HloOpcode opcode) { - HloComputation::Builder b("scalar_computation"); - auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), - opcode, scalar_lhs, scalar_rhs)); - HloComputation* scalar_computation = - module->AddEmbeddedComputation(b.Build(scalar_op)); - return scalar_computation; -} - -} // namespace - // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain // algebraic expressions to simplified forms. Note: This only supports // simplifications that simply look at the operands of an instruction. For the @@ -220,8 +200,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloComputation* AddReduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); + HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( shape, hlo, zero, {dim}, AddReduce_computation)); @@ -293,6 +272,24 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { StatusOr OptimizeDotOfGather(HloInstruction* dot); + HloComputation* GetOrCreateScalarAddComputation() { + if (scalar_add_computation_) { + return scalar_add_computation_; + } + + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(F32, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, shape, "scalar_rhs")); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); + scalar_add_computation_ = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return scalar_add_computation_; + } + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -311,8 +308,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable convolution simplification on platforms where it causes a slowdown. bool enable_conv_simplification_; + + // Cached computation for adding two scalar F32. + HloComputation* scalar_add_computation_ = nullptr; }; +} // namespace + bool AlgebraicSimplifierVisitor::Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 38086bd..96e02b8 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -15,35 +15,32 @@ limitations under the License. #include "tensorflow/compiler/xla/service/batchnorm_expander.h" -#include #include -#include -#include #include #include #include -#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +namespace { + // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { @@ -80,17 +77,25 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { rewrite_grad_op_(rewrite_grad_op), use_fusion_(use_fusion) {} - HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type, - HloOpcode opcode) { - HloComputation::Builder b("scalar_computation"); - auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), - opcode, scalar_lhs, scalar_rhs)); - return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + HloComputation* GetOrCreateScalarAddComputation( + PrimitiveType primitive_type) { + HloComputation** scalar_add_computation = + &scalar_add_computations_[primitive_type]; + if (*scalar_add_computation) { + return *scalar_add_computation; + } + + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(primitive_type, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, shape, "scalar_rhs")); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); + *scalar_add_computation = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return *scalar_add_computation; } // Current HloComputation instance the BatchNormExpander is @@ -105,6 +110,10 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { // Whether rewrite has occurred. bool changed_ = false; + // Cached computations for adding two scalars. + tensorflow::gtl::FlatMap + scalar_add_computations_; + // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. // Returns the Status representing the result of the replace operation. @@ -129,6 +138,8 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { } }; +} // namespace + bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, @@ -199,7 +210,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(ptype, HloOpcode::kAdd); + GetOrCreateScalarAddComputation(ptype); // X^2. auto operand_squared = add(HloInstruction::CreateBinary( @@ -500,7 +511,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( grad_output, activation_minus_mean)); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(ptype, HloOpcode::kAdd); + GetOrCreateScalarAddComputation(ptype); // sum(Grad[Y] * (X - E[X])). auto sum_grad_output_times_activiation_minus_mean = -- 2.7.4