[XLA] Add a DeviceAllocator* argument to compilation.
authorJustin Lebar <jlebar@google.com>
Sat, 27 Jan 2018 01:12:23 +0000 (17:12 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 27 Jan 2018 01:16:06 +0000 (17:16 -0800)
In a later change, the GPU backend will use this allocator to reserve
scratch memory when trying out different convolution algorithms during
compilation.

PiperOrigin-RevId: 183469579

23 files changed:
tensorflow/compiler/jit/kernels/xla_launch_op.cc
tensorflow/compiler/jit/xla_compilation_cache.cc
tensorflow/compiler/tf2xla/xla_compiler.h
tensorflow/compiler/xla/client/local_client.cc
tensorflow/compiler/xla/client/local_client.h
tensorflow/compiler/xla/service/compiler.h
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
tensorflow/compiler/xla/service/cpu/cpu_compiler.h
tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
tensorflow/compiler/xla/service/gpu/gpu_compiler.h
tensorflow/compiler/xla/service/hlo_runner.cc
tensorflow/compiler/xla/service/interpreter/compiler.cc
tensorflow/compiler/xla/service/interpreter/compiler.h
tensorflow/compiler/xla/service/llvm_compiler.cc
tensorflow/compiler/xla/service/llvm_compiler.h
tensorflow/compiler/xla/service/local_service.cc
tensorflow/compiler/xla/service/local_service.h
tensorflow/compiler/xla/service/service.cc
tensorflow/compiler/xla/service/service.h
tensorflow/compiler/xla/tests/codegen_test_base.cc
tensorflow/compiler/xla/tests/llvm_compiler_test.cc
tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
tensorflow/compiler/xla/tools/dumped_computation_to_text.cc

index 4842877..1d7bd22 100644 (file)
@@ -248,12 +248,16 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
 
   xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
 
+  // Builds an XLA allocator for the device.
+  XlaAllocator xla_allocator(client->platform(), ctx);
+
   XlaCompiler::Options options;
   options.client = client;
   options.device_type = &cache->device_type();
   options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
   options.graph_def_version = ctx->function_library()->graph_def_version();
   options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
+  options.device_allocator = &xla_allocator;
 
   const XlaCompiler::CompilationResult* kernel;
   xla::LocalExecutable* executable;
@@ -264,9 +268,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
 
   VLOG(1) << "Executing XLA Computation...";
 
-  // Builds an XLA allocator for the device.
-  XlaAllocator xla_allocator(client->platform(), ctx);
-
   std::unique_ptr<xla::ShapedBuffer> output;
   // Build xla::ShapedBuffers that point directly to the Tensor buffers.
   std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
index bfff52c..21d3a54 100644 (file)
@@ -223,6 +223,7 @@ Status XlaCompilationCache::BuildExecutable(
   xla::ExecutableBuildOptions build_options;
   build_options.set_device_ordinal(client_->default_device_ordinal());
   build_options.set_result_layout(result.xla_output_shape);
+  build_options.set_device_allocator(options.device_allocator);
 
   auto compile_result =
       client_->Compile(*result.computation, argument_layouts, build_options);
index 6a46e54..30d3c05 100644 (file)
@@ -235,6 +235,19 @@ class XlaCompiler {
     // device is created, and can be used to create metadata objects
     // that can be accessed by XLA op kernels.
     std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
+
+    // If not nullptr, this memory allocator can be used by the compiler for
+    // temporary allocations it might want to make during compilation.
+    //
+    // For example, the compiler may want to try out different algorithms and
+    // choose the fastest one, and it might run those algorithms over buffers
+    // created using this allocator.
+    //
+    // The compiler can function correctly without an explicit allocator given
+    // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
+    // allocate most or all available memory on the device, leaving none for the
+    // compiler to access, unless it can use TensorFlow's allocator.
+    xla::DeviceMemoryAllocator* device_allocator = nullptr;
   };
 
   explicit XlaCompiler(Options options);
index fbeedfc..e45787f 100644 (file)
@@ -49,6 +49,16 @@ const Shape* ExecutableBuildOptions::result_layout() const {
   return result_layout_set_ ? &result_layout_ : nullptr;
 }
 
+ExecutableBuildOptions& ExecutableBuildOptions::set_device_allocator(
+    DeviceMemoryAllocator* allocator) {
+  device_allocator_ = allocator;
+  return *this;
+}
+
+DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const {
+  return device_allocator_;
+}
+
 namespace {
 StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal,
                                                    Backend* backend) {
@@ -270,10 +280,11 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
   int device_ordinal = options.device_ordinal() == -1
                            ? default_device_ordinal()
                            : options.device_ordinal();
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
-                      local_service_->CompileExecutable(
-                          computation.handle(), argument_layouts,
-                          options.result_layout(), device_ordinal));
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<Executable> executable,
+      local_service_->CompileExecutable(computation.handle(), argument_layouts,
+                                        options.result_layout(), device_ordinal,
+                                        options.device_allocator()));
   return WrapUnique(new LocalExecutable(std::move(executable),
                                         local_service_->mutable_backend(),
                                         device_ordinal, options));
index 19fd14f..843ad7a 100644 (file)
@@ -53,10 +53,22 @@ class ExecutableBuildOptions {
   ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout);
   const Shape* result_layout() const;
 
+  // If set, this specifies an allocator that can be used to allocate temporary
+  // space on the device during compilation.  For example, the compiler might
+  // want to run various algorithms on the device and pick the fastest one -- it
+  // might allocate buffers for use by these algorithms using this allocator.
+  //
+  // This does not need to be the same as the DeviceMemoryAllocator passed when
+  // running the executable.
+  ExecutableBuildOptions& set_device_allocator(
+      DeviceMemoryAllocator* allocator);
+  DeviceMemoryAllocator* device_allocator() const;
+
  private:
   int device_ordinal_ = -1;
   Shape result_layout_;
   bool result_layout_set_ = false;
+  DeviceMemoryAllocator* device_allocator_ = nullptr;
 };
 
 class LocalExecutable {
index fc67330..74fd24e 100644 (file)
@@ -72,8 +72,18 @@ class AotCompilationOptions {
   // Returns the ID of the platform to which these options apply.
   virtual perftools::gputools::Platform::Id PlatformId() const = 0;
 
+  // Optional allocator that may be used for allocating temp space on the device
+  // during compilation.
+  DeviceMemoryAllocator* device_allocator() const { return device_allocator_; }
+  void set_device_allocator(DeviceMemoryAllocator* device_allocator) {
+    device_allocator_ = device_allocator;
+  }
+
  protected:
   AotCompilationOptions() = default;
+
+ private:
+  DeviceMemoryAllocator* device_allocator_ = nullptr;
 };
 
 // Abstract compiler interface that is subclassed for compilation on a
@@ -99,9 +109,16 @@ class Compiler {
 
   // Runs Hlo passes to optimize the given Hlo module, returns the optimized
   // module.
+  //
+  // If device_allocator is not null, the compiler may use it to allocate temp
+  // space on the device for use during compilation.  For example, the compiler
+  // may allocate buffers on the device and then run variants of a given
+  // algorithm over those buffers, to see which variant is fastest.  Any space
+  // allocated should be deallocated before this function returns.
   virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
       std::unique_ptr<HloModule> module,
-      perftools::gputools::StreamExecutor* executor) = 0;
+      perftools::gputools::StreamExecutor* executor,
+      DeviceMemoryAllocator* device_allocator) = 0;
 
   // Compiles the HLO module for execution on a device given by the executor,
   // and returns an executable object or an error status. No HLO passes are
