[XLA] Make XLA's memory allocator return an owning smart pointer.
authorJustin Lebar <jlebar@google.com>
Wed, 9 May 2018 18:22:31 +0000 (11:22 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 18:34:00 +0000 (11:34 -0700)
Previously, xla::DeviceMemoryAllocator::Allocate returned a
stream_executor::DeviceMemoryBase.  This is morally equivalent to a raw
pointer: It's on you the user to call Deallocate().

Unfortunately we ~never got this right.  Essentially all users of
Allocate() call it in a loop, and TF_RETURN_IF_ERROR within the loop.
If any of these allocations fails (mostly commonly, due to OOM), we leak
everything we've allocated up until then.

This patch changes our API so that it returns an owning pointer.  Now
things mostly Just Work.

Also worth calling out: The lambda in CpuExecutable::ExecuteOnStream
passed to ExecuteComputeFunction almost certainly had multithreaded
use-after-free bugs.  This patch fixes them.

PiperOrigin-RevId: 196000535

27 files changed:
tensorflow/compiler/jit/BUILD
tensorflow/compiler/jit/xla_launch_util.cc
tensorflow/compiler/jit/xla_launch_util.h
tensorflow/compiler/jit/xla_launch_util_test.cc
tensorflow/compiler/jit/xla_tensor.cc
tensorflow/compiler/xla/map_util.h
tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/allocation_tracker.cc
tensorflow/compiler/xla/service/allocation_tracker.h
tensorflow/compiler/xla/service/cpu/cpu_executable.cc
tensorflow/compiler/xla/service/cpu/cpu_executable.h
tensorflow/compiler/xla/service/device_memory_allocator.cc
tensorflow/compiler/xla/service/device_memory_allocator.h
tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
tensorflow/compiler/xla/service/gpu/buffer_allocations.h
tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
tensorflow/compiler/xla/service/gpu/fft_thunk.cc
tensorflow/compiler/xla/service/gpu/fft_thunk.h
tensorflow/compiler/xla/service/gpu/gpu_executable.cc
tensorflow/compiler/xla/service/owning_device_memory.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/owning_device_memory.h [new file with mode: 0644]
tensorflow/compiler/xla/service/shaped_buffer.cc
tensorflow/compiler/xla/service/shaped_buffer.h
tensorflow/compiler/xla/service/transfer_manager.cc
tensorflow/compiler/xla/tests/local_client_test_base.cc
tensorflow/compiler/xla/tests/local_client_test_base.h
tensorflow/stream_executor/stream_executor_pimpl.h

index a6b3ce3..a6d0408 100644 (file)
@@ -217,6 +217,7 @@ cc_library(
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla/client:client_library",
         "//tensorflow/compiler/xla/client:local_client",
+        "//tensorflow/compiler/xla/service:device_memory_allocator",
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:gpu_runtime",
index e12e88f..6a0f557 100644 (file)
@@ -60,7 +60,7 @@ XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
 
 XlaAllocator::~XlaAllocator() {}
 
-xla::StatusOr<se::DeviceMemoryBase> XlaAllocator::Allocate(
+xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
     int device_ordinal, uint64 size, bool retry_on_failure) {
   AllocationAttributes attrs;
   attrs.no_retry_on_failure = !retry_on_failure;
@@ -69,13 +69,13 @@ xla::StatusOr<se::DeviceMemoryBase> XlaAllocator::Allocate(
   if (data == nullptr) {
     return errors::ResourceExhausted("Out of memory while trying to allocate ",
                                      size, " bytes.");
-  } else {
-    return se::DeviceMemoryBase(data, size);
   }
+  return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
+                                 device_ordinal, this);
 }
 
-Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) {
-  wrapped_->DeallocateRaw(mem->opaque());
+Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
+  wrapped_->DeallocateRaw(mem.opaque());
   return Status::OK();
 }
 
@@ -241,7 +241,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
       } else {
         Tensor output_tensor = XlaTensorBuffer::MakeTensor(
             ctx->expected_output_dtype(i), shape, buffer, allocator);
-        output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num});
+        output.set_buffer(xla::OwningDeviceMemory(), {output_num});
         ctx->set_output(i, output_tensor);
       }
       ++output_num;
@@ -291,7 +291,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
     } else {
       Tensor output_tensor = XlaTensorBuffer::MakeTensor(
           write.type, write.shape, buffer, allocator);
-      output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num});
+      output.set_buffer(xla::OwningDeviceMemory(), {output_num});
       *variable->tensor() = output_tensor;
     }
     ++output_num;
index a243125..4390701 100644 (file)
@@ -22,6 +22,8 @@ limitations under the License.
 #include "tensorflow/compiler/jit/xla_tensor.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/owning_device_memory.h"
 #include "tensorflow/core/framework/allocation_description.pb.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/types.h"
@@ -50,9 +52,9 @@ class XlaAllocator : public xla::DeviceMemoryAllocator {
  public:
   XlaAllocator(const se::Platform* platform, Allocator* wrapped);
   ~XlaAllocator() override;
-  xla::StatusOr<se::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size,
-                                               bool retry_on_failure) override;
-  Status Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) override;
+  xla::StatusOr<xla::OwningDeviceMemory> Allocate(
+      int device_ordinal, uint64 size, bool retry_on_failure) override;
+  Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
 
   // The Tensorflow BFC allocator used on GPU allows host-side deallocation
   // before GPU execution takes place. Tensorflow uses the ordering of the main
index 27813ef..a459324 100644 (file)
@@ -36,9 +36,9 @@ void BM_ExtractSubBuffer(int iters, int depth, int fan_out) {
   for (int i = 0; i < iters; ++i) {
     // Extract a buffer from approximately the middle of the first level of the
     // tree.
-    tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer,
-                                                 /*index=*/fan_out / 2,
-                                                 /*allocator=*/nullptr)
+    (void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer,
+                                                       /*index=*/fan_out / 2,
+                                                       /*allocator=*/nullptr)
         .release();
   }
 }
index ce64568..a7211c9 100644 (file)
@@ -52,20 +52,22 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
       client->backend().transfer_manager()->HostShapeToDeviceShape(
           on_host_shape);
 
