hdrs = [
"ir_emitter.h",
"ir_emitter_context.h",
+ "ir_emitter_nested.h",
+ "ir_emitter_unnested.h",
],
deps = [
":cudnn_convolution_runner",
#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
-#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
limitations under the License.
==============================================================================*/
-// An XLA HLO graph may contain multiple computations. These computations
-// fall into two types, nested and unnested. We translate each nested
-// computation (e.g. the computation operand of a Map operator) to a device
-// function. For each unnested computation composed of top-level
-// HloInstructions, we generate a CUDA kernel for each HloInstruction.
-//
-// This file declares classes that translate an XLA HLO graph to LLVM IR for
-// GPUs. IrEmitterNested emits LLVM IR for nested computations, and
-// IrEmitterUnnested for unnested computations. The logic of emitting LLVM IR
-// for each individual HloInstruction is largely the same between these two
-// classes. Therefore, we implement the common logic in the Handle* functions in
-// the superclass IrEmitter.
-
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
namespace xla {
namespace gpu {
-// This class is the top-level API for the XLA HLO --> LLVM IR compiler.
-// It implements the DfsHloVisitor interface and emits an LLVM IR program that
-// implements the input HLO graph.
+// Abstract base class for translating HLO graphs to LLVM IR for a GPU.
+//
+// There are two concrete subclasses of IrEmitter: IrEmitterNested and
+// IrEmitterUnnested. In the unnested variety, each HLO gets its own kernel
+// function, whereas in the nested version the whole computation is emitted as
+// one *non-kernel* function.
+//
+// In XLA, kernel functions never call other kernel functions. This means that
+// if we have a kernel -- e.g. implementing a kReduce HLO -- that wants to use
+// an HLO computation as a "subroutine" -- e.g. the HLO computation that
+// specifies how to reduce two elements -- then the subroutine computation must
+// be emitted using IrEmitterNested.
//
-// Note: if `T` is a subclass of `IrEmitter` and a handler is not overridden in
-// either `IrEmitter` or `T`, the handler in `DfsHloVisitorWithDefault`
-// calls `T::DefaultAction`.
+// Fusion nodes are a special case. A fusion node is emitted using
+// IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is
+// not a subclass of gpu::IrEmitter, and in fact is better understood as an IR
+// generator generator. See comments on that class.
class IrEmitter : public DfsHloVisitorWithDefault {
public:
IrEmitter(const IrEmitter&) = delete;
IrEmitter& operator=(const IrEmitter&) = delete;
- // The following methods implement the DfsHloVisitorWithDefault interface.
Status DefaultAction(HloInstruction* hlo) override;
Status HandleConstant(HloInstruction* constant) override;
Status HandleBitcast(HloInstruction* bitcast) override;
std::map<const HloComputation*, llvm::Function*> computation_to_ir_function_;
};
-// Emits LLVM IR for unnested computations. Each HloInstruction is translated to
-// a separate CUDA kernel. These kernels are inserted into the resultant module
-// sorted in reverse postorder of the XLA HLO graph.
-class IrEmitterUnnested : public IrEmitter {
- public:
- IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
- const HloComputation* hlo_computation,
- IrEmitterContext* ir_emitter_context);
- IrEmitterUnnested(const IrEmitterUnnested&) = delete;
- IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;
-
- // Transfers the ownship of thunk_sequence_ out.
- std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
- return std::move(thunk_sequence_);
- }
-
- Status DefaultAction(HloInstruction* hlo) override;
-
- // IrEmitterUnnested handles the following instructions differently from
- // IrEmitter.
- Status HandleCopy(HloInstruction* copy) override;
- Status HandleConditional(HloInstruction* conditional) override;
- Status HandleConvolution(HloInstruction* convolution) override;
- Status HandleCustomCall(HloInstruction* custom_call) override;
- Status HandleDot(HloInstruction* dot) override;
- Status HandleFft(HloInstruction* fft) override;
- Status HandleFusion(HloInstruction* fusion) override;
- Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
- Status HandleReduce(HloInstruction* reduce) override;
- Status HandleSelectAndScatter(HloInstruction* instruction) override;
- Status HandleTuple(HloInstruction* tuple) override;
- Status HandleWhile(HloInstruction* xla_while) override;
- Status HandleInfeed(HloInstruction* xla_infeed) override;
- Status HandleRng(HloInstruction* random) override;
- Status HandleSelect(HloInstruction* select) override;
-
- Status EmitTargetElementLoop(
- const HloInstruction& hlo,
- const llvm_ir::ElementGenerator& body_emitter) override;
-
- // Same as `EmitTargetElementLoop`, but in given `thunk` rather than
- // `LastThunk()`.
- Status EmitTargetElementLoopInThunk(
- const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter,
- KernelThunk* thunk);
-
- private:
- // Builds the appropriate thunk for the instruction hlo and returns the owning
- // pointer to it. The caller needs to make sure `inst` outlives the lifetime
- // of the returned Thunk object.
- std::unique_ptr<Thunk> BuildThunk(const HloInstruction* hlo);
-
- // Builds the prototype of the IR kernel for `inst` and adds it to the module.
- llvm::Function* BuildKernelPrototype(
- const HloInstruction& inst,
- tensorflow::gtl::ArraySlice<const HloInstruction*> escaped_hlos);
-
- // Emits the base pointers for `hlo` and its operands. `io_hlos` will store
- // all input/output HLOs among `hlo` and its operands.
- llvm::Function* EmitBasePointersForHloAndItsOperands(
- const HloInstruction& hlo, std::vector<const HloInstruction*>* io_hlos);
-
- // EmitColumnReduction and EmitRowReduction emit code for column and row
- // reduction of a matrix and/or 3D tensor. Row and column reduction have
- // different memory access pattern, so for performance their implementations
- // are significantly different.
- //
- // Emits code that reduces a matrix of shape [height x width] to a vector of
- // [width]. Other parameters have the same meaning as those of
- // `EmitReductionToVector`. Note that input shape might not be
- // [height x width], but can be bitcast to [height x weight] with "height"
- // being the major dimension.
- Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce,
- const Shape& input_shape,
- const llvm_ir::ElementGenerator& input_gen,
- const llvm_ir::ElementGenerator& init_value_gen,
- HloComputation* reducer);
-
- // Emits code that reduces a 3D tensor of shape [depth x height x width] to a
- // vector of shape [height]. Other parameters have the same meaning as those
- // of `EmitReductionToVector`. Note that input shape might not be
- // [depth x height x width], but can be bitcast to [depth x height x weight]
- // with "depth" being the most major dimension.
- Status EmitRowReduction(int64 depth, int64 height, int64 width,
- HloInstruction* reduce, const Shape& input_shape,
- const llvm_ir::ElementGenerator& input_gen,
- const llvm_ir::ElementGenerator& init_value_gen,
- HloComputation* reducer);
-
- // Emits code that reduces a tensor of arbitrary rank to a scalar.
- Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape,
- const llvm_ir::ElementGenerator& input_gen,
- const llvm_ir::ElementGenerator& init_value_gen,
- HloComputation* reducer);
-
- // Figures out whether `reduce` is a row or column reduction, and which
- // dimensions to reduce, and calls either `EmitRowReduction` or
- // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the
- // input array, which is the operand of the Reduce instruction if unfused or
- // of the Fusion instruction if fused. `input_gen` and `init_value_gen`
- // generate elements of the input and the initial value. Other parameters mean
- // the same as for `HandleReduce`.
- //
- // Prerequisite: `IsReductionToVector(*reduce)`
- Status EmitReductionToVector(
- HloInstruction* reduce, const Shape& input_shape,
- const llvm_ir::ElementGenerator& input_gen,
- const llvm_ir::ElementGenerator& init_value_gen,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- HloComputation* reducer);
-
- // Emits code to initialize buffer of `inst` in given `thunk`.
- Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk);
-
- // Returns a KernelThunk that invokes the kernel emitted for `inst`. The
- // caller needs to make sure `inst` outlives the lifetime of the returned
- // Thunk object.
- std::unique_ptr<Thunk> BuildKernelThunk(const HloInstruction* inst);
-
- // Returns a FftThunk that calls cuFFT to implement `inst`.
- std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
-
- // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
- // to make sure `inst` outlives the lifetime of the returned Thunk object.
- std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);
-
- // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
- std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);
-
- // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`.
- std::unique_ptr<Thunk> BuildDeviceToDeviceCopyThunk(
- const HloInstruction* inst);
-
- // Returns an InfeedThunk that performs device-to-device memcpy to implement
- // `inst`.
- std::unique_ptr<Thunk> BuildInfeedThunk(const HloInstruction* inst);
-
- // Returns a WhileThunk that invokes thunk sequences for 'condition' and
- // 'body' sub-computations of while instruction 'hlo'.
- std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);
-
- // Returns a ForThunk which executes 'loop_limit' invocations of a thunk
- // sequence from the 'body' sub-computation of the while instruction 'hlo'.
- std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
- const int64 loop_limit);
-
- // Returns a ConditionalThunk that executes the thunk sequence for
- // 'true_computation' or 'false_computation' depending on the value of the
- // predicate in the given conditional instruction.
- std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);
-
- Status Postprocess(HloInstruction* hlo) override;
-
- // Returns the last generated thunk.
- Thunk* LastThunk() const { return thunk_sequence_->back().get(); }
-
- // The thunk sequence this IrEmitter generates for the input computation.
- std::unique_ptr<ThunkSequence> thunk_sequence_;
-
- // The HloComputation that this IrEmitter emits code for.
- const HloComputation* hlo_computation_;
-};
-
-// Emits LLVM IR for a nested computation to the resultant function.
-class IrEmitterNested : public IrEmitter {
- public:
- // Constructs an LLVM IR emitter for a nested HLO computation. `function` is
- // the containing IR function this emitter produces IR to. See
- // IrEmitter::IrEmitter for the meanings of other arguments.
- IrEmitterNested(const HloModuleConfig& hlo_module_config,
- const HloComputation& nested_computation,
- IrEmitterContext* ir_emitter_context);
- IrEmitterNested(const IrEmitterNested&) = delete;
- IrEmitterNested& operator=(const IrEmitterNested&) = delete;
-
- // Overrides the default empty implementation. Binds the given instruction
- // "parameter" with the parameter of the IR function.
- Status HandleParameter(HloInstruction* parameter) override;
-
- llvm::Function* GetEmittedFunction() const { return emitted_function_; }
-
- Status EmitTargetElementLoop(
- const HloInstruction& hlo,
- const llvm_ir::ElementGenerator& body_emitter) override;
-
- private:
- llvm::Function* EmitBasePointersForNestedComputation(
- const HloComputation& nested_computation,
- std::vector<const HloInstruction*>* io_hlos);
-
- llvm::Function* emitted_function_;
-};
-
} // namespace gpu
} // namespace xla
#include <memory>
#include <vector>
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
+
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
-#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_
+
+#include "llvm/IR/Function.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
+
+namespace xla {
+namespace gpu {
+
+// Emits LLVM IR for a "nested computation" into a non-kernel device function.
+//
+// This is used to emit code for HloComputations that don't require a separate
+// kernel call. For example, IrEmitterNested is used to emit code for a kReduce
+// HLO's elementwise reduction computation. Notably, IrEmitterNested is *not*
+// used to emit code for fusion nodes -- fusion nodes use FusedIrEmitter, which
+// is a different beast altogether.
+//
+// IrEmitterNested generates a non-kernel function with the following
+// parameters:
+//
+// - N pointers to the buffers of each of the N parameters to the computation,
+// - a pointer to the output buffer of the computation, and
+// - a pointer to the top-level temp buffer.
+//
+class IrEmitterNested : public IrEmitter {
+ public:
+ // Constructs an LLVM IR emitter for a nested HLO computation. `function` is
+ // the containing IR function this emitter produces IR to. See
+ // IrEmitter::IrEmitter for the meanings of other arguments.
+ IrEmitterNested(const HloModuleConfig& hlo_module_config,
+ const HloComputation& nested_computation,
+ IrEmitterContext* ir_emitter_context);
+ IrEmitterNested(const IrEmitterNested&) = delete;
+ IrEmitterNested& operator=(const IrEmitterNested&) = delete;
+
+ // Overrides the default empty implementation. Binds the given instruction
+ // "parameter" with the parameter of the IR function.
+ Status HandleParameter(HloInstruction* parameter) override;
+
+ llvm::Function* GetEmittedFunction() const { return emitted_function_; }
+
+ Status EmitTargetElementLoop(
+ const HloInstruction& hlo,
+ const llvm_ir::ElementGenerator& body_emitter) override;
+
+ private:
+ llvm::Function* EmitBasePointersForNestedComputation(
+ const HloComputation& nested_computation,
+ std::vector<const HloInstruction*>* io_hlos);
+
+ llvm::Function* emitted_function_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_
#include <string>
#include <vector>
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
+
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
-#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
+
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+
+namespace xla {
+namespace gpu {
+
+// Emits LLVM IR for an "unnested computation".
+//
+// An unnested computation is an HloComputation which you run by executing one
+// or more kernels for each HloInstruction it contains. Examples of unnested
+// computations:
+//
+// - An HloModule's root computation,
+// - The body of an HLO while loop,
+// - The true/false computation of an HLO conditional.
+//
+// Note the opportunity for confusion -- the while loop's computation is nested
+// within the root computation, but it's emitted using IrEmitterUnnested! Don't
+// think about it too hard.
+//
+// Examples of things that are not unnested computations:
+//
+// - The reducer of a kReduce HLO. This is emited using IrEmitterNested.
+// - The body of a fusion node. IrEmitterUnenested emits the relevant code
+// within a kernel function using FusedIrEmitter. (FusedIrEmitter is not
+// really an IrEmitter, but is more an "IR generator generator".)
+//
+class IrEmitterUnnested : public IrEmitter {
+ public:
+ IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
+ const HloComputation* hlo_computation,
+ IrEmitterContext* ir_emitter_context);
+ IrEmitterUnnested(const IrEmitterUnnested&) = delete;
+ IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;
+
+ // Transfers the ownship of thunk_sequence_ out.
+ std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
+ return std::move(thunk_sequence_);
+ }
+
+ Status DefaultAction(HloInstruction* hlo) override;
+
+ // IrEmitterUnnested handles the following instructions differently from
+ // IrEmitter.
+ Status HandleCopy(HloInstruction* copy) override;
+ Status HandleConditional(HloInstruction* conditional) override;
+ Status HandleConvolution(HloInstruction* convolution) override;
+ Status HandleCustomCall(HloInstruction* custom_call) override;
+ Status HandleDot(HloInstruction* dot) override;
+ Status HandleFft(HloInstruction* fft) override;
+ Status HandleFusion(HloInstruction* fusion) override;
+ Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
+ Status HandleReduce(HloInstruction* reduce) override;
+ Status HandleSelectAndScatter(HloInstruction* instruction) override;
+ Status HandleTuple(HloInstruction* tuple) override;
+ Status HandleWhile(HloInstruction* xla_while) override;
+ Status HandleInfeed(HloInstruction* xla_infeed) override;
+ Status HandleRng(HloInstruction* random) override;
+ Status HandleSelect(HloInstruction* select) override;
+
+ Status EmitTargetElementLoop(
+ const HloInstruction& hlo,
+ const llvm_ir::ElementGenerator& body_emitter) override;
+
+ // Same as `EmitTargetElementLoop`, but in given `thunk` rather than
+ // `LastThunk()`.
+ Status EmitTargetElementLoopInThunk(
+ const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter,
+ KernelThunk* thunk);
+
+ private:
+ // Builds the appropriate thunk for the instruction hlo and returns the owning
+ // pointer to it. The caller needs to make sure `inst` outlives the lifetime
+ // of the returned Thunk object.
+ std::unique_ptr<Thunk> BuildThunk(const HloInstruction* hlo);
+
+ // Builds the prototype of the IR kernel for `inst` and adds it to the module.
+ llvm::Function* BuildKernelPrototype(
+ const HloInstruction& inst,
+ tensorflow::gtl::ArraySlice<const HloInstruction*> escaped_hlos);
+
+ // Emits the base pointers for `hlo` and its operands. `io_hlos` will store
+ // all input/output HLOs among `hlo` and its operands.
+ llvm::Function* EmitBasePointersForHloAndItsOperands(
+ const HloInstruction& hlo, std::vector<const HloInstruction*>* io_hlos);
+
+ // EmitColumnReduction and EmitRowReduction emit code for column and row
+ // reduction of a matrix and/or 3D tensor. Row and column reduction have
+ // different memory access pattern, so for performance their implementations
+ // are significantly different.
+ //
+ // Emits code that reduces a matrix of shape [height x width] to a vector of
+ // [width]. Other parameters have the same meaning as those of
+ // `EmitReductionToVector`. Note that input shape might not be
+ // [height x width], but can be bitcast to [height x weight] with "height"
+ // being the major dimension.
+ Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce,
+ const Shape& input_shape,
+ const llvm_ir::ElementGenerator& input_gen,
+ const llvm_ir::ElementGenerator& init_value_gen,
+ HloComputation* reducer);
+
+ // Emits code that reduces a 3D tensor of shape [depth x height x width] to a
+ // vector of shape [height]. Other parameters have the same meaning as those
+ // of `EmitReductionToVector`. Note that input shape might not be
+ // [depth x height x width], but can be bitcast to [depth x height x weight]
+ // with "depth" being the most major dimension.
+ Status EmitRowReduction(int64 depth, int64 height, int64 width,
+ HloInstruction* reduce, const Shape& input_shape,
+ const llvm_ir::ElementGenerator& input_gen,
+ const llvm_ir::ElementGenerator& init_value_gen,
+ HloComputation* reducer);
+
+ // Emits code that reduces a tensor of arbitrary rank to a scalar.
+ Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape,
+ const llvm_ir::ElementGenerator& input_gen,
+ const llvm_ir::ElementGenerator& init_value_gen,
+ HloComputation* reducer);
+
+ // Figures out whether `reduce` is a row or column reduction, and which
+ // dimensions to reduce, and calls either `EmitRowReduction` or
+ // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the
+ // input array, which is the operand of the Reduce instruction if unfused or
+ // of the Fusion instruction if fused. `input_gen` and `init_value_gen`
+ // generate elements of the input and the initial value. Other parameters mean
+ // the same as for `HandleReduce`.
+ //
+ // Prerequisite: `IsReductionToVector(*reduce)`
+ Status EmitReductionToVector(
+ HloInstruction* reduce, const Shape& input_shape,
+ const llvm_ir::ElementGenerator& input_gen,
+ const llvm_ir::ElementGenerator& init_value_gen,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reducer);
+
+ // Emits code to initialize buffer of `inst` in given `thunk`.
+ Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk);
+
+ // Returns a KernelThunk that invokes the kernel emitted for `inst`. The
+ // caller needs to make sure `inst` outlives the lifetime of the returned
+ // Thunk object.
+ std::unique_ptr<Thunk> BuildKernelThunk(const HloInstruction* inst);
+
+ // Returns a FftThunk that calls cuFFT to implement `inst`.
+ std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
+
+ // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
+ // to make sure `inst` outlives the lifetime of the returned Thunk object.
+ std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);
+
+ // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
+ std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);
+
+ // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`.
+ std::unique_ptr<Thunk> BuildDeviceToDeviceCopyThunk(
+ const HloInstruction* inst);
+
+ // Returns an InfeedThunk that performs device-to-device memcpy to implement
+ // `inst`.
+ std::unique_ptr<Thunk> BuildInfeedThunk(const HloInstruction* inst);
+
+ // Returns a WhileThunk that invokes thunk sequences for 'condition' and
+ // 'body' sub-computations of while instruction 'hlo'.
+ std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);
+
+ // Returns a ForThunk which executes 'loop_limit' invocations of a thunk
+ // sequence from the 'body' sub-computation of the while instruction 'hlo'.
+ std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
+ const int64 loop_limit);
+
+ // Returns a ConditionalThunk that executes the thunk sequence for
+ // 'true_computation' or 'false_computation' depending on the value of the
+ // predicate in the given conditional instruction.
+ std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);
+
+ Status Postprocess(HloInstruction* hlo) override;
+
+ // Returns the last generated thunk.
+ Thunk* LastThunk() const { return thunk_sequence_->back().get(); }
+
+ // The thunk sequence this IrEmitter generates for the input computation.
+ std::unique_ptr<ThunkSequence> thunk_sequence_;
+
+ // The HloComputation that this IrEmitter emits code for.
+ const HloComputation* hlo_computation_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
namespace xla {
-// Unlike IrEmitter, this creates host functions which emit IR to generate the
-// output element at the given index. It is used to generate fused operations.
+// FusedIrEmitter is used to generate code for fusion nodes.
+//
+// Unlike IrEmitter and its ilk, which directly create LLVM IR in an LLVM
+// Module, FusedIrEmitter is better understood as "IR generator generator".
+// FusedIrEmitter recursively creates a generator (a host function) which the
+// compiler can invoke at a later time. Invoking the generator emits LLVM IR
+// that, when run, produces the value at a particular index of the output.
+//
+// After building this generator, the compiler creates a loop (or its moral
+// equivalent, e.g. a GPU kernel) and calls the generator from within the loop.
+// This generates code that produces each element of the output.
//
// This class handles both vanilla fusion and multi-output fusion. In the MOF
-// case, the fusion node ends with a kTuple instruction, and the root generator
-// returned by this emitter returns an LLVM struct with N elements, one for each
-// element of the arrays in the tuple. It follows that the arrays in the tuple
-// must have the same length.
+// case, the fusion node ends with a kTuple instruction, and the generator
+// created produces an LLVM struct with N elements, one for each element of the
+// arrays in the tuple. It follows that the arrays in the tuple must have the
+// same length.
class FusedIrEmitter : public DfsHloVisitorWithDefault {
public:
using Generator = llvm_ir::ElementGenerator;