@@ -112,21 +129,27 @@ class Compiler {
   // The compiler may optionally specialize to the individual device
   // (not just type of device) indicated by the executor.
   //
+  // device_allocator is optional; see RunHloPasses.
+  //
   // Use the overload below to compile computations that run in parallel.
   virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
       std::unique_ptr<HloModule> module,
-      perftools::gputools::StreamExecutor* executor) = 0;
+      perftools::gputools::StreamExecutor* executor,
+      DeviceMemoryAllocator* device_allocator) = 0;
 
   // Compiles a set of HLO modules that can run in parallel, potentially
   // communicating data between the modules, and returns a corresponding
   // sequence of executable objects.
   //
+  // device_allocator is optional; see RunHloPasses.
+  //
   // TODO(b/68666782): Remove this method after adding support for multiple
   // modules to RunHloPasses and RunBackends.
   virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
       std::vector<std::unique_ptr<HloModule>> modules,
       std::vector<std::vector<perftools::gputools::StreamExecutor*>>
-          stream_exec) = 0;
+          stream_exec,
+      DeviceMemoryAllocator* device_allocator) = 0;
 
   // Compiles the HLO module for ahead-of-time execution.  This is intended for
   // use in static compilation.
index 33af77e..3fdb3d5 100644 (file)
@@ -437,7 +437,8 @@ Status VerifyLlvmModule(const llvm::Module& llvm_module) {
 
 StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
     std::unique_ptr<HloModule> module,
-    perftools::gputools::StreamExecutor* /*stream_exec*/) {
+    perftools::gputools::StreamExecutor* /*stream_exec*/,
+    DeviceMemoryAllocator* /*device_allocator*/) {
   VLOG(2) << "Before optimization:";
   XLA_VLOG_LINES(2, module->ToString());
 
@@ -450,7 +451,8 @@ StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
 
 StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
     std::unique_ptr<HloModule> module,
-    perftools::gputools::StreamExecutor* stream_exec) {
+    perftools::gputools::StreamExecutor* stream_exec,
+    DeviceMemoryAllocator* /*device_allocator*/) {
   const string timer_message =
       "Compiling [" + module->name() + "] for CPU using JIT";
   XLA_SCOPED_LOGGING_TIMER(timer_message);
index ebed705..3498139 100644 (file)
@@ -118,11 +118,13 @@ class CpuCompiler : public LLVMCompiler {
 
   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
       std::unique_ptr<HloModule> module,
-      perftools::gputools::StreamExecutor* stream_exec) override;
+      perftools::gputools::StreamExecutor* stream_exec,
+      DeviceMemoryAllocator* device_allocator) override;
 
   StatusOr<std::unique_ptr<Executable>> RunBackend(
       std::unique_ptr<HloModule> module,
-      perftools::gputools::StreamExecutor* stream_exec) override;
+      perftools::gputools::StreamExecutor* stream_exec,
+      DeviceMemoryAllocator* device_allocator) override;
 
   StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
   CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
