From: Justin Lebar Date: Tue, 6 Feb 2018 01:11:39 +0000 (-0800) Subject: [XLA:GPU] Split IrEmitter{Unn,N}ested out of ir_emitter.h. X-Git-Tag: upstream/v1.7.0~31^2~985 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=fa99ec49976c92cd8ea15d2a487f52554de3659d;p=platform%2Fupstream%2Ftensorflow.git [XLA:GPU] Split IrEmitter{Unn,N}ested out of ir_emitter.h. Also add a bunch of clarifying comments. PiperOrigin-RevId: 184610674 --- diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 7df01f7..9da4fb9 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -129,6 +129,8 @@ cc_library( hdrs = [ "ir_emitter.h", "ir_emitter_context.h", + "ir_emitter_nested.h", + "ir_emitter_unnested.h", ], deps = [ ":cudnn_convolution_runner", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 12ec266..28ebd03 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -47,8 +47,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index affd2ff..a3df67a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 9031a83..b0accc0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -13,19 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// An XLA HLO graph may contain multiple computations. These computations -// fall into two types, nested and unnested. We translate each nested -// computation (e.g. the computation operand of a Map operator) to a device -// function. For each unnested computation composed of top-level -// HloInstructions, we generate a CUDA kernel for each HloInstruction. -// -// This file declares classes that translate an XLA HLO graph to LLVM IR for -// GPUs. IrEmitterNested emits LLVM IR for nested computations, and -// IrEmitterUnnested for unnested computations. The logic of emitting LLVM IR -// for each individual HloInstruction is largely the same between these two -// classes. Therefore, we implement the common logic in the Handle* functions in -// the superclass IrEmitter. - #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ @@ -60,19 +47,28 @@ limitations under the License. namespace xla { namespace gpu { -// This class is the top-level API for the XLA HLO --> LLVM IR compiler. -// It implements the DfsHloVisitor interface and emits an LLVM IR program that -// implements the input HLO graph. +// Abstract base class for translating HLO graphs to LLVM IR for a GPU. +// +// There are two concrete subclasses of IrEmitter: IrEmitterNested and +// IrEmitterUnnested. In the unnested variety, each HLO gets its own kernel +// function, whereas in the nested version the whole computation is emitted as +// one *non-kernel* function. +// +// In XLA, kernel functions never call other kernel functions. This means that +// if we have a kernel -- e.g. implementing a kReduce HLO -- that wants to use +// an HLO computation as a "subroutine" -- e.g. the HLO computation that +// specifies how to reduce two elements -- then the subroutine computation must +// be emitted using IrEmitterNested. // -// Note: if `T` is a subclass of `IrEmitter` and a handler is not overridden in -// either `IrEmitter` or `T`, the handler in `DfsHloVisitorWithDefault` -// calls `T::DefaultAction`. +// Fusion nodes are a special case. A fusion node is emitted using +// IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is +// not a subclass of gpu::IrEmitter, and in fact is better understood as an IR +// generator generator. See comments on that class. class IrEmitter : public DfsHloVisitorWithDefault { public: IrEmitter(const IrEmitter&) = delete; IrEmitter& operator=(const IrEmitter&) = delete; - // The following methods implement the DfsHloVisitorWithDefault interface. Status DefaultAction(HloInstruction* hlo) override; Status HandleConstant(HloInstruction* constant) override; Status HandleBitcast(HloInstruction* bitcast) override; @@ -217,199 +213,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { std::map computation_to_ir_function_; }; -// Emits LLVM IR for unnested computations. Each HloInstruction is translated to -// a separate CUDA kernel. These kernels are inserted into the resultant module -// sorted in reverse postorder of the XLA HLO graph. -class IrEmitterUnnested : public IrEmitter { - public: - IrEmitterUnnested(const HloModuleConfig& hlo_module_config, - const HloComputation* hlo_computation, - IrEmitterContext* ir_emitter_context); - IrEmitterUnnested(const IrEmitterUnnested&) = delete; - IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; - - // Transfers the ownship of thunk_sequence_ out. - std::unique_ptr ConsumeThunkSequence() { - return std::move(thunk_sequence_); - } - - Status DefaultAction(HloInstruction* hlo) override; - - // IrEmitterUnnested handles the following instructions differently from - // IrEmitter. - Status HandleCopy(HloInstruction* copy) override; - Status HandleConditional(HloInstruction* conditional) override; - Status HandleConvolution(HloInstruction* convolution) override; - Status HandleCustomCall(HloInstruction* custom_call) override; - Status HandleDot(HloInstruction* dot) override; - Status HandleFft(HloInstruction* fft) override; - Status HandleFusion(HloInstruction* fusion) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; - Status HandleReduce(HloInstruction* reduce) override; - Status HandleSelectAndScatter(HloInstruction* instruction) override; - Status HandleTuple(HloInstruction* tuple) override; - Status HandleWhile(HloInstruction* xla_while) override; - Status HandleInfeed(HloInstruction* xla_infeed) override; - Status HandleRng(HloInstruction* random) override; - Status HandleSelect(HloInstruction* select) override; - - Status EmitTargetElementLoop( - const HloInstruction& hlo, - const llvm_ir::ElementGenerator& body_emitter) override; - - // Same as `EmitTargetElementLoop`, but in given `thunk` rather than - // `LastThunk()`. - Status EmitTargetElementLoopInThunk( - const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, - KernelThunk* thunk); - - private: - // Builds the appropriate thunk for the instruction hlo and returns the owning - // pointer to it. The caller needs to make sure `inst` outlives the lifetime - // of the returned Thunk object. - std::unique_ptr BuildThunk(const HloInstruction* hlo); - - // Builds the prototype of the IR kernel for `inst` and adds it to the module. - llvm::Function* BuildKernelPrototype( - const HloInstruction& inst, - tensorflow::gtl::ArraySlice escaped_hlos); - - // Emits the base pointers for `hlo` and its operands. `io_hlos` will store - // all input/output HLOs among `hlo` and its operands. - llvm::Function* EmitBasePointersForHloAndItsOperands( - const HloInstruction& hlo, std::vector* io_hlos); - - // EmitColumnReduction and EmitRowReduction emit code for column and row - // reduction of a matrix and/or 3D tensor. Row and column reduction have - // different memory access pattern, so for performance their implementations - // are significantly different. - // - // Emits code that reduces a matrix of shape [height x width] to a vector of - // [width]. Other parameters have the same meaning as those of - // `EmitReductionToVector`. Note that input shape might not be - // [height x width], but can be bitcast to [height x weight] with "height" - // being the major dimension. - Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); - - // Emits code that reduces a 3D tensor of shape [depth x height x width] to a - // vector of shape [height]. Other parameters have the same meaning as those - // of `EmitReductionToVector`. Note that input shape might not be - // [depth x height x width], but can be bitcast to [depth x height x weight] - // with "depth" being the most major dimension. - Status EmitRowReduction(int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); - - // Emits code that reduces a tensor of arbitrary rank to a scalar. - Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); - - // Figures out whether `reduce` is a row or column reduction, and which - // dimensions to reduce, and calls either `EmitRowReduction` or - // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the - // input array, which is the operand of the Reduce instruction if unfused or - // of the Fusion instruction if fused. `input_gen` and `init_value_gen` - // generate elements of the input and the initial value. Other parameters mean - // the same as for `HandleReduce`. - // - // Prerequisite: `IsReductionToVector(*reduce)` - Status EmitReductionToVector( - HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer); - - // Emits code to initialize buffer of `inst` in given `thunk`. - Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); - - // Returns a KernelThunk that invokes the kernel emitted for `inst`. The - // caller needs to make sure `inst` outlives the lifetime of the returned - // Thunk object. - std::unique_ptr BuildKernelThunk(const HloInstruction* inst); - - // Returns a FftThunk that calls cuFFT to implement `inst`. - std::unique_ptr BuildFftThunk(const HloInstruction* inst); - - // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs - // to make sure `inst` outlives the lifetime of the returned Thunk object. - std::unique_ptr BuildGemmThunk(const HloInstruction* inst); - - // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. - std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); - - // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`. - std::unique_ptr BuildDeviceToDeviceCopyThunk( - const HloInstruction* inst); - - // Returns an InfeedThunk that performs device-to-device memcpy to implement - // `inst`. - std::unique_ptr BuildInfeedThunk(const HloInstruction* inst); - - // Returns a WhileThunk that invokes thunk sequences for 'condition' and - // 'body' sub-computations of while instruction 'hlo'. - std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); - - // Returns a ForThunk which executes 'loop_limit' invocations of a thunk - // sequence from the 'body' sub-computation of the while instruction 'hlo'. - std::unique_ptr BuildForThunk(const HloInstruction* hlo, - const int64 loop_limit); - - // Returns a ConditionalThunk that executes the thunk sequence for - // 'true_computation' or 'false_computation' depending on the value of the - // predicate in the given conditional instruction. - std::unique_ptr BuildConditionalThunk(const HloInstruction* hlo); - - Status Postprocess(HloInstruction* hlo) override; - - // Returns the last generated thunk. - Thunk* LastThunk() const { return thunk_sequence_->back().get(); } - - // The thunk sequence this IrEmitter generates for the input computation. - std::unique_ptr thunk_sequence_; - - // The HloComputation that this IrEmitter emits code for. - const HloComputation* hlo_computation_; -}; - -// Emits LLVM IR for a nested computation to the resultant function. -class IrEmitterNested : public IrEmitter { - public: - // Constructs an LLVM IR emitter for a nested HLO computation. `function` is - // the containing IR function this emitter produces IR to. See - // IrEmitter::IrEmitter for the meanings of other arguments. - IrEmitterNested(const HloModuleConfig& hlo_module_config, - const HloComputation& nested_computation, - IrEmitterContext* ir_emitter_context); - IrEmitterNested(const IrEmitterNested&) = delete; - IrEmitterNested& operator=(const IrEmitterNested&) = delete; - - // Overrides the default empty implementation. Binds the given instruction - // "parameter" with the parameter of the IR function. - Status HandleParameter(HloInstruction* parameter) override; - - llvm::Function* GetEmittedFunction() const { return emitted_function_; } - - Status EmitTargetElementLoop( - const HloInstruction& hlo, - const llvm_ir::ElementGenerator& body_emitter) override; - - private: - llvm::Function* EmitBasePointersForNestedComputation( - const HloComputation& nested_computation, - std::vector* io_hlos); - - llvm::Function* emitted_function_; -}; - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 5225ff3..71aada0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -16,12 +16,13 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" + #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h new file mode 100644 index 0000000..ca11cf2 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ + +#include "llvm/IR/Function.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" + +namespace xla { +namespace gpu { + +// Emits LLVM IR for a "nested computation" into a non-kernel device function. +// +// This is used to emit code for HloComputations that don't require a separate +// kernel call. For example, IrEmitterNested is used to emit code for a kReduce +// HLO's elementwise reduction computation. Notably, IrEmitterNested is *not* +// used to emit code for fusion nodes -- fusion nodes use FusedIrEmitter, which +// is a different beast altogether. +// +// IrEmitterNested generates a non-kernel function with the following +// parameters: +// +// - N pointers to the buffers of each of the N parameters to the computation, +// - a pointer to the output buffer of the computation, and +// - a pointer to the top-level temp buffer. +// +class IrEmitterNested : public IrEmitter { + public: + // Constructs an LLVM IR emitter for a nested HLO computation. `function` is + // the containing IR function this emitter produces IR to. See + // IrEmitter::IrEmitter for the meanings of other arguments. + IrEmitterNested(const HloModuleConfig& hlo_module_config, + const HloComputation& nested_computation, + IrEmitterContext* ir_emitter_context); + IrEmitterNested(const IrEmitterNested&) = delete; + IrEmitterNested& operator=(const IrEmitterNested&) = delete; + + // Overrides the default empty implementation. Binds the given instruction + // "parameter" with the parameter of the IR function. + Status HandleParameter(HloInstruction* parameter) override; + + llvm::Function* GetEmittedFunction() const { return emitted_function_; } + + Status EmitTargetElementLoop( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& body_emitter) override; + + private: + llvm::Function* EmitBasePointersForNestedComputation( + const HloComputation& nested_computation, + std::vector* io_hlos); + + llvm::Function* emitted_function_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index a4847f6..08fea34 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" + #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -40,7 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h new file mode 100644 index 0000000..56ab820 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -0,0 +1,209 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ + +#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" + +namespace xla { +namespace gpu { + +// Emits LLVM IR for an "unnested computation". +// +// An unnested computation is an HloComputation which you run by executing one +// or more kernels for each HloInstruction it contains. Examples of unnested +// computations: +// +// - An HloModule's root computation, +// - The body of an HLO while loop, +// - The true/false computation of an HLO conditional. +// +// Note the opportunity for confusion -- the while loop's computation is nested +// within the root computation, but it's emitted using IrEmitterUnnested! Don't +// think about it too hard. +// +// Examples of things that are not unnested computations: +// +// - The reducer of a kReduce HLO. This is emited using IrEmitterNested. +// - The body of a fusion node. IrEmitterUnenested emits the relevant code +// within a kernel function using FusedIrEmitter. (FusedIrEmitter is not +// really an IrEmitter, but is more an "IR generator generator".) +// +class IrEmitterUnnested : public IrEmitter { + public: + IrEmitterUnnested(const HloModuleConfig& hlo_module_config, + const HloComputation* hlo_computation, + IrEmitterContext* ir_emitter_context); + IrEmitterUnnested(const IrEmitterUnnested&) = delete; + IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; + + // Transfers the ownship of thunk_sequence_ out. + std::unique_ptr ConsumeThunkSequence() { + return std::move(thunk_sequence_); + } + + Status DefaultAction(HloInstruction* hlo) override; + + // IrEmitterUnnested handles the following instructions differently from + // IrEmitter. + Status HandleCopy(HloInstruction* copy) override; + Status HandleConditional(HloInstruction* conditional) override; + Status HandleConvolution(HloInstruction* convolution) override; + Status HandleCustomCall(HloInstruction* custom_call) override; + Status HandleDot(HloInstruction* dot) override; + Status HandleFft(HloInstruction* fft) override; + Status HandleFusion(HloInstruction* fusion) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; + Status HandleReduce(HloInstruction* reduce) override; + Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleWhile(HloInstruction* xla_while) override; + Status HandleInfeed(HloInstruction* xla_infeed) override; + Status HandleRng(HloInstruction* random) override; + Status HandleSelect(HloInstruction* select) override; + + Status EmitTargetElementLoop( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& body_emitter) override; + + // Same as `EmitTargetElementLoop`, but in given `thunk` rather than + // `LastThunk()`. + Status EmitTargetElementLoopInThunk( + const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, + KernelThunk* thunk); + + private: + // Builds the appropriate thunk for the instruction hlo and returns the owning + // pointer to it. The caller needs to make sure `inst` outlives the lifetime + // of the returned Thunk object. + std::unique_ptr BuildThunk(const HloInstruction* hlo); + + // Builds the prototype of the IR kernel for `inst` and adds it to the module. + llvm::Function* BuildKernelPrototype( + const HloInstruction& inst, + tensorflow::gtl::ArraySlice escaped_hlos); + + // Emits the base pointers for `hlo` and its operands. `io_hlos` will store + // all input/output HLOs among `hlo` and its operands. + llvm::Function* EmitBasePointersForHloAndItsOperands( + const HloInstruction& hlo, std::vector* io_hlos); + + // EmitColumnReduction and EmitRowReduction emit code for column and row + // reduction of a matrix and/or 3D tensor. Row and column reduction have + // different memory access pattern, so for performance their implementations + // are significantly different. + // + // Emits code that reduces a matrix of shape [height x width] to a vector of + // [width]. Other parameters have the same meaning as those of + // `EmitReductionToVector`. Note that input shape might not be + // [height x width], but can be bitcast to [height x weight] with "height" + // being the major dimension. + Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Emits code that reduces a 3D tensor of shape [depth x height x width] to a + // vector of shape [height]. Other parameters have the same meaning as those + // of `EmitReductionToVector`. Note that input shape might not be + // [depth x height x width], but can be bitcast to [depth x height x weight] + // with "depth" being the most major dimension. + Status EmitRowReduction(int64 depth, int64 height, int64 width, + HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Emits code that reduces a tensor of arbitrary rank to a scalar. + Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Figures out whether `reduce` is a row or column reduction, and which + // dimensions to reduce, and calls either `EmitRowReduction` or + // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the + // input array, which is the operand of the Reduce instruction if unfused or + // of the Fusion instruction if fused. `input_gen` and `init_value_gen` + // generate elements of the input and the initial value. Other parameters mean + // the same as for `HandleReduce`. + // + // Prerequisite: `IsReductionToVector(*reduce)` + Status EmitReductionToVector( + HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reducer); + + // Emits code to initialize buffer of `inst` in given `thunk`. + Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); + + // Returns a KernelThunk that invokes the kernel emitted for `inst`. The + // caller needs to make sure `inst` outlives the lifetime of the returned + // Thunk object. + std::unique_ptr BuildKernelThunk(const HloInstruction* inst); + + // Returns a FftThunk that calls cuFFT to implement `inst`. + std::unique_ptr BuildFftThunk(const HloInstruction* inst); + + // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs + // to make sure `inst` outlives the lifetime of the returned Thunk object. + std::unique_ptr BuildGemmThunk(const HloInstruction* inst); + + // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. + std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); + + // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`. + std::unique_ptr BuildDeviceToDeviceCopyThunk( + const HloInstruction* inst); + + // Returns an InfeedThunk that performs device-to-device memcpy to implement + // `inst`. + std::unique_ptr BuildInfeedThunk(const HloInstruction* inst); + + // Returns a WhileThunk that invokes thunk sequences for 'condition' and + // 'body' sub-computations of while instruction 'hlo'. + std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); + + // Returns a ForThunk which executes 'loop_limit' invocations of a thunk + // sequence from the 'body' sub-computation of the while instruction 'hlo'. + std::unique_ptr BuildForThunk(const HloInstruction* hlo, + const int64 loop_limit); + + // Returns a ConditionalThunk that executes the thunk sequence for + // 'true_computation' or 'false_computation' depending on the value of the + // predicate in the given conditional instruction. + std::unique_ptr BuildConditionalThunk(const HloInstruction* hlo); + + Status Postprocess(HloInstruction* hlo) override; + + // Returns the last generated thunk. + Thunk* LastThunk() const { return thunk_sequence_->back().get(); } + + // The thunk sequence this IrEmitter generates for the input computation. + std::unique_ptr thunk_sequence_; + + // The HloComputation that this IrEmitter emits code for. + const HloComputation* hlo_computation_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 242062e..b3b6026 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -32,14 +32,23 @@ limitations under the License. namespace xla { -// Unlike IrEmitter, this creates host functions which emit IR to generate the -// output element at the given index. It is used to generate fused operations. +// FusedIrEmitter is used to generate code for fusion nodes. +// +// Unlike IrEmitter and its ilk, which directly create LLVM IR in an LLVM +// Module, FusedIrEmitter is better understood as "IR generator generator". +// FusedIrEmitter recursively creates a generator (a host function) which the +// compiler can invoke at a later time. Invoking the generator emits LLVM IR +// that, when run, produces the value at a particular index of the output. +// +// After building this generator, the compiler creates a loop (or its moral +// equivalent, e.g. a GPU kernel) and calls the generator from within the loop. +// This generates code that produces each element of the output. // // This class handles both vanilla fusion and multi-output fusion. In the MOF -// case, the fusion node ends with a kTuple instruction, and the root generator -// returned by this emitter returns an LLVM struct with N elements, one for each -// element of the arrays in the tuple. It follows that the arrays in the tuple -// must have the same length. +// case, the fusion node ends with a kTuple instruction, and the generator +// created produces an LLVM struct with N elements, one for each element of the +// arrays in the tuple. It follows that the arrays in the tuple must have the +// same length. class FusedIrEmitter : public DfsHloVisitorWithDefault { public: using Generator = llvm_ir::ElementGenerator;