From 2b932c28d60ec4f248d950cd2ef69a7eb98bb66d Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 2 Feb 2018 18:57:10 -0800 Subject: [PATCH] [XLA] Minor cleanups related to multi-output fusion. - Add some comments about preexisting invariants, and add some CHECKs. - In the LoopEmitter constructor, materialize the given ArraySlice to a vector, so we don't rely on the given ArraySlice having any particular lifetime. - Add the invariant that the LoopEmitter constructor which takes a list of IrArrays is only for multi-output fusion. Previously it said: If you only pass one array, then treat it as regular fusion. But this results in an LLVM type mismatch, because the given target_element_generator should be passing a struct with one element. PiperOrigin-RevId: 184365310 --- .../compiler/xla/service/gpu/hlo_to_ir_bindings.cc | 6 ++- .../xla/service/gpu/ir_emitter_unnested.cc | 22 ++++----- .../xla/service/gpu/parallel_loop_emitter.h | 5 ++ tensorflow/compiler/xla/service/hlo_instruction.h | 4 +- .../xla/service/llvm_ir/fused_ir_emitter.h | 6 +++ .../compiler/xla/service/llvm_ir/loop_emitter.cc | 57 ++++++++++++---------- .../compiler/xla/service/llvm_ir/loop_emitter.h | 8 ++- 7 files changed, 66 insertions(+), 42 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index c2115c4..dd4426c 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -191,7 +191,11 @@ static bool BuffersInvariantWithinConsumer( llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, const HloInstruction& consumer, const ShapeIndex& shape_index) { - llvm_ir::IrArray ir_array(GetBasePointer(hlo, shape_index), + llvm::Value* base_ptr = GetBasePointer(hlo, shape_index); + CHECK_NE(base_ptr, nullptr) + << "Buffer not assigned for shape_index " << shape_index.ToString() + << " of " << hlo.ToString(); + llvm_ir::IrArray ir_array(base_ptr, ShapeUtil::GetSubshape(hlo.shape(), shape_index)); alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 4072573..a4847f6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1657,24 +1657,24 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { } Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { - tensorflow::gtl::ArraySlice operands(tuple->operands()); - bool all_tuple_elements_have_buffer = std::all_of( - operands.begin(), operands.end(), [this](HloInstruction* tuple_element) { + bool all_tuple_elements_have_buffer = + c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { return ir_emitter_context_->buffer_assignment().HasTopLevelAllocation( tuple_element); }); - // Tuples (especially output tuples) can take too many tuple elements, - // causing the kernel emitted exceeds the parameter space limit - // (b/31336476). As an optimization, if all tuple elements have a buffer, we - // collect their buffer addresses in a host array, and then copy that array - // to the tuple's buffer. + // Tuples (especially tuples that are the final result of a computation) can + // be so huge that if we were to emit a kernel that took each tuple element as + // a parameter, we would exceed the max allowable number of parameters to a + // GPU kernel, b/31336476. As an optimization, if all tuple elements have a + // buffer, we collect their buffer addresses in a host array, and then copy + // that array to the tuple's buffer. // // Some tuple elements (e.g. const or bitcast of const) might not have a - // buffer -- their contents are stored in code. In that case, we fall back - // to emitting kernels which have access to their buffer addresses in code. + // buffer -- their contents are stored in code. In that case, we fall back to + // emitting kernels which have access to their buffer addresses in code. if (all_tuple_elements_have_buffer) { std::vector tuple_element_buffers; - for (const HloInstruction* tuple_element : operands) { + for (const HloInstruction* tuple_element : tuple->operands()) { tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); } thunk_sequence_->emplace_back(MakeUnique( diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 934e7e1..8ed63a8 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -42,6 +42,11 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder); + // Constructs a loop emitter for a loop that generates on element of each of N + // arrays on each iteration. + // + // This is used in multi-output fusion. target_element_generator should + // produce a struct with N elements, one for each of target_arrays. ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 84b4696..bce9ebd 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -874,8 +874,8 @@ class HloInstruction { // Returns true if this instruction is a fusion instruction that generates // multiple outputs. const bool IsMultiOutputFusion() const { - return (opcode() == HloOpcode::kFusion && - fused_expression_root()->opcode() == HloOpcode::kTuple); + return opcode() == HloOpcode::kFusion && + fused_expression_root()->opcode() == HloOpcode::kTuple; } FusionKind fusion_kind() const { diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 9ad7cd8..242062e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -34,6 +34,12 @@ namespace xla { // Unlike IrEmitter, this creates host functions which emit IR to generate the // output element at the given index. It is used to generate fused operations. +// +// This class handles both vanilla fusion and multi-output fusion. In the MOF +// case, the fusion node ends with a kTuple instruction, and the root generator +// returned by this emitter returns an LLVM struct with N elements, one for each +// element of the arrays in the tuple. It follows that the arrays in the tuple +// must have the same length. class FusedIrEmitter : public DfsHloVisitorWithDefault { public: using Generator = llvm_ir::ElementGenerator; diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index a5f7c85..b6b918e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -51,37 +51,40 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, shape_(target_array.GetShape()), ir_builder_(ir_builder) {} +static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( + const ElementGenerator& target_element_generator, + const std::vector& target_arrays, llvm::IRBuilder<>* ir_builder) { + return [=](const llvm_ir::IrArray::Index array_index) { + TF_ASSIGN_OR_RETURN(llvm::Value * target_element, + target_element_generator(array_index)); + CHECK(target_element->getType()->isStructTy()) + << "This BodyEmitter is for multi-output fusion, but target element " + "generator does not produce values of struct type."; + CHECK_EQ(target_element->getType()->getStructNumElements(), + target_arrays.size()); + + for (int64 i = 0; i < target_arrays.size(); ++i) { + target_arrays[i].EmitWriteArrayElement( + array_index, ir_builder->CreateExtractValue(target_element, i), + ir_builder); + } + return Status::OK(); + }; +} + LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, llvm::IRBuilder<>* ir_builder) - : body_emitter_([=](const llvm_ir::IrArray::Index array_index) - -> ::tensorflow::Status { - // Convert target_element_generator to a BodyEmitter. - TF_ASSIGN_OR_RETURN(llvm::Value * target_element, - target_element_generator(array_index)); - if (target_arrays.size() == 1) { - target_arrays[0].EmitWriteArrayElement(array_index, target_element, - ir_builder); - return tensorflow::Status::OK(); - } - - for (int64 i = 0; i < target_arrays.size(); ++i) { - target_arrays[i].EmitWriteArrayElement( - array_index, ir_builder_->CreateExtractValue(target_element, i), - ir_builder); - } - return tensorflow::Status::OK(); - }), + : body_emitter_(MakeBodyEmitterForMultiOutputFusion( + target_element_generator, + std::vector(target_arrays.begin(), target_arrays.end()), + ir_builder)), + shape_(target_arrays[0].GetShape()), ir_builder_(ir_builder) { - if (target_arrays.size() > 1) { - // The sanity check for multiple outputs. - shape_ = target_arrays[0].GetShape(); - for (int64 i = 1; i < target_arrays.size(); ++i) { - const Shape& element_shape = target_arrays[i].GetShape(); - CHECK(ShapeUtil::SameDimensions(shape_, element_shape)); - } - } else { - shape_ = target_arrays[0].GetShape(); + // Sanity check: In multi-output fusion, all shapes produced must have the + // same dimensions. + for (const IrArray& array : target_arrays) { + CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())); } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 1ef1dc2..0fc5284 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -47,10 +47,16 @@ class LoopEmitter { // element of the given target array. LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* ir_builder); - // Same as previous method except emits multiple targets in an array. + + // Constructs a LoopEmitter that emits one element into each of N separate + // arrays on each iteration of the loop. + // + // This is used for multi-output fusion. target_element_generator must + // produce an LLVM struct with N elements. LoopEmitter(const ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, llvm::IRBuilder<>* ir_builder); + LoopEmitter(const LoopEmitter&) = delete; LoopEmitter& operator=(const LoopEmitter&) = delete; virtual ~LoopEmitter() = default; -- 2.7.4