-  xla::ShapedBuffer buffer(on_host_shape, on_device_shape, client->platform(),
-                           device_ordinal);
-  for (auto& index_to_buffer : buffer.buffers()) {
+  xla::ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape,
+                                        client->backend().memory_allocator(),
+                                        device_ordinal);
+  for (auto& index_to_buffer : shaped_buffer.buffers()) {
     xla::Shape subshape =
         xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
     uint64 size =
         client->backend().transfer_manager()->GetByteSizeRequirement(subshape);
-    TF_ASSIGN_OR_RETURN(index_to_buffer.second,
+    TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
                         client->backend().memory_allocator()->Allocate(
                             device_ordinal, size, /*retry_on_failure=*/false));
+    // Move our buffer into shaped_buffer, which takes ownership of it.
+    index_to_buffer.second = buffer.Forget();
   }
 
-  set_shaped_buffer(xla::ScopedShapedBuffer(
-      std::move(buffer), client->backend().memory_allocator()));
+  set_shaped_buffer(std::move(shaped_buffer));
   return Status::OK();
 }
 
index 8db8c6f..3c74e07 100644 (file)
@@ -86,11 +86,10 @@ const typename Collection::value_type::second_type& FindOrDefault(
 
 // Inserts the key-value pair into the collection. Dies if key was already
 // present.
-template <class Collection>
-void InsertOrDie(Collection* const collection,
-                 const typename Collection::value_type::first_type& key,
-                 const typename Collection::value_type::second_type& data) {
-  auto p = collection->insert(std::make_pair(key, data));
+template <class Collection, class Key, class Value>
+void InsertOrDie(Collection* const collection, Key&& key, Value&& value) {
+  auto p = collection->insert(
+      std::make_pair(std::forward<Key>(key), std::forward<Value>(value)));
   CHECK(p.second) << "duplicate key: " << key;
 }
 
@@ -101,9 +100,10 @@ bool ContainsKey(const Collection& collection, const Key& key) {
 }
 
 // Inserts `value` into `set`. Dies if it was already present.
-template <class Set>
-void InsertOrDie(Set* const set, const typename Set::value_type& value) {
-  CHECK(set->insert(value).second) << "duplicate value: " << value;
+template <class Set, class Value>
+void InsertOrDie(Set* const set, Value&& value) {
+  CHECK(set->insert(std::forward<Value>(value)).second)
+      << "duplicate value: " << value;
 }
 
 }  // namespace xla
index aa3a626..fecc257 100644 (file)
@@ -2316,8 +2316,14 @@ tf_cc_test(
 
 cc_library(
     name = "device_memory_allocator",
-    srcs = ["device_memory_allocator.cc"],
-    hdrs = ["device_memory_allocator.h"],
+    srcs = [
+        "device_memory_allocator.cc",
+        "owning_device_memory.cc",
+    ],
+    hdrs = [
+        "device_memory_allocator.h",
+        "owning_device_memory.h",
+    ],
     deps = [
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:statusor",
index cf1231b..eb52803 100644 (file)
@@ -220,8 +220,10 @@ void AllocationTracker::AddAllocationOrIncrementRefCount(
   AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
   auto it = allocation_map.find(device_memory.opaque());
   if (it == allocation_map.end()) {
-    allocation_map[device_memory.opaque()] = {device_memory, device_ordinal,
-                                              /*ref_count=*/1};
+    allocation_map[device_memory.opaque()] = {
+        OwningDeviceMemory(device_memory, device_ordinal,
+                           backend_->memory_allocator()),
+        /*ref_count=*/1};
   } else {
     it->second.ref_count++;
   }
@@ -235,8 +237,7 @@ Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory,
   Allocation& allocation = it->second;
   TF_RET_CHECK(allocation.ref_count >= 1);
   if (allocation.ref_count == 1) {
-    TF_RETURN_IF_ERROR(backend_->memory_allocator()->Deallocate(
-        device_ordinal, &device_memory));
+    allocation.device_memory.Free();
     allocation_map.erase(it);
   } else {
     allocation.ref_count--;
index 1174fa6..a7d8927 100644 (file)
@@ -76,10 +76,7 @@ class AllocationTracker {
   // Data structure encapsulating single memory allocation on the device.
   struct Allocation {
     // The pointer to this allocation.
-    se::DeviceMemoryBase device_memory;
-
-    // The device that the memory is allocated on.
-    int device_ordinal;
+    OwningDeviceMemory device_memory;
 
     // This is the number of times this memory allocation is referred to by
     // registered data handles.
@@ -126,7 +123,10 @@ class AllocationTracker {
   int64 next_handle_ GUARDED_BY(mutex_);
 
   // A map from device ordinal to AllocationMap.
-  tensorflow::gtl::FlatMap<int, AllocationMap> opaque_to_allocation_map_
+  //
+  // This is not a TF FlatMap because (currently) FlatMap (and therefore
+  // AllocationMap) is not movable.
+  std::unordered_map<int, AllocationMap> opaque_to_allocation_map_
       GUARDED_BY(mutex_);
 
   // A map from data handle to a vector of shaped buffers that represent the
index 32613b8..cf43b74 100644 (file)
@@ -73,7 +73,7 @@ CpuExecutable::CpuExecutable(
 
 Status CpuExecutable::AllocateBuffers(
     DeviceMemoryAllocator* memory_allocator, int device_ordinal,
-    std::vector<se::DeviceMemoryBase>* buffers) {
+    std::vector<OwningDeviceMemory>* buffers) {
   CHECK_EQ(buffers->size(), assignment_->Allocations().size());
   VLOG(3) << "Allocating " << assignment_->Allocations().size()
           << " allocations for module " << module().name();
@@ -201,60 +201,18 @@ Status CpuExecutable::ExecuteComputeFunction(
   return Status::OK();
 }
 
-static void LogLiveAddresses(
-    tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
-    const std::vector<bool>& buffers_in_result) {
-  if (!VLOG_IS_ON(3)) {
-    return;
-  }
-
-  CHECK_EQ(buffers.size(), buffers_in_result.size());
-  std::vector<const void*> live_out_buffers;
-  for (int i = 0; i < buffers.size(); ++i) {
-    if (buffers_in_result[i]) {
-      live_out_buffers.push_back(buffers[i].opaque());
-    }
-  }
-  VLOG(3) << "Live addresses in output marking found "
-          << live_out_buffers.size() << " addresses:\n"
-          << tensorflow::str_util::Join(
-                 live_out_buffers, ", ", [](string* out, const void* address) {
-                   tensorflow::strings::StrAppend(
-                       out, tensorflow::strings::Printf("%p", address));
-                 });
-}
-
-static Status DeallocateTempBuffers(
-    DeviceMemoryAllocator* allocator, se::Stream* stream,
-    tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
-    const std::vector<bool>& buffers_in_result) {
-  // Keep those buffers in the output of the marked live because they are needed
-  // by the service. They will be deallocated by the service.
-  for (size_t i = 0; i < buffers.size(); ++i) {
-    se::DeviceMemoryBase alloc = buffers[i];
-    if (!buffers_in_result[i] && !alloc.is_null()) {
-      VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
-              << alloc.opaque() << "]";
-      TF_RETURN_IF_ERROR(
-          allocator->Deallocate(stream->parent()->device_ordinal(), &alloc));
-    }
-  }
-
-  return Status::OK();
-}
-
 StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
     const ServiceExecutableRunOptions* run_options,
-    tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> allocated_buffers,
-    std::vector<bool>* buffers_in_result) {
+    tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers) {
   se::Stream* stream = run_options->stream();
   ScopedShapedBuffer result_buffer(
       /*on_host_shape=*/host_result_shape(),
       /*on_device_shape=*/host_result_shape(), run_options->allocator(),
       stream->parent()->device_ordinal());
 
-  // Copy DeviceMemoryBase values which contain the array(s) of the result into
-  // the respective location in ShapedBuffer which is returned to the caller.
+  // Move OwningDeviceMemory values which contain the array(s) of the result
+  // into the respective location in ScopedShapedBuffer which is returned to the
+  // caller.
   TF_RETURN_IF_ERROR(result_buffer.buffers().ForEachMutableElementWithStatus(
       [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) {
         const auto& sources = this->GetRootPointsToSet().element(index);
@@ -273,10 +231,9 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
         CHECK(!slice.allocation()->is_entry_computation_parameter());
 
         const BufferAllocation::Index buffer_index = slice.index();
-        const se::DeviceMemoryBase& buffer = allocated_buffers[buffer_index];
+        OwningDeviceMemory& buffer = buffers[buffer_index];
         CHECK(!buffer.is_null() || buffer.size() == 0);
-        *device_memory = buffer;
-        (*buffers_in_result)[buffer_index] = true;
+        *device_memory = buffer.Forget();
         return Status::OK();
       }));
   return std::move(result_buffer);
@@ -292,23 +249,21 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
 
   se::Stream* stream = run_options->stream();
   DeviceMemoryAllocator* memory_allocator = run_options->allocator();
-  std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
+  std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size());
 
   TF_RETURN_IF_ERROR(AllocateBuffers(
       memory_allocator, stream->parent()->device_ordinal(), &buffers));
-  TF_RETURN_IF_ERROR(ExecuteComputeFunction(
-      &run_options->run_options(), arguments, buffers, hlo_execution_profile));
 
-  std::vector<bool> buffers_in_result(assignment_->Allocations().size(), false);
-  TF_ASSIGN_OR_RETURN(
-      ScopedShapedBuffer result_buffer,
-      CreateResultShapedBuffer(run_options, buffers, &buffers_in_result));
-
-  // Free all buffers not in the result.
-  TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers,
-                                           buffers_in_result));
+  std::vector<se::DeviceMemoryBase> unowning_buffers;
+  unowning_buffers.reserve(buffers.size());
+  for (auto& buffer : buffers) {
+    unowning_buffers.push_back(buffer.AsDeviceMemoryBase());
+  }
+  TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(),
+                                            arguments, unowning_buffers,
+                                            hlo_execution_profile));
 
-  return std::move(result_buffer);
+  return CreateResultShapedBuffer(run_options, &buffers);
 }
 
 StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
@@ -324,30 +279,53 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
       run_options->stream()->implementation());
   se::Stream* stream = run_options->stream();
   DeviceMemoryAllocator* memory_allocator = run_options->allocator();
-  std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
-
+  std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size());
   TF_RETURN_IF_ERROR(AllocateBuffers(
       memory_allocator, stream->parent()->device_ordinal(), &buffers));
 
-  std::vector<bool> buffers_in_result(assignment_->Allocations().size(), false);
-  TF_ASSIGN_OR_RETURN(
-      ScopedShapedBuffer result_buffer,
-      CreateResultShapedBuffer(run_options, buffers, &buffers_in_result));
-
-  LogLiveAddresses(buffers, buffers_in_result);
-
-  host_stream->EnqueueTask([this, run_options, arguments, buffers,
-                            buffers_in_result, memory_allocator, stream]() {
-    // Failing a CHECK here is not great, but I don't see an obvious way to
-    // return a failed Status asynchronously.
-    TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments,
-                                       buffers,
-                                       /*hlo_execution_profile=*/nullptr));
-    TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers,
-                                      buffers_in_result));
-  });
+  std::vector<se::DeviceMemoryBase> unowning_buffers;
+  unowning_buffers.reserve(buffers.size());
+  for (auto& buffer : buffers) {
+    unowning_buffers.push_back(buffer.AsDeviceMemoryBase());
+  }
+  TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
+                      CreateResultShapedBuffer(run_options, &buffers));
 
