"framework/type_index.h",
"framework/type_traits.h",
"framework/types.h",
- "framework/visitable_allocator.h",
"public/version.h",
"util/activation_mode.h",
"util/bcast.h",
"framework/tracking_allocator.h", # only needed for tests
"framework/unique_tensor_references.h",
"framework/variant.h",
- "framework/visitable_allocator.h",
"platform/variant_coding.h",
"util/command_line_flags.h",
"util/env_var.h",
"common_runtime/stats_publisher_interface.h",
"common_runtime/step_stats_collector.h",
"common_runtime/threadpool_device.h",
+ "common_runtime/visitable_allocator.h",
"graph/gradients.h",
"graph/quantize_training.h",
] + if_mkl(["graph/mkl_graph_util.h"])
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/framework/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/framework/log_memory.h"
cpu_allocator_collect_full_stats = enable;
}
-class CPUAllocator : public VisitableAllocator {
+class CPUAllocator : public Allocator {
public:
- CPUAllocator()
- : total_allocation_warning_triggered_(false), allocation_begun_(false) {}
+ CPUAllocator() : total_allocation_warning_triggered_(false) {}
~CPUAllocator() override {}
string Name() override { return "cpu"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
- if (!allocation_begun_) {
- allocation_begun_ = true;
- }
-
if (num_bytes > LargeAllocationWarningBytes()) {
LOG(WARNING) << "Allocation of " << num_bytes << " exceeds "
<< 100 * kLargeAllocationWarningThreshold
total_allocation_warning_triggered_ = true;
}
}
-
- // visit each Visitor in alloc_visitors_
- if (p != nullptr) {
- for (const Visitor& v : alloc_visitors_) {
- v(p, num_bytes);
- }
- }
-
return p;
}
void DeallocateRaw(void* ptr) override {
- std::size_t alloc_size;
- bool init_alloc_size = false;
if (cpu_allocator_collect_stats) {
- alloc_size = port::MallocExtension_GetAllocatedSize(ptr);
- init_alloc_size = true;
+ const std::size_t alloc_size =
+ port::MallocExtension_GetAllocatedSize(ptr);
mutex_lock l(mu_);
stats_.bytes_in_use -= alloc_size;
}
-
- // visit each Visitor in free_visitors_
- if (ptr != nullptr) {
- if (!init_alloc_size) {
- alloc_size = port::MallocExtension_GetAllocatedSize(ptr);
- init_alloc_size = true;
- }
- for (const Visitor& v : free_visitors_) {
- v(ptr, alloc_size);
- }
- }
-
port::AlignedFree(ptr);
}
return port::MallocExtension_GetAllocatedSize(ptr);
}
- // REQUIRES: can only add visitors before the first Allocate call
-
- void AddAllocVisitor(Visitor visitor) override {
- mutex_lock lock(visitor_mutex_);
- CHECK(!allocation_begun_)
- << "AddAllocVisitor may not be called after allocation has begun.";
- alloc_visitors_.push_back(visitor);
- }
-
- void AddFreeVisitor(Visitor visitor) override {
- mutex_lock lock(visitor_mutex_);
- CHECK(!allocation_begun_)
- << "AddFreeVisitor may not be called after allocation has begun.";
- free_visitors_.push_back(visitor);
- }
-
private:
mutex mu_;
AllocatorStats stats_ GUARDED_BY(mu_);
bool total_allocation_warning_triggered_ GUARDED_BY(mu_);
- // visitor_mutex_ protects write access to alloc_visitors_ and free_visitors_.
- // While write access is mutually exclusive, reads may happen concurrently.
- // This is okay because we may only append to alloc_visitors_ and
- // free_visitors_ before first allocation, and subsequently we only read these
- // vectors.
- mutex visitor_mutex_;
- std::vector<Visitor> alloc_visitors_;
- std::vector<Visitor> free_visitors_;
- std::atomic<bool> allocation_begun_;
-
TF_DISALLOW_COPY_AND_ASSIGN(CPUAllocator);
};