index 0cca3ca..495ae17 100644 (file)
@@ -212,7 +212,9 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) {
 
 // Modifies the given HLO module so that it will be accepted by IrEmitter.
 // Unlike optimization passes, the passes are necessary for correctness.
-tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
+tensorflow::Status PrepareHloModuleForIrEmitting(
+    HloModule* hlo_module, se::StreamExecutor* stream_exec,
+    DeviceMemoryAllocator* /*device_allocator*/) {
   // In some cases, we have to place the result of an instruction in a temporary
   // buffer. For instance, the buffer that holds an external parameter is
   // assumed immutable at this point, and should not be reused for output
@@ -410,7 +412,8 @@ GpuCompiler::GpuCompiler()
                         .getPointerSize(0 /* default address space */)) {}
 
 StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
-    std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/) {
+    std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
+    DeviceMemoryAllocator* device_allocator) {
   XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
   Tracing::TraceMe annotation("HLO Transforms", module->name(),
                               /*is_expensive=*/true);
@@ -419,12 +422,14 @@ StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
 }
 
 StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
-    std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec) {
+    std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
+    DeviceMemoryAllocator* device_allocator) {
   XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
 
   TF_RET_CHECK(stream_exec != nullptr);
 
-  TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
+  TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get(), stream_exec,
+                                                   device_allocator));
 
   llvm::LLVMContext llvm_context;
   std::string buffer;