-  return std::move(result_buffer);
+  // At this point, `unowning_buffers` contains unowning pointers to all of our
+  // buffers, and `buffers` contains owning pointers to the non-live-out
+  // buffers.  Enqueue a task which keeps alive the non-live-out buffers.
+  //
+  // Logically we want this lambda to capture `buffers` by move, ultimately our
+  // functor needs to be wrapped in an std::function, and that requires its
+  // functor to be copyable.  Thus we perpitrate the hack of capturing buffers
+  // "by shared pointer".
+  //
+  // We also need to change the types of some of the variables we capture:
+  // run_options needs to change from a pointer to a value type, and arguments
+  // needs to change from an ArraySlice into a vector.  We use a struct instead
+  // of a lambda to make this explicit.
+  struct AsyncRunTask {
+    CpuExecutable* executable;
+    ServiceExecutableRunOptions run_options;
+    std::vector<const ShapedBuffer*> arguments;
+    std::vector<se::DeviceMemoryBase> unowning_buffers;
+    std::shared_ptr<std::vector<OwningDeviceMemory>> buffers;
+
+    void operator()() {
+      // Failing a CHECK here is not great, but I don't see an obvious way to
+      // return a failed Status asynchronously.
+      TF_CHECK_OK(executable->ExecuteComputeFunction(
+          &run_options.run_options(), arguments, unowning_buffers,
+          /*hlo_execution_profile=*/nullptr));
+    }
+  };
+  host_stream->EnqueueTask(AsyncRunTask{
+      this, *run_options,
+      std::vector<const ShapedBuffer*>(arguments.begin(), arguments.end()),
+      unowning_buffers,
+      std::make_shared<std::vector<OwningDeviceMemory>>(std::move(buffers))});
+
+  return std::move(result);
 }
 
 /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {
index 68ad38c..8dd47bf 100644 (file)
@@ -92,7 +92,7 @@ class CpuExecutable : public Executable {
   // buffer is assigned for this element.
   Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator,
                          int device_ordinal,
-                         std::vector<se::DeviceMemoryBase>* buffers);
+                         std::vector<OwningDeviceMemory>* buffers);
 
   // Calls the generated function performing the computation with the given
   // arguments using the supplied buffers.
