[FIX] modified for checking weight grad
authorJiho Chu <jiho.chu@samsung.com>
Tue, 1 Aug 2023 10:34:02 +0000 (19:34 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 2 Aug 2023 23:25:08 +0000 (08:25 +0900)
This path checks requested memory is weight gradient which information
will be used for planning.

Signed-off-by: Jiho Chu <jiho.chu@samsung.com>
nntrainer/tensor/cache_pool.cpp
nntrainer/tensor/cache_pool.h

index ac14d99..3a4c188 100644 (file)
@@ -154,9 +154,9 @@ void CachePool::invalidate(unsigned int id) {
 unsigned int CachePool::requestMemory(size_t bytes, unsigned int start_time,
                                       unsigned int end_time,
                                       std::vector<unsigned int> exec_order,
-                                      TensorLifespan lifespan) {
+                                      TensorLifespan lifespan, bool is_wgrad) {
   auto id = MemoryPool::requestMemory(bytes, start_time, end_time, exec_order,
-                                      lifespan);
+                                      lifespan, is_wgrad);
 
   const CachePolicy policy = convertTensorLifespanToCachePolicy(lifespan);
 
index b78fa00..8514ebe 100644 (file)
@@ -75,7 +75,8 @@ public:
   virtual unsigned int requestMemory(
     size_t bytes, unsigned int start_time, unsigned int end_time,
     std::vector<unsigned int> exec_order = std::vector<unsigned int>(),
-    TensorLifespan lifespan = TensorLifespan::MAX_LIFESPAN);
+    TensorLifespan lifespan = TensorLifespan::MAX_LIFESPAN,
+    bool is_wgrad = false);
 
   /**
    * @brief Get the allocated cache