index 18e3434..c352d4d 100644 (file)
@@ -51,11 +51,13 @@ class GpuCompiler : public LLVMCompiler {
 
   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
       std::unique_ptr<HloModule> module,
-      perftools::gputools::StreamExecutor* stream_exec) override;
+      perftools::gputools::StreamExecutor* stream_exec,
+      DeviceMemoryAllocator* device_allocator) override;
 
   StatusOr<std::unique_ptr<Executable>> RunBackend(
       std::unique_ptr<HloModule> module,
-      perftools::gputools::StreamExecutor* stream_exec) override;
+      perftools::gputools::StreamExecutor* stream_exec,
+      DeviceMemoryAllocator* device_allocator) override;
 
   StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
   CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> module,
index 204a8bf..e281538 100644 (file)
@@ -121,12 +121,14 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::ExecuteInternal(
   if (run_hlo_passes) {
     TF_ASSIGN_OR_RETURN(
         module, backend().compiler()->RunHloPasses(
-                    std::move(module), backend().default_stream_executor()));
+                    std::move(module), backend().default_stream_executor(),
+                    /*device_allocator=*/nullptr));
   }
   TF_ASSIGN_OR_RETURN(
       std::unique_ptr<Executable> executable,
       backend().compiler()->RunBackend(std::move(module),
-                                       backend().default_stream_executor()));
+                                       backend().default_stream_executor(),
+                                       /*device_allocator=*/nullptr));
 
   se::Stream stream(backend().default_stream_executor());
   stream.Init();
index dc63a22..c83880e 100644 (file)
@@ -70,15 +70,16 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
 }
 
 StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
-    std::unique_ptr<HloModule> hlo_module,
-    se::StreamExecutor* /*stream_exec*/) {
+    std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/,
+    DeviceMemoryAllocator* /*device_allocator*/) {
   VLOG(1) << "Run hlo passes on graph " << hlo_module->name();
   TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
   return std::move(hlo_module);
 }
 
 StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
-    std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec) {
+    std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
+    DeviceMemoryAllocator* /*device_allocator*/) {
   TF_RET_CHECK(stream_exec != nullptr);
 
   VLOG(1) << "Run backend " << hlo_module->name();
@@ -96,7 +97,8 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
 
 StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
     std::vector<std::unique_ptr<HloModule>> /*hlo_modules*/,
-    std::vector<std::vector<se::StreamExecutor*>> /*stream_execs*/) {
+    std::vector<std::vector<se::StreamExecutor*>> /*stream_execs*/,
+    DeviceMemoryAllocator* /*device_allocator*/) {
   return tensorflow::errors::Unimplemented(
       "Compilation of multiple HLO modules is not supported on Interpreter.");
 }
index 278cf51..c8660c0 100644 (file)
@@ -45,16 +45,19 @@ class InterpreterCompiler : public Compiler {
 
   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
       std::unique_ptr<HloModule> hlo_module,
-      perftools::gputools::StreamExecutor* stream_exec) override;
+      perftools::gputools::StreamExecutor* stream_exec,
+      DeviceMemoryAllocator* device_allocator) override;
 
   StatusOr<std::unique_ptr<Executable>> RunBackend(
       std::unique_ptr<HloModule> hlo_module,
-      perftools::gputools::StreamExecutor* stream_exec) override;
+      perftools::gputools::StreamExecutor* stream_exec,
+      DeviceMemoryAllocator* device_allocator) override;
 
   StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
       std::vector<std::unique_ptr<HloModule>> hlo_modules,
       std::vector<std::vector<perftools::gputools::StreamExecutor*>>
-          stream_exec) override;
+          stream_exec,
+      DeviceMemoryAllocator* device_allocator) override;
 
   StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
   CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> hlo_modules,