@@ -102,16 +102,12 @@ class CpuExecutable : public Executable {
       tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
       HloExecutionProfile* hlo_execution_profile);
 
-  // Creates a ScopedShapedBuffer for holding the result of the computation. The
-  // addresses (DeviceMemoryBases) are set according to buffer assignment.
-  // 'buffers_in_result' should point to a vector of the same size as
-  // 'allocated_buffers'. An element in buffers_in_result is set to true if the
-  // corresponding buffer is live out of the computation (and thus contained in
-  // the returned ShapedBuffer).
+  // Creates a ScopedShapedBuffer for holding the result of the computation,
+  // moving buffers out of allocated_buffers and into the result as appropriate.
+  // The addresses are set according to buffer assignment.
   StatusOr<ScopedShapedBuffer> CreateResultShapedBuffer(
       const ServiceExecutableRunOptions* run_options,
-      tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> allocated_buffers,
-      std::vector<bool>* buffers_in_result);
+      tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers);
 
   // Returns the points-to set of the root instruction of the entry
   // computation. Uses points-to analysis from buffer assignment.
index 35db4fd..e228bb5 100644 (file)
@@ -29,7 +29,7 @@ StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
     : DeviceMemoryAllocator(platform),
       stream_executors_(stream_executors.begin(), stream_executors.end()) {}
 
-StatusOr<se::DeviceMemoryBase> StreamExecutorMemoryAllocator::Allocate(
+StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
     int device_ordinal, uint64 size, bool retry_on_failure) {
   TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor,
                       GetStreamExecutor(device_ordinal));
@@ -40,22 +40,17 @@ StatusOr<se::DeviceMemoryBase> StreamExecutorMemoryAllocator::Allocate(
         tensorflow::strings::HumanReadableNumBytes(size).c_str(), size,
         device_ordinal);
   }
-  return result;
+  return OwningDeviceMemory(result, device_ordinal, this);
 }
 
-tensorflow::Status StreamExecutorMemoryAllocator::Deallocate(
-    int device_ordinal, se::DeviceMemoryBase* mem) {
-  if (!mem->is_null()) {
+Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
+                                                 se::DeviceMemoryBase mem) {
+  if (!mem.is_null()) {
     TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor,
                         GetStreamExecutor(device_ordinal));
-    // We make a local copy of 'mem' so the original is not zeroed out by the
-    // Deallocate() call below. This gives us a better chance of
-    // catching double-free bugs, since Deallocate silently succeeds for null
-    // values.
-    se::DeviceMemoryBase mem_copy(*mem);
-    stream_executor->Deallocate(&mem_copy);
+    stream_executor->Deallocate(&mem);
   }
-  return tensorflow::Status::OK();
+  return Status::OK();
 }
 
 StatusOr<se::StreamExecutor*> StreamExecutorMemoryAllocator::GetStreamExecutor(
index da45c4d..5feb650 100644 (file)
@@ -18,6 +18,7 @@ limitations under the License.
 
 #include <vector>
 
+#include "tensorflow/compiler/xla/service/owning_device_memory.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
@@ -37,28 +38,30 @@ class DeviceMemoryAllocator {
       : platform_(platform) {}
   virtual ~DeviceMemoryAllocator() {}
 
+  // Allocates memory on the device.
+  //
+  // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory
+  // must not be null.  If size == 0, must return a null OwningDeviceMemory.
+  //
   // 'retry_on_failure': If false, and the first attempt to allocate the memory
   // fails, the allocation should return immediately without retrying.  An
   // example use case is optional scratch spaces where a failure has only
   // performance impact.
-  //
-  // Allocate() should return a null pointer for a size-0 allocation.
-  // Deallocate() must be a no-op for null pointers.
-  virtual StatusOr<se::DeviceMemoryBase> Allocate(int device_ordinal,
-                                                  uint64 size,
-                                                  bool retry_on_failure) = 0;
+  virtual StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
+                                                bool retry_on_failure) = 0;
 
   // Two-arg version of Allocate(), which sets retry-on-failure to true.
   //
   // (We don't simply use a default argument on the virtual Allocate function
   // because default args on virtual functions are disallowed by the Google
   // style guide.)
-  StatusOr<se::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size) {
+  StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size) {
     return Allocate(device_ordinal, size, /*retry_on_failure=*/true);
   }
 
+  // Must be a nop for null pointers.
   virtual tensorflow::Status Deallocate(int device_ordinal,
-                                        se::DeviceMemoryBase* mem) = 0;
+                                        se::DeviceMemoryBase mem) = 0;
 
   // Return the platform that the allocator allocates memory on.
   const se::Platform* platform() const { return platform_; }
