"framework/type_index.h",
"framework/type_traits.h",
"framework/types.h",
+ "framework/visitable_allocator.h",
"public/version.h",
"util/activation_mode.h",
"util/bcast.h",
"framework/tracking_allocator.h", # only needed for tests
"framework/unique_tensor_references.h",
"framework/variant.h",
+ "framework/visitable_allocator.h",
"platform/variant_coding.h",
"util/command_line_flags.h",
"util/env_var.h",
"common_runtime/stats_publisher_interface.h",
"common_runtime/step_stats_collector.h",
"common_runtime/threadpool_device.h",
- "common_runtime/visitable_allocator.h",
"graph/gradients.h",
"graph/quantize_training.h",
] + if_mkl(["graph/mkl_graph_util.h"])
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/visitable_allocator.h"
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/framework/log_memory.h"
cpu_allocator_collect_full_stats = enable;
}
-class CPUAllocator : public Allocator {
+class CPUAllocator : public VisitableAllocator {
public:
- CPUAllocator() {}
+ CPUAllocator() : allocation_begun_(false) {}
~CPUAllocator() override {}
string Name() override { return "cpu"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+ if (!allocation_begun_) {
+ allocation_begun_ = true;
+ }
+
void* p = port::AlignedMalloc(num_bytes, alignment);
if (cpu_allocator_collect_stats) {
const std::size_t alloc_size = port::MallocExtension_GetAllocatedSize(p);
stats_.max_alloc_size =
std::max<int64>(stats_.max_alloc_size, alloc_size);
}
+
+ // visit each Visitor in alloc_visitors_
+ if (p != nullptr) {
+ for (const Visitor& v : alloc_visitors_) {
+ v(p, num_bytes);
+ }
+ }
+
return p;
}
void DeallocateRaw(void* ptr) override {
+ std::size_t alloc_size;
+ bool init_alloc_size = false;
if (cpu_allocator_collect_stats) {
- const std::size_t alloc_size =
- port::MallocExtension_GetAllocatedSize(ptr);
+ alloc_size = port::MallocExtension_GetAllocatedSize(ptr);
+ init_alloc_size = true;
mutex_lock l(mu_);
stats_.bytes_in_use -= alloc_size;
}
+
+ // visit each Visitor in free_visitors_
+ if (ptr != nullptr) {
+ if (!init_alloc_size) {
+ alloc_size = port::MallocExtension_GetAllocatedSize(ptr);
+ init_alloc_size = true;
+ }
+ for (const Visitor& v : free_visitors_) {
+ v(ptr, alloc_size);
+ }
+ }
+
port::AlignedFree(ptr);
}
return port::MallocExtension_GetAllocatedSize(ptr);
}
+ // REQUIRES: can only add visitors before the first Allocate call
+
+ void AddAllocVisitor(Visitor visitor) override {
+ mutex_lock lock(visitor_mutex_);
+ CHECK(!allocation_begun_)
+ << "AddAllocVisitor may not be called after allocation has begun.";
+ alloc_visitors_.push_back(visitor);
+ }
+
+ void AddFreeVisitor(Visitor visitor) override {
+ mutex_lock lock(visitor_mutex_);
+ CHECK(!allocation_begun_)
+ << "AddFreeVisitor may not be called after allocation has begun.";
+ free_visitors_.push_back(visitor);
+ }
+
private:
mutex mu_;
AllocatorStats stats_ GUARDED_BY(mu_);
+ // visitor_mutex_ protects write access to alloc_visitors_ and free_visitors_.
+ // While write access is mutually exclusive, reads may happen concurrently.
+ // This is okay because we may only append to alloc_visitors_ and
+ // free_visitors_ before first allocation, and subsequently we only read these
+ // vectors.
+ mutex visitor_mutex_;
+ std::vector<Visitor> alloc_visitors_;
+ std::vector<Visitor> free_visitors_;
+ std::atomic<bool> allocation_begun_;
+
TF_DISALLOW_COPY_AND_ASSIGN(CPUAllocator);
};