[XLA:GPU] Split IrEmitter{Unn,N}ested out of ir_emitter.h.
authorJustin Lebar <jlebar@google.com>
Tue, 6 Feb 2018 01:11:39 +0000 (17:11 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Feb 2018 01:15:26 +0000 (17:15 -0800)
Also add a bunch of clarifying comments.

PiperOrigin-RevId: 184610674

tensorflow/compiler/xla/service/gpu/BUILD
tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
tensorflow/compiler/xla/service/gpu/ir_emitter.cc
tensorflow/compiler/xla/service/gpu/ir_emitter.h
tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h [new file with mode: 0644]
tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h [new file with mode: 0644]
tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h

index 7df01f7..9da4fb9 100644 (file)
@@ -129,6 +129,8 @@ cc_library(
     hdrs = [
         "ir_emitter.h",
         "ir_emitter_context.h",
+        "ir_emitter_nested.h",
+        "ir_emitter_unnested.h",
     ],
     deps = [
         ":cudnn_convolution_runner",
index 12ec266..28ebd03 100644 (file)
@@ -47,8 +47,8 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
-#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
 #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
index affd2ff..a3df67a 100644 (file)
@@ -27,6 +27,8 @@ limitations under the License.
 #include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
index 9031a83..b0accc0 100644 (file)
@@ -13,19 +13,6 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-// An XLA HLO graph may contain multiple computations. These computations
-// fall into two types, nested and unnested. We translate each nested
-// computation (e.g. the computation operand of a Map operator) to a device
-// function. For each unnested computation composed of top-level
-// HloInstructions, we generate a CUDA kernel for each HloInstruction.
-//
-// This file declares classes that translate an XLA HLO graph to LLVM IR for
-// GPUs. IrEmitterNested emits LLVM IR for nested computations, and
-// IrEmitterUnnested for unnested computations. The logic of emitting LLVM IR
-// for each individual HloInstruction is largely the same between these two
-// classes. Therefore, we implement the common logic in the Handle* functions in
-// the superclass IrEmitter.
-
 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
 
@@ -60,19 +47,28 @@ limitations under the License.
 namespace xla {
 namespace gpu {
 
-// This class is the top-level API for the XLA HLO --> LLVM IR compiler.
-// It implements the DfsHloVisitor interface and emits an LLVM IR program that
-// implements the input HLO graph.
+// Abstract base class for translating HLO graphs to LLVM IR for a GPU.
+//
+// There are two concrete subclasses of IrEmitter: IrEmitterNested and
+// IrEmitterUnnested.  In the unnested variety, each HLO gets its own kernel
+// function, whereas in the nested version the whole computation is emitted as
+// one *non-kernel* function.
+//
+// In XLA, kernel functions never call other kernel functions.  This means that
+// if we have a kernel -- e.g. implementing a kReduce HLO -- that wants to use
+// an HLO computation as a "subroutine" -- e.g. the HLO computation that
+// specifies how to reduce two elements -- then the subroutine computation must
+// be emitted using IrEmitterNested.
 //
-// Note: if `T` is a subclass of `IrEmitter` and a handler is not overridden in
-//       either `IrEmitter` or `T`, the handler in `DfsHloVisitorWithDefault`
-//       calls `T::DefaultAction`.
+// Fusion nodes are a special case.  A fusion node is emitted using
+// IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is
+// not a subclass of gpu::IrEmitter, and in fact is better understood as an IR
+// generator generator.  See comments on that class.
 class IrEmitter : public DfsHloVisitorWithDefault {
  public:
   IrEmitter(const IrEmitter&) = delete;
   IrEmitter& operator=(const IrEmitter&) = delete;
 
-  // The following methods implement the DfsHloVisitorWithDefault interface.
   Status DefaultAction(HloInstruction* hlo) override;
   Status HandleConstant(HloInstruction* constant) override;
   Status HandleBitcast(HloInstruction* bitcast) override;
@@ -217,199 +213,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
   std::map<const HloComputation*, llvm::Function*> computation_to_ir_function_;
 };
 
-// Emits LLVM IR for unnested computations. Each HloInstruction is translated to
-// a separate CUDA kernel. These kernels are inserted into the resultant module
-// sorted in reverse postorder of the XLA HLO graph.
-class IrEmitterUnnested : public IrEmitter {
- public:
-  IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
-                    const HloComputation* hlo_computation,
-                    IrEmitterContext* ir_emitter_context);
-  IrEmitterUnnested(const IrEmitterUnnested&) = delete;
-  IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;
-
-  // Transfers the ownship of thunk_sequence_ out.
-  std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
-    return std::move(thunk_sequence_);
-  }
-
-  Status DefaultAction(HloInstruction* hlo) override;
-
-  // IrEmitterUnnested handles the following instructions differently from
-  // IrEmitter.
-  Status HandleCopy(HloInstruction* copy) override;
-  Status HandleConditional(HloInstruction* conditional) override;
-  Status HandleConvolution(HloInstruction* convolution) override;
-  Status HandleCustomCall(HloInstruction* custom_call) override;
-  Status HandleDot(HloInstruction* dot) override;
-  Status HandleFft(HloInstruction* fft) override;
-  Status HandleFusion(HloInstruction* fusion) override;
-  Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
-  Status HandleReduce(HloInstruction* reduce) override;
-  Status HandleSelectAndScatter(HloInstruction* instruction) override;
-  Status HandleTuple(HloInstruction* tuple) override;
-  Status HandleWhile(HloInstruction* xla_while) override;
-  Status HandleInfeed(HloInstruction* xla_infeed) override;
-  Status HandleRng(HloInstruction* random) override;
-  Status HandleSelect(HloInstruction* select) override;
-
-  Status EmitTargetElementLoop(
-      const HloInstruction& hlo,
-      const llvm_ir::ElementGenerator& body_emitter) override;
-
-  // Same as `EmitTargetElementLoop`, but in given `thunk` rather than
-  // `LastThunk()`.
-  Status EmitTargetElementLoopInThunk(
-      const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter,
-      KernelThunk* thunk);
-
- private:
-  // Builds the appropriate thunk for the instruction hlo and returns the owning
-  // pointer to it. The caller needs to make sure `inst` outlives the lifetime
-  // of the returned Thunk object.
-  std::unique_ptr<Thunk> BuildThunk(const HloInstruction* hlo);
-
-  // Builds the prototype of the IR kernel for `inst` and adds it to the module.
-  llvm::Function* BuildKernelPrototype(
-      const HloInstruction& inst,
-      tensorflow::gtl::ArraySlice<const HloInstruction*> escaped_hlos);
-
-  // Emits the base pointers for `hlo` and its operands. `io_hlos` will store
-  // all input/output HLOs among `hlo` and its operands.
-  llvm::Function* EmitBasePointersForHloAndItsOperands(
-      const HloInstruction& hlo, std::vector<const HloInstruction*>* io_hlos);
-
-  // EmitColumnReduction and EmitRowReduction emit code for column and row
-  // reduction of a matrix and/or 3D tensor. Row and column reduction have
-  // different memory access pattern, so for performance their implementations
-  // are significantly different.
-  //
-  // Emits code that reduces a matrix of shape [height x width] to a vector of
-  // [width]. Other parameters have the same meaning as those of
-  // `EmitReductionToVector`. Note that input shape might not be
-  // [height x width], but can be bitcast to [height x weight] with "height"
-  // being the major dimension.
-  Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce,
-                             const Shape& input_shape,
-                             const llvm_ir::ElementGenerator& input_gen,
-                             const llvm_ir::ElementGenerator& init_value_gen,
-                             HloComputation* reducer);
-
-  // Emits code that reduces a 3D tensor of shape [depth x height x width] to a
-  // vector of shape [height]. Other parameters have the same meaning as those
-  // of `EmitReductionToVector`. Note that input shape might not be
-  // [depth x height x width], but can be bitcast to [depth x height x weight]
-  // with "depth" being the most major dimension.
-  Status EmitRowReduction(int64 depth, int64 height, int64 width,
-                          HloInstruction* reduce, const Shape& input_shape,
-                          const llvm_ir::ElementGenerator& input_gen,
-                          const llvm_ir::ElementGenerator& init_value_gen,
-                          HloComputation* reducer);
-
-  // Emits code that reduces a tensor of arbitrary rank to a scalar.
-  Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape,
-                               const llvm_ir::ElementGenerator& input_gen,
-                               const llvm_ir::ElementGenerator& init_value_gen,
-                               HloComputation* reducer);
-
-  // Figures out whether `reduce` is a row or column reduction, and which
-  // dimensions to reduce, and calls either `EmitRowReduction` or
-  // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the
-  // input array, which is the operand of the Reduce instruction if unfused or
-  // of the Fusion instruction if fused. `input_gen` and `init_value_gen`
-  // generate elements of the input and the initial value. Other parameters mean
-  // the same as for `HandleReduce`.
-  //
-  // Prerequisite: `IsReductionToVector(*reduce)`
-  Status EmitReductionToVector(
-      HloInstruction* reduce, const Shape& input_shape,
-      const llvm_ir::ElementGenerator& input_gen,
-      const llvm_ir::ElementGenerator& init_value_gen,
-      tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
-      HloComputation* reducer);
-
-  // Emits code to initialize buffer of `inst` in given `thunk`.
-  Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk);
-
-  // Returns a KernelThunk that invokes the kernel emitted for `inst`. The
-  // caller needs to make sure `inst` outlives the lifetime of the returned
-  // Thunk object.
-  std::unique_ptr<Thunk> BuildKernelThunk(const HloInstruction* inst);
-
-  // Returns a FftThunk that calls cuFFT to implement `inst`.
-  std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
-
-  // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
-  // to make sure `inst` outlives the lifetime of the returned Thunk object.
-  std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);
-
-  // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
-  std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);
-
-  // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`.
-  std::unique_ptr<Thunk> BuildDeviceToDeviceCopyThunk(
-      const HloInstruction* inst);
-
-  // Returns an InfeedThunk that performs device-to-device memcpy to implement
-  // `inst`.
-  std::unique_ptr<Thunk> BuildInfeedThunk(const HloInstruction* inst);
-
-  // Returns a WhileThunk that invokes thunk sequences for 'condition' and
-  // 'body' sub-computations of while instruction 'hlo'.
-  std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);
-
-  // Returns a ForThunk which executes 'loop_limit' invocations of a thunk
-  // sequence from the 'body' sub-computation of the while instruction 'hlo'.
-  std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
-                                       const int64 loop_limit);
-
-  // Returns a ConditionalThunk that executes the thunk sequence for
-  // 'true_computation' or 'false_computation' depending on the value of the
-  // predicate in the given conditional instruction.
-  std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);
-
-  Status Postprocess(HloInstruction* hlo) override;
-
-  // Returns the last generated thunk.
-  Thunk* LastThunk() const { return thunk_sequence_->back().get(); }
-
-  // The thunk sequence this IrEmitter generates for the input computation.
-  std::unique_ptr<ThunkSequence> thunk_sequence_;
-
-  // The HloComputation that this IrEmitter emits code for.
-  const HloComputation* hlo_computation_;
-};
-
-// Emits LLVM IR for a nested computation to the resultant function.
-class IrEmitterNested : public IrEmitter {
- public:
-  // Constructs an LLVM IR emitter for a nested HLO computation. `function` is
-  // the containing IR function this emitter produces IR to. See
-  // IrEmitter::IrEmitter for the meanings of other arguments.
-  IrEmitterNested(const HloModuleConfig& hlo_module_config,
-                  const HloComputation& nested_computation,
-                  IrEmitterContext* ir_emitter_context);
-  IrEmitterNested(const IrEmitterNested&) = delete;
-  IrEmitterNested& operator=(const IrEmitterNested&) = delete;
-
-  // Overrides the default empty implementation. Binds the given instruction
-  // "parameter" with the parameter of the IR function.
-  Status HandleParameter(HloInstruction* parameter) override;
-
-  llvm::Function* GetEmittedFunction() const { return emitted_function_; }
-
-  Status EmitTargetElementLoop(
-      const HloInstruction& hlo,
-      const llvm_ir::ElementGenerator& body_emitter) override;
-
- private:
-  llvm::Function* EmitBasePointersForNestedComputation(
-      const HloComputation& nested_computation,
-      std::vector<const HloInstruction*>* io_hlos);
-
-  llvm::Function* emitted_function_;
-};
-
 }  // namespace gpu
 }  // namespace xla
 