@@ -68,6 +71,7 @@ class DeviceMemoryAllocator {
   virtual bool AllowsAsynchronousDeallocation() const = 0;
 
  protected:
+  friend class OwningDeviceMemory;
   const se::Platform* platform_;
 };
 
@@ -79,14 +83,14 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
       const se::Platform* platform,
       tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors);
 
-  StatusOr<se::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size,
-                                          bool retry_on_failure) override;
+  StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
+                                        bool retry_on_failure) override;
 
   // Pull in two-arg overload that sets retry_on_failure to true.
   using DeviceMemoryAllocator::Allocate;
 
   tensorflow::Status Deallocate(int device_ordinal,
-                                se::DeviceMemoryBase* mem) override;
+                                se::DeviceMemoryBase mem) override;
 
   bool AllowsAsynchronousDeallocation() const override;
 
index 837f052..cb66d37 100644 (file)
@@ -37,11 +37,11 @@ void BufferAllocations::Builder::RegisterBuffer(BufferAllocation::Index index,
 }
 
 StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
-    const BufferAssignment& buffer_assignment, int device_ordinal,
+    const BufferAssignment* buffer_assignment, int device_ordinal,
     DeviceMemoryAllocator* memory_allocator) {
-  const int64 num_buffers = buffer_assignment.Allocations().size();
-  auto buffer_allocations = WrapUnique(
-      new BufferAllocations(num_buffers, device_ordinal, memory_allocator));
+  const int64 num_buffers = buffer_assignment->Allocations().size();
+  auto buffer_allocations = WrapUnique(new BufferAllocations(
+      num_buffers, device_ordinal, memory_allocator, buffer_assignment));
 
   for (BufferAllocation::Index i = 0; i < num_buffers; ++i) {
     // If buffer #i's address is already registered (e.g. external arguments or
@@ -62,28 +62,28 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
 
     // Allocate each allocation that might escape, or is the temp buffer.
     bool seen_temp_buffer = false;
-    const BufferAllocation& allocation = buffer_assignment.GetAllocation(i);
+    const BufferAllocation& allocation = buffer_assignment->GetAllocation(i);
     if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) {
       const int64 buffer_size = allocation.size();
       se::DeviceMemoryBase buffer_address;
       if (buffer_size > 0) {
-        TF_ASSIGN_OR_RETURN(buffer_address, memory_allocator->Allocate(
-                                                device_ordinal, buffer_size));
-        if (buffer_address == nullptr) {
-          return ResourceExhausted(
-              "Out of memory when allocating %s for buffer %lld.",
-              tensorflow::strings::HumanReadableNumBytes(buffer_size).c_str(),
-              i);
-        }
-        if (reinterpret_cast<uintptr_t>(buffer_address.opaque()) %
+        OwningDeviceMemory buffer;
+        TF_ASSIGN_OR_RETURN(
+            buffer, memory_allocator->Allocate(device_ordinal, buffer_size));
+        if (reinterpret_cast<uintptr_t>(buffer.opaque()) %
                 kCudaMallocAlignBytes !=
             0) {
           return InternalError(
               "Address returned by memory_allocator->Allocate must be a "
               "multiple of %llx, but was %p",
-              kCudaMallocAlignBytes, buffer_address.opaque());
+              kCudaMallocAlignBytes, buffer.opaque());
         }
+        // We do manual memory management within BufferAllocations.  Be sure not
+        // to do a TF_RETURN_IF_ERROR between this line and the
+        // buffer_allocations->SetBuffer(buffer_address) call below!
+        buffer_address = buffer.Forget();
       }
+
       buffer_allocations->SetBuffer(i, buffer_address);
       if (allocation.IsPreallocatedTempBuffer()) {
         if (seen_temp_buffer) {
@@ -103,28 +103,42 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
               << "B)";
     }
   }
-
   return std::move(buffer_allocations);
 }
 
+BufferAllocations::~BufferAllocations() {
+  if (!torn_down_) {
+    // Presumably if we're executing this branch, the caller is in an error
+    // state, otherwise it would have explicitly called TearDown so it could
+    // save some set of live addresses.  So ignoring any errors in TearDown is
+    // sensible.
+    TearDown(/*live_addresses=*/{}).IgnoreError();
+  }
+}
+
 tensorflow::Status BufferAllocations::TearDown(
-    const std::set<se::DeviceMemoryBase>& live_addresses,
-    const BufferAssignment& buffer_assignment) {
-  // Deallocate temporary buffers.
-  const int64 num_buffers = buffer_assignment.Allocations().size();
+    const std::set<se::DeviceMemoryBase>& live_addresses) {
+  // Deallocate temporary buffers, taking care to try to deallocate all of them
+  // even if one of the deallocations fails.
+  Status status;
+  const int64 num_buffers = buffer_assignment_->Allocations().size();
   for (BufferAllocation::Index i = 0; i < num_buffers; ++i) {
-    const BufferAllocation& allocation = buffer_assignment.GetAllocation(i);
+    const BufferAllocation& allocation = buffer_assignment_->GetAllocation(i);
     se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index());
     // Deallocate buffers marked "maybe_live_out" but aren't actually live out,
     // and temp buffers.
     if ((allocation.maybe_live_out() &&
          !live_addresses.count(buffer_address)) ||
         allocation.IsPreallocatedTempBuffer()) {
-      TF_RETURN_IF_ERROR(
-          memory_allocator_->Deallocate(device_ordinal_, &buffer_address));
+      auto dealloc_result =
+          memory_allocator_->Deallocate(device_ordinal_, buffer_address);
+      if (!dealloc_result.ok() && status.ok()) {
+        status = dealloc_result;
+      }
     }
   }
-  return tensorflow::Status::OK();
+  torn_down_ = true;
+  return status;
 }
 
 se::DeviceMemoryBase BufferAllocations::GetDeviceAddress(
index c2fc35b..a36571d 100644 (file)
@@ -48,13 +48,15 @@ class BufferAllocations {
     // `device_ordinal` is the number of the device this function allocates
     // memory on.
     StatusOr<std::unique_ptr<BufferAllocations>> Build(
-        const BufferAssignment& buffer_assignment, int device_ordinal,
+        const BufferAssignment* buffer_assignment, int device_ordinal,
         DeviceMemoryAllocator* memory_allocator);
 
    private:
     std::map<BufferAllocation::Index, se::DeviceMemoryBase> registered_buffers_;
   };
 
+  ~BufferAllocations();
+
   BufferAllocations(const BufferAllocations&) = delete;
   BufferAllocations& operator=(const BufferAllocations&) = delete;
 
@@ -77,15 +79,16 @@ class BufferAllocations {
   // Tears down all buffers allocated by this object that are not in
   // `live_addresses`.
   tensorflow::Status TearDown(
-      const std::set<se::DeviceMemoryBase>& live_addresses,
-      const BufferAssignment& buffer_assignment);
+      const std::set<se::DeviceMemoryBase>& live_addresses);
 
  private:
   BufferAllocations(BufferAllocation::Index buffer_count, int device_ordinal,
-                    DeviceMemoryAllocator* memory_allocator)
+                    DeviceMemoryAllocator* memory_allocator,
+                    const BufferAssignment* buffer_assignment)
       : buffers_(buffer_count),
         device_ordinal_(device_ordinal),
