Make CPUAllocator VisitableAllocator, for better RDMA networking.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 21 Feb 2018 22:14:42 +0000 (14:14 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Feb 2018 22:18:29 +0000 (14:18 -0800)
PiperOrigin-RevId: 186518037

tensorflow/core/BUILD
tensorflow/core/common_runtime/bfc_allocator.h
tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
tensorflow/core/common_runtime/gpu/pool_allocator.h
tensorflow/core/common_runtime/mkl_cpu_allocator.h
tensorflow/core/framework/allocator.cc
tensorflow/core/framework/visitable_allocator.h [moved from tensorflow/core/common_runtime/visitable_allocator.h with 94% similarity]

index 04307db..1893967 100644 (file)
@@ -480,6 +480,7 @@ tf_cuda_library(
         "framework/type_index.h",
         "framework/type_traits.h",
         "framework/types.h",
+        "framework/visitable_allocator.h",
         "public/version.h",
         "util/activation_mode.h",
         "util/bcast.h",
@@ -1812,6 +1813,7 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
     "framework/tracking_allocator.h",  # only needed for tests
     "framework/unique_tensor_references.h",
     "framework/variant.h",
+    "framework/visitable_allocator.h",
     "platform/variant_coding.h",
     "util/command_line_flags.h",
     "util/env_var.h",
@@ -2107,7 +2109,6 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
     "common_runtime/stats_publisher_interface.h",
     "common_runtime/step_stats_collector.h",
     "common_runtime/threadpool_device.h",
-    "common_runtime/visitable_allocator.h",
     "graph/gradients.h",
     "graph/quantize_training.h",
 ] + if_mkl(["graph/mkl_graph_util.h"])
index b8e7735..e34945d 100644 (file)
@@ -23,7 +23,7 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/core/common_runtime/allocator_retry.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/visitable_allocator.h"
 #include "tensorflow/core/lib/gtl/stl_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/macros.h"
index 2086973..0a58634 100644 (file)
@@ -19,7 +19,7 @@ limitations under the License.
 #include <memory>
 
 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/visitable_allocator.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/stream_executor.h"
 #include "tensorflow/core/platform/types.h"
index adce3a8..0db08dc 100644 (file)
@@ -21,7 +21,7 @@ limitations under the License.
 #include <unordered_map>
 
 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/visitable_allocator.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/stream_executor.h"
 #include "tensorflow/core/platform/types.h"
index 91ce830..38d669e 100644 (file)
@@ -24,7 +24,7 @@ limitations under the License.
 #include <map>
 #include <memory>
 #include <vector>
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/visitable_allocator.h"
 #include "tensorflow/core/lib/core/bits.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
index 2a67c03..77eeb56 100644 (file)
@@ -25,7 +25,7 @@ limitations under the License.
 #include <cstdlib>
 #include <string>
 #include "tensorflow/core/common_runtime/bfc_allocator.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/visitable_allocator.h"
 #include "tensorflow/core/lib/strings/numbers.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/mem.h"
index 94bf34a..a382b8b 100644 (file)
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/visitable_allocator.h"
 
 #include "tensorflow/core/framework/allocator_registry.h"
 #include "tensorflow/core/framework/log_memory.h"
@@ -68,15 +68,19 @@ void EnableCPUAllocatorFullStats(bool enable) {
   cpu_allocator_collect_full_stats = enable;
 }
 
-class CPUAllocator : public Allocator {
+class CPUAllocator : public VisitableAllocator {
  public:
-  CPUAllocator() {}
+  CPUAllocator() : allocation_begun_(false) {}
 
   ~CPUAllocator() override {}
 
   string Name() override { return "cpu"; }
 
   void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+    if (!allocation_begun_) {
+      allocation_begun_ = true;
+    }
+
     void* p = port::AlignedMalloc(num_bytes, alignment);
     if (cpu_allocator_collect_stats) {
       const std::size_t alloc_size = port::MallocExtension_GetAllocatedSize(p);
@@ -88,16 +92,38 @@ class CPUAllocator : public Allocator {
       stats_.max_alloc_size =
           std::max<int64>(stats_.max_alloc_size, alloc_size);
     }
+
+    // visit each Visitor in alloc_visitors_
+    if (p != nullptr) {
+      for (const Visitor& v : alloc_visitors_) {
+        v(p, num_bytes);
+      }
+    }
+
     return p;
   }
 
   void DeallocateRaw(void* ptr) override {
+    std::size_t alloc_size;
+    bool init_alloc_size = false;
     if (cpu_allocator_collect_stats) {
-      const std::size_t alloc_size =
-          port::MallocExtension_GetAllocatedSize(ptr);
+      alloc_size = port::MallocExtension_GetAllocatedSize(ptr);
+      init_alloc_size = true;
       mutex_lock l(mu_);
       stats_.bytes_in_use -= alloc_size;
     }
+
+    // visit each Visitor in free_visitors_
+    if (ptr != nullptr) {
+      if (!init_alloc_size) {
+        alloc_size = port::MallocExtension_GetAllocatedSize(ptr);
+        init_alloc_size = true;
+      }
+      for (const Visitor& v : free_visitors_) {
+        v(ptr, alloc_size);
+      }
+    }
+
     port::AlignedFree(ptr);
   }
 
@@ -117,10 +143,36 @@ class CPUAllocator : public Allocator {
     return port::MallocExtension_GetAllocatedSize(ptr);
   }
 
+  // REQUIRES: can only add visitors before the first Allocate call
+
+  void AddAllocVisitor(Visitor visitor) override {
+    mutex_lock lock(visitor_mutex_);
+    CHECK(!allocation_begun_)
+        << "AddAllocVisitor may not be called after allocation has begun.";
+    alloc_visitors_.push_back(visitor);
+  }
+
+  void AddFreeVisitor(Visitor visitor) override {
+    mutex_lock lock(visitor_mutex_);
+    CHECK(!allocation_begun_)
+        << "AddFreeVisitor may not be called after allocation has begun.";
+    free_visitors_.push_back(visitor);
+  }
+
  private:
   mutex mu_;
   AllocatorStats stats_ GUARDED_BY(mu_);
 
+  // visitor_mutex_ protects write access to alloc_visitors_ and free_visitors_.
+  // While write access is mutually exclusive, reads may happen concurrently.
+  // This is okay because we may only append to alloc_visitors_ and
+  // free_visitors_ before first allocation, and subsequently we only read these
+  // vectors.
+  mutex visitor_mutex_;
+  std::vector<Visitor> alloc_visitors_;
+  std::vector<Visitor> free_visitors_;
+  std::atomic<bool> allocation_begun_;
+
   TF_DISALLOW_COPY_AND_ASSIGN(CPUAllocator);
 };
 
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_VISITABLE_ALLOCATOR_H_
+#define TENSORFLOW_CORE_FRAMEWORK_VISITABLE_ALLOCATOR_H_
 
 #include <functional>
 #include "tensorflow/core/framework/allocator.h"
@@ -76,4 +76,4 @@ class TrackingVisitableAllocator : public TrackingAllocator,
   VisitableAllocator* allocator_;
 };
 }  // namespace tensorflow
-#endif  // TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
+#endif  // TENSORFLOW_CORE_FRAMEWORK_VISITABLE_ALLOCATOR_H_