index 34f3419..f98fc04 100644 (file)
@@ -18,8 +18,8 @@ limitations under the License.
 namespace xla {
 StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
     std::vector<std::unique_ptr<HloModule>> modules,
-    std::vector<std::vector<perftools::gputools::StreamExecutor*>>
-        stream_execs) {
+    std::vector<std::vector<perftools::gputools::StreamExecutor*>> stream_execs,
+    DeviceMemoryAllocator* device_allocator) {
   std::vector<std::unique_ptr<Executable>> result;
   for (size_t i = 0; i < modules.size(); i++) {
     if (stream_execs[i].size() != 1) {
@@ -27,10 +27,12 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
           "Model partitioning not implemented for the CPU/GPU compilers!");
     }
 
-    TF_ASSIGN_OR_RETURN(
-        modules[i], RunHloPasses(std::move(modules[i]), stream_execs[i][0]));
+    TF_ASSIGN_OR_RETURN(modules[i],
+                        RunHloPasses(std::move(modules[i]), stream_execs[i][0],
+                                     device_allocator));
     TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
-                        RunBackend(std::move(modules[i]), stream_execs[i][0]));
+                        RunBackend(std::move(modules[i]), stream_execs[i][0],
+                                   device_allocator));
     result.push_back(std::move(executable));
   }
 
index c5393ce..d74e81b 100644 (file)
@@ -60,17 +60,20 @@ class LLVMCompiler : public Compiler {
   // Bring in
   //   StatusOr<std::unique_ptr<Executable>> RunBackend(
   //       std::unique_ptr<HloModule> module,
-  //       perftools::gputools::StreamExecutor* stream_exec)
+  //       perftools::gputools::StreamExecutor* stream_exec,
+  //       DeviceMemoryAllocator* device_allocator)
   //   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
   //       std::unique_ptr<HloModule> module,
-  //       perftools::gputools::StreamExecutor* stream_exec)
+  //       perftools::gputools::StreamExecutor* stream_exec,
+  //       DeviceMemoryAllocator* device_allocator)
   using Compiler::RunBackend;
   using Compiler::RunHloPasses;
 
   StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
       std::vector<std::unique_ptr<HloModule>> modules,
       std::vector<std::vector<perftools::gputools::StreamExecutor*>>
-          stream_execs) override;
+          stream_execs,
+      DeviceMemoryAllocator* device_allocator) override;
 
  protected:
   ModuleHook user_pre_optimization_hook_;
index f30530d..bb9fd44 100644 (file)
@@ -71,7 +71,8 @@ LocalService::LocalService(const ServiceOptions& options,
 StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
     const ComputationHandle& computation,
     const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
-    const Shape* result_layout, int device_ordinal) {
+    const Shape* result_layout, int device_ordinal,
+    DeviceMemoryAllocator* device_allocator) {
   TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
                       computation_tracker_.Resolve(computation));
   VersionedComputationHandle versioned_handle =
@@ -135,7 +136,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
                       execute_backend_->stream_executor(device_ordinal));
 
   return BuildExecutable(versioned_handle, std::move(module_config),
-                         execute_backend_.get(), executor);
+                         execute_backend_.get(), executor, device_allocator);
 }
 
 StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
index acbc726..16c71b2 100644 (file)
@@ -41,11 +41,14 @@ class LocalService : public Service {
 
   // Builds an Executable with the given argument layouts and options. If
   // result_layout is non-null, then the executable is compiled to produce a
-  // result of the given layout.
+  // result of the given layout.  If device_allocator is non-null, then the
+  // compiler may use it to allocate temp space on the device.  The compiler is
+  // responsible for freeing any memory it allocates this way.
   StatusOr<std::unique_ptr<Executable>> CompileExecutable(
       const ComputationHandle& computation,
       const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
-      const Shape* result_layout, int device_ordinal);
+      const Shape* result_layout, int device_ordinal,
+      DeviceMemoryAllocator* device_allocator);
 
   // Returns the device ordinal that corresponds to the given replica number.
   //
index 849df1d..fea6956 100644 (file)
@@ -337,7 +337,8 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
     std::vector<VersionedComputationHandle> versioned_handles,
     std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
     Backend* backend,