-        memory_allocator_(memory_allocator) {}
+        memory_allocator_(memory_allocator),
+        buffer_assignment_(buffer_assignment) {}
 
   // Sets the device address of buffer `buffer_index`.
   void SetBuffer(BufferAllocation::Index buffer_index,
@@ -100,8 +103,9 @@ class BufferAllocations {
   se::DeviceMemoryBase temp_buffer_base_;
 
   int device_ordinal_;
-
   DeviceMemoryAllocator* memory_allocator_;
+  const BufferAssignment* buffer_assignment_;
+  bool torn_down_ = false;
 };
 
 }  // namespace gpu
index 41ee45f..6a46bdb 100644 (file)
@@ -35,35 +35,22 @@ class ScratchAllocator : public se::ScratchAllocator {
   ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator)
       : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
 
-  ~ScratchAllocator() override;
-
   int64 GetMemoryLimitInBytes(se::Stream* stream) override {
     return 1LL << 32;  // 4GB.  TODO(jlebar): Tune this?
   }
   int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
 
-  se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
-      se::Stream* stream, int64 byte_size) override;
+  StatusOr<se::DeviceMemory<uint8>> AllocateBytes(se::Stream* stream,
+                                                  int64 byte_size) override;
 
  private:
   const int device_ordinal_;
   DeviceMemoryAllocator* memory_allocator_;
-  std::vector<se::DeviceMemoryBase> allocated_buffers_;
+  std::vector<OwningDeviceMemory> allocated_buffers_;
   int64 total_allocated_bytes_ = 0;
 };
 
-ScratchAllocator::~ScratchAllocator() {
-  for (auto& allocated_buffer : allocated_buffers_) {
-    if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer)
-             .ok()) {
-      // The program can still continue with failed deallocation.
-      LOG(ERROR) << "Failed to deallocate the allocated buffer: "
-                 << allocated_buffer.opaque();
-    }
-  }
-}
-
-se::port::StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
+StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
     se::Stream* stream, int64 byte_size) {
   CHECK_GE(byte_size, 0) << "byte_size must be positive.";
   if (byte_size > GetMemoryLimitInBytes(stream)) {
@@ -74,19 +61,14 @@ se::port::StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
             byte_size, GetMemoryLimitInBytes(stream)));
   }
 
-  auto status_or_memory =
-      memory_allocator_->Allocate(device_ordinal_, byte_size,
-                                  /*retry_on_failure=*/false);
-  if (!status_or_memory.ok()) {
-    return se::port::Status(se::port::error::RESOURCE_EXHAUSTED,
-                            tensorflow::strings::Printf(
-                                "Failed to allocate %lld bytes on device %d.",
-                                byte_size, device_ordinal_));
-  }
-  se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie();
-  allocated_buffers_.push_back(allocated_buffer);
+  TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer,
+                      memory_allocator_->Allocate(device_ordinal_, byte_size,
+                                                  /*retry_on_failure=*/false));
   total_allocated_bytes_ += byte_size;
-  return se::DeviceMemory<uint8>(allocated_buffer);
+
+  se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase();
+  allocated_buffers_.push_back(std::move(allocated_buffer));
+  return se::DeviceMemory<uint8>(buffer_addr);
 }
 
 // Determines whether we can safely perform a winograd non-fused convolution for
index cc747ad..1cea493 100644 (file)
@@ -31,23 +31,12 @@ FftScratchAllocator::FftScratchAllocator(
     int device_ordinal, DeviceMemoryAllocator* memory_allocator)
     : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
 
-FftScratchAllocator::~FftScratchAllocator() {
-  for (auto& allocated_buffer : allocated_buffers_) {
-    if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer)
-             .ok()) {
-      // The program can still continue with failed deallocation.
-      LOG(ERROR) << "Failed to deallocate the allocated buffer: "
-                 << allocated_buffer.opaque();
-    }
-  }
-}
-
 int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) {
   constexpr int64 kFftScratchSize = 1LL << 32;  // 4GB by default.
   return kFftScratchSize;
 }
 
-se::port::StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
+StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
     se::Stream* stream, int64 byte_size) {
   CHECK_GE(byte_size, 0) << "byte_size must be positive.";
   if (byte_size > GetMemoryLimitInBytes(stream)) {
@@ -58,18 +47,14 @@ se::port::StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
             byte_size, GetMemoryLimitInBytes(stream)));
   }
 
-  auto status_or_memory =
-      memory_allocator_->Allocate(device_ordinal_, byte_size,
-                                  /*retry_on_failure=*/false);
-  if (!status_or_memory.ok()) {
-    return tensorflow::errors::ResourceExhausted(
-        "Failed to allocate %lld bytes on device %d.", byte_size,
-        device_ordinal_);
-  }
-  se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie();
-  allocated_buffers_.push_back(allocated_buffer);
+  TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer,
+                      memory_allocator_->Allocate(device_ordinal_, byte_size,
+                                                  /*retry_on_failure=*/false));
   total_allocated_bytes_ += byte_size;