index 5225ff3..71aada0 100644 (file)
@@ -16,12 +16,13 @@ limitations under the License.
 #include <memory>
 #include <vector>
 
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
+
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
-#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h
new file mode 100644 (file)
index 0000000..ca11cf2
--- /dev/null
@@ -0,0 +1,72 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_
+
+#include "llvm/IR/Function.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
+
+namespace xla {
+namespace gpu {
+
+// Emits LLVM IR for a "nested computation" into a non-kernel device function.
+//
+// This is used to emit code for HloComputations that don't require a separate
+// kernel call.  For example, IrEmitterNested is used to emit code for a kReduce
+// HLO's elementwise reduction computation.  Notably, IrEmitterNested is *not*
+// used to emit code for fusion nodes -- fusion nodes use FusedIrEmitter, which
+// is a different beast altogether.
+//
+// IrEmitterNested generates a non-kernel function with the following
+// parameters:
+//
+//   - N pointers to the buffers of each of the N parameters to the computation,
+//   - a pointer to the output buffer of the computation, and
+//   - a pointer to the top-level temp buffer.
+//
+class IrEmitterNested : public IrEmitter {
+ public:
+  // Constructs an LLVM IR emitter for a nested HLO computation. `function` is
+  // the containing IR function this emitter produces IR to. See
+  // IrEmitter::IrEmitter for the meanings of other arguments.
+  IrEmitterNested(const HloModuleConfig& hlo_module_config,
+                  const HloComputation& nested_computation,
+                  IrEmitterContext* ir_emitter_context);
+  IrEmitterNested(const IrEmitterNested&) = delete;
+  IrEmitterNested& operator=(const IrEmitterNested&) = delete;
+
+  // Overrides the default empty implementation. Binds the given instruction
+  // "parameter" with the parameter of the IR function.
+  Status HandleParameter(HloInstruction* parameter) override;
+
+  llvm::Function* GetEmittedFunction() const { return emitted_function_; }
+
+  Status EmitTargetElementLoop(
+      const HloInstruction& hlo,
+      const llvm_ir::ElementGenerator& body_emitter) override;
+
+ private:
+  llvm::Function* EmitBasePointersForNestedComputation(
+      const HloComputation& nested_computation,
+      std::vector<const HloInstruction*>* io_hlos);
+
+  llvm::Function* emitted_function_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_
index a4847f6..08fea34 100644 (file)
@@ -17,6 +17,8 @@ limitations under the License.
 #include <string>
 #include <vector>
 
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
+
 #include "llvm/ADT/StringRef.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Function.h"
@@ -40,7 +42,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
 #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
-#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
new file mode 100644 (file)
index 0000000..56ab820
--- /dev/null
@@ -0,0 +1,209 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
+
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+
+namespace xla {
+namespace gpu {
+
+// Emits LLVM IR for an "unnested computation".
+//
+// An unnested computation is an HloComputation which you run by executing one
+// or more kernels for each HloInstruction it contains.  Examples of unnested
+// computations:
+//
+//  - An HloModule's root computation,
+//  - The body of an HLO while loop,
+//  - The true/false computation of an HLO conditional.
+//
+// Note the opportunity for confusion -- the while loop's computation is nested
+// within the root computation, but it's emitted using IrEmitterUnnested!  Don't
+// think about it too hard.
+//
+// Examples of things that are not unnested computations:
+//
+//  - The reducer of a kReduce HLO.  This is emited using IrEmitterNested.
+//  - The body of a fusion node.  IrEmitterUnenested emits the relevant code
+//    within a kernel function using FusedIrEmitter.  (FusedIrEmitter is not
+//    really an IrEmitter, but is more an "IR generator generator".)
+//
+class IrEmitterUnnested : public IrEmitter {
+ public:
+  IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
+                    const HloComputation* hlo_computation,
+                    IrEmitterContext* ir_emitter_context);
+  IrEmitterUnnested(const IrEmitterUnnested&) = delete;
+  IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;
+
+  // Transfers the ownship of thunk_sequence_ out.
+  std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
+    return std::move(thunk_sequence_);
+  }
+
+  Status DefaultAction(HloInstruction* hlo) override;
+
+  // IrEmitterUnnested handles the following instructions differently from
+  // IrEmitter.
+  Status HandleCopy(HloInstruction* copy) override;
+  Status HandleConditional(HloInstruction* conditional) override;
+  Status HandleConvolution(HloInstruction* convolution) override;
+  Status HandleCustomCall(HloInstruction* custom_call) override;
+  Status HandleDot(HloInstruction* dot) override;
+  Status HandleFft(HloInstruction* fft) override;
+  Status HandleFusion(HloInstruction* fusion) override;
+  Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
+  Status HandleReduce(HloInstruction* reduce) override;
+  Status HandleSelectAndScatter(HloInstruction* instruction) override;
+  Status HandleTuple(HloInstruction* tuple) override;
+  Status HandleWhile(HloInstruction* xla_while) override;
+  Status HandleInfeed(HloInstruction* xla_infeed) override;
+  Status HandleRng(HloInstruction* random) override;
+  Status HandleSelect(HloInstruction* select) override;
+
+  Status EmitTargetElementLoop(
+      const HloInstruction& hlo,
+      const llvm_ir::ElementGenerator& body_emitter) override;
+
+  // Same as `EmitTargetElementLoop`, but in given `thunk` rather than
+  // `LastThunk()`.
+  Status EmitTargetElementLoopInThunk(
+      const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter,
+      KernelThunk* thunk);
+
+ private:
+  // Builds the appropriate thunk for the instruction hlo and returns the owning
+  // pointer to it. The caller needs to make sure `inst` outlives the lifetime
+  // of the returned Thunk object.
+  std::unique_ptr<Thunk> BuildThunk(const HloInstruction* hlo);
+
+  // Builds the prototype of the IR kernel for `inst` and adds it to the module.
+  llvm::Function* BuildKernelPrototype(
+      const HloInstruction& inst,
+      tensorflow::gtl::ArraySlice<const HloInstruction*> escaped_hlos);
+
+  // Emits the base pointers for `hlo` and its operands. `io_hlos` will store
+  // all input/output HLOs among `hlo` and its operands.
+  llvm::Function* EmitBasePointersForHloAndItsOperands(
+      const HloInstruction& hlo, std::vector<const HloInstruction*>* io_hlos);
+
+  // EmitColumnReduction and EmitRowReduction emit code for column and row
+  // reduction of a matrix and/or 3D tensor. Row and column reduction have
+  // different memory access pattern, so for performance their implementations
+  // are significantly different.
+  //
+  // Emits code that reduces a matrix of shape [height x width] to a vector of
+  // [width]. Other parameters have the same meaning as those of
+  // `EmitReductionToVector`. Note that input shape might not be
+  // [height x width], but can be bitcast to [height x weight] with "height"
+  // being the major dimension.
+  Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce,
+                             const Shape& input_shape,
+                             const llvm_ir::ElementGenerator& input_gen,
+                             const llvm_ir::ElementGenerator& init_value_gen,
+                             HloComputation* reducer);
+
+  // Emits code that reduces a 3D tensor of shape [depth x height x width] to a
+  // vector of shape [height]. Other parameters have the same meaning as those
+  // of `EmitReductionToVector`. Note that input shape might not be
+  // [depth x height x width], but can be bitcast to [depth x height x weight]
+  // with "depth" being the most major dimension.
+  Status EmitRowReduction(int64 depth, int64 height, int64 width,
+                          HloInstruction* reduce, const Shape& input_shape,
+                          const llvm_ir::ElementGenerator& input_gen,
+                          const llvm_ir::ElementGenerator& init_value_gen,
+                          HloComputation* reducer);
+
+  // Emits code that reduces a tensor of arbitrary rank to a scalar.
+  Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape,
+                               const llvm_ir::ElementGenerator& input_gen,
+                               const llvm_ir::ElementGenerator& init_value_gen,
+                               HloComputation* reducer);
+
+  // Figures out whether `reduce` is a row or column reduction, and which
+  // dimensions to reduce, and calls either `EmitRowReduction` or
+  // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the
+  // input array, which is the operand of the Reduce instruction if unfused or
+  // of the Fusion instruction if fused. `input_gen` and `init_value_gen`
+  // generate elements of the input and the initial value. Other parameters mean
+  // the same as for `HandleReduce`.
+  //
+  // Prerequisite: `IsReductionToVector(*reduce)`
+  Status EmitReductionToVector(
+      HloInstruction* reduce, const Shape& input_shape,
+      const llvm_ir::ElementGenerator& input_gen,
+      const llvm_ir::ElementGenerator& init_value_gen,
+      tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+      HloComputation* reducer);
+
+  // Emits code to initialize buffer of `inst` in given `thunk`.
+  Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk);
+
+  // Returns a KernelThunk that invokes the kernel emitted for `inst`. The
+  // caller needs to make sure `inst` outlives the lifetime of the returned
+  // Thunk object.
+  std::unique_ptr<Thunk> BuildKernelThunk(const HloInstruction* inst);
+
+  // Returns a FftThunk that calls cuFFT to implement `inst`.
+  std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
+
+  // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
+  // to make sure `inst` outlives the lifetime of the returned Thunk object.
+  std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);
+
+  // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
+  std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);
+
+  // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`.
+  std::unique_ptr<Thunk> BuildDeviceToDeviceCopyThunk(
+      const HloInstruction* inst);
+
+  // Returns an InfeedThunk that performs device-to-device memcpy to implement
+  // `inst`.
+  std::unique_ptr<Thunk> BuildInfeedThunk(const HloInstruction* inst);
+
+  // Returns a WhileThunk that invokes thunk sequences for 'condition' and
+  // 'body' sub-computations of while instruction 'hlo'.
+  std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);
+
+  // Returns a ForThunk which executes 'loop_limit' invocations of a thunk
+  // sequence from the 'body' sub-computation of the while instruction 'hlo'.
+  std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
+                                       const int64 loop_limit);
+
+  // Returns a ConditionalThunk that executes the thunk sequence for
+  // 'true_computation' or 'false_computation' depending on the value of the
+  // predicate in the given conditional instruction.
+  std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);
+
+  Status Postprocess(HloInstruction* hlo) override;
+
+  // Returns the last generated thunk.
+  Thunk* LastThunk() const { return thunk_sequence_->back().get(); }
+
+  // The thunk sequence this IrEmitter generates for the input computation.
+  std::unique_ptr<ThunkSequence> thunk_sequence_;
+
+  // The HloComputation that this IrEmitter emits code for.
+  const HloComputation* hlo_computation_;
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
index 242062e..b3b6026 100644 (file)
@@ -32,14 +32,23 @@ limitations under the License.
 
 namespace xla {
 
-// Unlike IrEmitter, this creates host functions which emit IR to generate the
-// output element at the given index. It is used to generate fused operations.
+// FusedIrEmitter is used to generate code for fusion nodes.
+//
+// Unlike IrEmitter and its ilk, which directly create LLVM IR in an LLVM
+// Module, FusedIrEmitter is better understood as "IR generator generator".
+// FusedIrEmitter recursively creates a generator (a host function) which the
+// compiler can invoke at a later time.  Invoking the generator emits LLVM IR
+// that, when run, produces the value at a particular index of the output.
+//
+// After building this generator, the compiler creates a loop (or its moral
+// equivalent, e.g. a GPU kernel) and calls the generator from within the loop.
+// This generates code that produces each element of the output.
 //
 // This class handles both vanilla fusion and multi-output fusion.  In the MOF
-// case, the fusion node ends with a kTuple instruction, and the root generator
-// returned by this emitter returns an LLVM struct with N elements, one for each
-// element of the arrays in the tuple.  It follows that the arrays in the tuple
-// must have the same length.
+// case, the fusion node ends with a kTuple instruction, and the generator
+// created produces an LLVM struct with N elements, one for each element of the
+// arrays in the tuple.  It follows that the arrays in the tuple must have the
+// same length.
 class FusedIrEmitter : public DfsHloVisitorWithDefault {
  public:
   using Generator = llvm_ir::ElementGenerator;