-    std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors) {
+    std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors,
+    DeviceMemoryAllocator* device_allocator) {
   VLOG(1) << Printf("BuildExecutable on service %p", this);
 
   // Dump computation proto state if flag is set.
@@ -383,7 +384,8 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
 
   TF_ASSIGN_OR_RETURN(
       std::vector<std::unique_ptr<Executable>> executables,
-      backend->compiler()->Compile(std::move(modules), std::move(executors)));
+      backend->compiler()->Compile(std::move(modules), std::move(executors),
+                                   device_allocator));
 
   for (size_t i = 0; i < versioned_handles.size(); ++i) {
     if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) {
@@ -396,8 +398,8 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
 
 StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
     const VersionedComputationHandle& versioned_handle,
-    std::unique_ptr<HloModuleConfig> module_config,
-    Backend* backend, se::StreamExecutor* executor) {
+    std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
+    se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) {
   VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this,
                     versioned_handle.ToString().c_str());
 
@@ -430,11 +432,12 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
   TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
 
   TF_ASSIGN_OR_RETURN(
-      module, backend->compiler()->RunHloPasses(std::move(module), executor));
+      module, backend->compiler()->RunHloPasses(std::move(module), executor,
+                                                device_allocator));
 
-  TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<Executable> executable,
-      backend->compiler()->RunBackend(std::move(module), executor));
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
+                      backend->compiler()->RunBackend(
+                          std::move(module), executor, device_allocator));
 
   if (!other_directory_path.empty()) {
     executable->set_session_module(std::move(session_module));
@@ -445,9 +448,9 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
 
 StatusOr<std::shared_ptr<Executable>> Service::BuildAndCacheExecutable(
     const VersionedComputationHandle& versioned_handle,
-    std::unique_ptr<HloModuleConfig> module_config,
-    Backend* backend, perftools::gputools::StreamExecutor* executor,
-    ExecutionProfile* profile) {
+    std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
+    perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile,
+    DeviceMemoryAllocator* device_allocator) {
   std::shared_ptr<Executable> executable =
       compilation_cache_.LookUp(versioned_handle, *module_config);
 
@@ -469,7 +472,7 @@ StatusOr<std::shared_ptr<Executable>> Service::BuildAndCacheExecutable(
   TF_ASSIGN_OR_RETURN(
       std::unique_ptr<Executable> executable_unique_ptr,
       BuildExecutable(versioned_handle, std::move(module_config), backend,
-                      executor));
+                      executor, device_allocator));
 
   if (profile != nullptr) {
     uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@@ -771,10 +774,14 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
 
   // Build the user computations into HloModules and compile to generate the
   // executables.
+  //
+  // TODO(jlebar): There's currently no way to pass a device allocator to
+  // ExecuteParallel, so we have to pass a null device_allocator below.
   TF_ASSIGN_OR_RETURN(
       std::vector<std::unique_ptr<Executable>> executables,
       BuildExecutables(versioned_handles, std::move(module_configs),
-                       execute_backend_.get(), all_executors));
+                       execute_backend_.get(), all_executors,
+                       /*device_allocator=*/nullptr));
   std::vector<Executable*> executable_ptrs;
   executable_ptrs.reserve(executables.size());
   for (const auto& executable : executables) {
index ca77e8f..6ce2419 100644 (file)
@@ -280,10 +280,15 @@ class Service : public ServiceInterface {
       const UserComputation& user_computation);
 
   // Builds an Executable for the given parameters.
+  //
+  // If device_allocator is not null, the compiler may use it to allocate temp
+  // buffers, which the compiler is responsible for freeing.  The allocator
+  // given here need not match the allocator used when running the executable.
   StatusOr<std::unique_ptr<Executable>> BuildExecutable(
       const VersionedComputationHandle& versioned_handle,
-      std::unique_ptr<HloModuleConfig> module_config,
-      Backend* backend, perftools::gputools::StreamExecutor* executor);
+      std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
+      perftools::gputools::StreamExecutor* executor,
+      DeviceMemoryAllocator* device_allocator = nullptr);
 
   // Same as BuildExecutable() above, but builds a list of Executables for the
   // given computations that may interact with each other.