-  return se::DeviceMemory<uint8>(allocated_buffer);
+
+  se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase();
+  allocated_buffers_.push_back(std::move(allocated_buffer));
+  return se::DeviceMemory<uint8>(buffer_addr);
 }
 
 namespace {
index 24b1dca..ea4270a 100644 (file)
@@ -39,8 +39,6 @@ class FftScratchAllocator : public se::ScratchAllocator {
   FftScratchAllocator(int device_ordinal,
                       DeviceMemoryAllocator* memory_allocator);
 
-  ~FftScratchAllocator() override;
-
   int64 GetMemoryLimitInBytes(se::Stream* stream) override;
 
   int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
@@ -51,7 +49,7 @@ class FftScratchAllocator : public se::ScratchAllocator {
  private:
   const int device_ordinal_;
   DeviceMemoryAllocator* memory_allocator_;
-  std::vector<se::DeviceMemoryBase> allocated_buffers_;
+  std::vector<OwningDeviceMemory> allocated_buffers_;
   int64 total_allocated_bytes_ = 0;
 };
 
index 980cc89..04b4f7a 100644 (file)
@@ -286,8 +286,8 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
   se::StreamExecutor* executor = run_options->stream()->parent();
   TF_ASSIGN_OR_RETURN(
       auto buffer_allocations,
-      buffer_allocations_builder.Build(*assignment_, executor->device_ordinal(),
-                                       memory_allocator));
+      buffer_allocations_builder.Build(
+          assignment_.get(), executor->device_ordinal(), memory_allocator));
 
   bool block_host_until_done =
       !memory_allocator->AllowsAsynchronousDeallocation();
@@ -329,8 +329,7 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
         buffers_in_result.insert(src_base);
         return Status::OK();
       }));
-  TF_RETURN_IF_ERROR(
-      buffer_allocations->TearDown(buffers_in_result, *assignment_));
+  TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result));
 
   return std::move(shaped_buffer);
 }
diff --git a/tensorflow/compiler/xla/service/owning_device_memory.cc b/tensorflow/compiler/xla/service/owning_device_memory.cc
new file mode 100644 (file)
index 0000000..c115bc0
--- /dev/null
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/owning_device_memory.h"
+
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+
+namespace xla {
+
+void OwningDeviceMemory::Free() {
+  CHECK(allocator_ != nullptr)
+      << "Can't call Free() on an inactive (i.e. moved from, Forget()'ten, "
+         "or Free()'ed) instance.";
+  auto status = allocator_->Deallocate(device_ordinal_, mem_);
+  if (!status.ok()) {
+    LOG(WARNING) << "Deallocating buffer " << mem_.opaque() << " failed.";
+  }
+
+  allocator_ = nullptr;
+  mem_ = se::DeviceMemoryBase();
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/owning_device_memory.h b/tensorflow/compiler/xla/service/owning_device_memory.h
new file mode 100644 (file)
index 0000000..9cf071f
--- /dev/null
@@ -0,0 +1,131 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+
+// Break circular dependency between this file and device_memory_allocator.h.
+class DeviceMemoryAllocator;
+
+// Owning pointer for memory on a device.
+//
+// OwningDeviceMemory is an owning pointer like std::unique_ptr, but it can
+// point to memory that resides on a "device" (e.g. a GPU).  When an
+// OwningDeviceMemory goes out of scope, it frees the memory it owns.
+//
+// We say that an instance of OwningDeviceMemory is "active" if it currently
+// owns a (possibly empty) slice of memory on the device.  Moving, Forget()'ing,
+// Free()'ing, and other actions can deactive an active object.
+//
+// Note that we can't simply use stream_executor::ScopedDeviceMemory instead of
+// OwningDeviceMemory, because ScopedDeviceMemory frees its pointer via a
+// StreamExecutor.  This class needs to free via a xla::DeviceMemoryAllocator.
+class OwningDeviceMemory {
+ public:
+  OwningDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {}
+
+  explicit OwningDeviceMemory(se::DeviceMemoryBase mem, int device_ordinal,
+                              DeviceMemoryAllocator* allocator)
+      : mem_(mem), device_ordinal_(device_ordinal), allocator_(allocator) {
+    CHECK(allocator != nullptr) << "allocator cannot be null.";
+  }
+
+  OwningDeviceMemory(OwningDeviceMemory&& other)
+      : mem_(other.mem_),
+        device_ordinal_(other.device_ordinal_),
+        allocator_(other.allocator_) {
+    other.mem_ = se::DeviceMemoryBase();
+    other.allocator_ = nullptr;
+  }
+
+  OwningDeviceMemory& operator=(OwningDeviceMemory&& other) {
+    if (allocator_ != nullptr) {
+      Free();
+    }
+    mem_ = other.mem_;
+    device_ordinal_ = other.device_ordinal_;
+    allocator_ = other.allocator_;
+
+    other.mem_ = se::DeviceMemoryBase();
+    other.allocator_ = nullptr;
+    return *this;
+  }
+
+  // Deactivates this instance if it's active.  Nop if it's not active.
+  OwningDeviceMemory& operator=(std::nullptr_t) {
+    if (allocator_ != nullptr) {
+      Free();
+    }
+    return *this;
+  }
+
+  ~OwningDeviceMemory() {
+    if (allocator_ != nullptr) {
+      Free();
+    }
+  }
+
+  // The returned allocator is nonnull iff this object is active.
+  DeviceMemoryAllocator* allocator() const { return allocator_; }
+
+  int device_ordinal() const { return device_ordinal_; }
+
+  // Gets the device memory pointer.
+  const void* opaque() const { return mem_.opaque(); }
+  void* opaque() { return mem_.opaque(); }
+
+  uint64 size() const { return mem_.size(); }
+
+  // Determines whether this wraps a null pointer.
+  //
+  // !is_null() is sufficient but not necessary to imply `this` is active.
+  bool is_null() const { return mem_.is_null(); }
+
+  se::DeviceMemoryBase AsDeviceMemoryBase() {
+    return se::DeviceMemoryBase(opaque(), size(), /*is_sub_buffer=*/false);
+  }
+
+  // Returns the wrapped DeviceMemoryBase without freeing it, and deactivates
+  // this object.  Precondition: `this` is active.
+  TF_MUST_USE_RESULT se::DeviceMemoryBase Forget() {
+    CHECK(allocator_ != nullptr)
+        << "Can't call Forget() on an inactive (i.e. moved from, Forget()'ten, "
+           "or Free()'ed) instance.";
+    allocator_ = nullptr;
+    se::DeviceMemoryBase mem(mem_);
+    mem_ = se::DeviceMemoryBase();
+    return mem;
+  }
+
+  // Frees the wrapped DeviceMemoryBase and deactivates this object.
+  // Precondition: `this` is active.
+  void Free();
+
+ private:
+  se::DeviceMemoryBase mem_;
+  int device_ordinal_;
+  DeviceMemoryAllocator* allocator_;  // Null if this object is inactive.
+};
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_
index fb3b5f0..6bacb37 100644 (file)
@@ -15,7 +15,6 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
 
-#include <set>
 #include <string>
 #include <utility>
 
@@ -25,6 +24,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/logging.h"
 
@@ -138,14 +138,12 @@ ScopedShapedBuffer::~ScopedShapedBuffer() {
   // Deallocate all non-null buffers. A buffer may appear in more than one spot
   // in the shape (eg, a tuple with a repeated element) so keep track of what
   // has been deallocated.
-  std::set<void*> deallocated_opaques;
+  tensorflow::gtl::FlatSet<void*> deallocated_ptrs;
   for (auto& pair : buffers_) {
     se::DeviceMemoryBase& memory_base = pair.second;
     if (!memory_base.is_null() &&
-        deallocated_opaques.count(memory_base.opaque()) == 0) {
-      deallocated_opaques.insert(memory_base.opaque());
-      TF_CHECK_OK(
-          this->allocator_->Deallocate(this->device_ordinal(), &memory_base));
+        deallocated_ptrs.insert(memory_base.opaque()).second) {
+      TF_CHECK_OK(allocator_->Deallocate(device_ordinal(), memory_base));
     }
   }
 }
index e10fca9..25b7095 100644 (file)
@@ -148,11 +148,25 @@ class ScopedShapedBuffer : public ShapedBuffer {
   // ScopedShapedBuffer.
   DeviceMemoryAllocator* memory_allocator() const { return allocator_; }
 
-  // Releases all device memory owned by this ScopedShapedBuffer and returns the
-  // device memory pointers in the form of a ShapedBuffer. The returned
-  // ShapedBuffer takes over the memory from the ScopedShapedBuffer. The
-  // resulting ScopedShapedBuffer can only be destroyed.
-  ShapedBuffer release();
+  // Sets the device memory buffer at the given index.
+  //
+  // If the given buffer's device memory is non-null, its device_ordinal and
+  // allocator must match those in `this`.
+  void set_buffer(OwningDeviceMemory buffer, const ShapeIndex& index) {
+    if (!buffer.is_null()) {
+      CHECK_EQ(buffer.device_ordinal(), device_ordinal());
+      CHECK_EQ(buffer.allocator(), allocator_);
+      *buffers_.mutable_element(index) = buffer.Forget();
+    } else {
+      *buffers_.mutable_element(index) = se::DeviceMemoryBase();
+    }
+  }
+
+  // Like unique_ptr::release(), creates and returns a regular ShapedBuffer from
+  // this ScopedShapedBuffer, without freeing any of the associated memory.
+  //
+  // It's the caller's job to ensure that the memory contained therein is freed.
+  TF_MUST_USE_RESULT ShapedBuffer release();
 
  protected:
   DeviceMemoryAllocator* allocator_;
index 8b71a41..3e7338f 100644 (file)
@@ -196,9 +196,11 @@ StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
     const ShapeIndex& index = pair.first;
     se::DeviceMemoryBase& memory_base = pair.second;
     const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index);
-    TF_ASSIGN_OR_RETURN(memory_base,
+    TF_ASSIGN_OR_RETURN(auto memory,
                         allocator->Allocate(shaped_buffer.device_ordinal(),
                                             GetByteSizeRequirement(subshape)));
+    // Move the allocated buffer into the ScopedShapedBuffer, which owns it.
+    memory_base = memory.Forget();
   }
 
   return std::move(shaped_buffer);
index e859b30..758a4aa 100644 (file)
@@ -35,9 +35,9 @@ namespace xla {
 
 /* static */ TestAllocator* LocalClientTestBase::allocator_;
 
-StatusOr<se::DeviceMemoryBase> TestAllocator::Allocate(int device_ordinal,
-                                                       uint64 size,
-                                                       bool retry_on_failure) {
+StatusOr<OwningDeviceMemory> TestAllocator::Allocate(int device_ordinal,
+                                                     uint64 size,
+                                                     bool retry_on_failure) {
   VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")";
   {
     tensorflow::mutex_lock lock(count_mutex_);
@@ -49,7 +49,7 @@ StatusOr<se::DeviceMemoryBase> TestAllocator::Allocate(int device_ordinal,
 }
 
 tensorflow::Status TestAllocator::Deallocate(int device_ordinal,
-                                             se::DeviceMemoryBase* mem) {
+                                             se::DeviceMemoryBase mem) {
   VLOG(2) << "Deallocate(" << device_ordinal << ")";
   {
     tensorflow::mutex_lock lock(count_mutex_);
index 3bbb760..6374c79 100644 (file)
@@ -46,10 +46,10 @@ class TestAllocator : public StreamExecutorMemoryAllocator {
             platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) {
   }
 
-  StatusOr<se::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size,
-                                          bool retry_on_failure) override;
+  StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
+                                        bool retry_on_failure) override;
   tensorflow::Status Deallocate(int device_ordinal,
-                                se::DeviceMemoryBase* mem) override;
+                                se::DeviceMemoryBase mem) override;
 
   // Return the number of allocations that have been performed.
   int64 allocation_count() const;
index ab6b00f..e426cf9 100644 (file)
@@ -177,6 +177,9 @@ class StreamExecutor {
   //
   // Resets the internal contents of mem to be null-representative, but this
   // null-out effect should not be relied upon in client code.
+  //
+  // TODO(jlebar): Change this to accept a DeviceMemoryBase by value, see
+  // discussion in cl/195744342.
   void Deallocate(DeviceMemoryBase *mem);
 
   // Retrieves a mapping of active opaque device memory pointer to a string