@@ -291,16 +296,17 @@ class Service : public ServiceInterface {
       std::vector<VersionedComputationHandle> versioned_handles,
       std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
       Backend* backend,
-      std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors);
+      std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors,
+      DeviceMemoryAllocator* device_allocator);
 
   // Similar to BuildExecutable, but look in the compilation cache for the
   // executable first. If the executable is not in the cache, it is built and
   // inserted into the cache.
   StatusOr<std::shared_ptr<Executable>> BuildAndCacheExecutable(
       const VersionedComputationHandle& versioned_handle,
-      std::unique_ptr<HloModuleConfig> module_config,
-      Backend* backend, perftools::gputools::StreamExecutor* executor,
-      ExecutionProfile* profile);
+      std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
+      perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile,
+      DeviceMemoryAllocator* device_allocator = nullptr);
 
   // Runs the given executable with the given arguments and register the result
   // in the allocation tracker. The handle of the result from the tracker is
index e472408..0226413 100644 (file)
@@ -21,9 +21,11 @@ StatusOr<std::unique_ptr<Executable>> CodegenTestBase::CompileToExecutable(
     std::unique_ptr<HloModule> hlo_module) {
   TF_ASSIGN_OR_RETURN(hlo_module, backend().compiler()->RunHloPasses(
                                       std::move(hlo_module),
-                                      backend().default_stream_executor()));
+                                      backend().default_stream_executor(),
+                                      /*device_allocator=*/nullptr));
   return backend().compiler()->RunBackend(std::move(hlo_module),
-                                          backend().default_stream_executor());
+                                          backend().default_stream_executor(),
+                                          /*device_allocator=*/nullptr);
 }
 
 StatusOr<std::unique_ptr<AotCompilationResult>>
index b5b9596..7e92439 100644 (file)
@@ -74,7 +74,8 @@ class LLVMCompilerTest : public ::testing::Test {
 
     ASSERT_TRUE(compiler
                     ->RunBackend(std::move(hlo_module),
-                                 backend_->default_stream_executor())
+                                 backend_->default_stream_executor(),
+                                 /*device_allocator=*/nullptr)
                     .ok());
 
     // Test that hooks were called.
@@ -98,7 +99,8 @@ class LLVMCompilerTest : public ::testing::Test {
     executors.push_back({backend_->default_stream_executor()});
     executors.push_back({backend_->default_stream_executor()});
 
-    EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors)));
+    EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors),
+                                   /*device_allocator=*/nullptr));
   }
 
  private:
index 5ede37b..4ad356d 100644 (file)
@@ -86,9 +86,9 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
       layouts.push_back(&program_shape->parameters(i));
     }
     StatusOr<std::unique_ptr<Executable>> executable =
-        local_service->CompileExecutable(computation.handle(), layouts,
-                                         &program_shape->result(),
-                                         /*device_ordinal=*/0);
+        local_service->CompileExecutable(
+            computation.handle(), layouts, &program_shape->result(),
+            /*device_ordinal=*/0, /*device_allocator=*/nullptr);
 
     const HloModule& module = executable.ValueOrDie()->module();
 
index 24417a0..5ebb75a 100644 (file)
@@ -61,9 +61,9 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
         layouts.push_back(&program_shape->parameters(i));
       }
       StatusOr<std::unique_ptr<Executable>> executable =
-          local_service->CompileExecutable(computation.handle(), layouts,
-                                           &program_shape->result(),
-                                           /*device_ordinal=*/0);
+          local_service->CompileExecutable(
+              computation.handle(), layouts, &program_shape->result(),
+              /*device_ordinal=*/0, /*device_allocator=*/nullptr);
 
       const HloModule& module = executable.ValueOrDie()->module();