[SWAP] Add memory swap policy
authorJiho Chu <jiho.chu@samsung.com>
Thu, 15 Dec 2022 07:41:30 +0000 (16:41 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 31 Jan 2023 11:27:17 +0000 (20:27 +0900)
Signed-off-by: Jiho Chu <jiho.chu@samsung.com>
nntrainer/tensor/cache_pool.cpp
nntrainer/tensor/cache_pool.h
nntrainer/tensor/memory_pool.cpp
nntrainer/tensor/memory_pool.h
nntrainer/tensor/swap_device.cpp
nntrainer/tensor/swap_device.h
test/unittest/models/models_golden_test.h

index a454544..b970d5d 100644 (file)
  *
  */
 
-#include "memory_pool.h"
+#include "cache_pool.h"
+
 #include <limits>
 #include <numeric>
 #include <stdexcept>
 #include <vector>
 
-#include <cache_pool.h>
 #include <nntrainer_error.h>
 #include <nntrainer_log.h>
 #include <profiler.h>
 
 namespace nntrainer {
 
-void CacheElem::swapIn() {
-  std::lock_guard<std::mutex> lock(device_mutex);
-  void *buf = device->getBuffer(offset, length);
-  mem_data->setAddr((float *)buf);
-  mem_data->setValid(true);
-  active = true;
+namespace {
 
-  std::string msg("CacheElem(");
-  msg += device->getDevicePath() + ") #" + std::to_string(id);
-  PROFILE_MEM_ALLOC(buf, length, msg);
+/**
+ * @brief convert tensor lifespan to cache policy
+ *
+ * @param lifespand tensor lifespan
+ * @return cache policy
+ * @note cache policy is defined by tensor's lifetime. If it needs to be maintained its value,
+ * ALWAYS_SYNCED or ITERATION_CONSIST is proper. If not, TEMPORAL doesnot keep its value.
+ */
+inline const CachePolicy
+convertTensorLifespanToCachePolicy(const TensorLifespan lifespan) {
+  CachePolicy policy;
+
+  switch (lifespan) {
+  case TensorLifespan::UNMANAGED:
+    policy = CachePolicy::ALWAYS_SYNCED;
+    break;
+  case TensorLifespan::FORWARD_FUNC_LIFESPAN:
+    policy = CachePolicy::ALWAYS_SYNCED;
+    break;
+  case TensorLifespan::CALC_DERIV_LIFESPAN:
+    policy = CachePolicy::TEMPORAL;
+    break;
+  case TensorLifespan::CALC_GRAD_LIFESPAN:
+    policy = CachePolicy::TEMPORAL;
+    break;
+  case TensorLifespan::CALC_GRAD_DERIV_LIFESPAN:
+    policy = CachePolicy::TEMPORAL;
+    break;
+  case TensorLifespan::FORWARD_GRAD_LIFESPAN:
+    policy = CachePolicy::ITERATION_CONSIST;
+    break;
+  case TensorLifespan::FORWARD_DERIV_LIFESPAN:
+    policy = CachePolicy::ALWAYS_SYNCED;
+    break;
+  case TensorLifespan::ITERATION_LIFESPAN:
+    policy = CachePolicy::ITERATION_CONSIST;
+    break;
+  case TensorLifespan::EPOCH_LIFESPAN:
+    policy = CachePolicy::ITERATION_CONSIST;
+    break;
+  case TensorLifespan::MAX_LIFESPAN:
+    policy = CachePolicy::ALWAYS_SYNCED;
+    break;
+  default:
+    policy = CachePolicy::ALWAYS_SYNCED;
+    break;
+  }
+
+  return policy;
 }
 
-void CacheElem::swapOut() {
-  std::lock_guard<std::mutex> lock(device_mutex);
-  void *buf = (void *)mem_data->getAddr();
-  device->putBuffer(buf);
-  mem_data->setAddr(nullptr);
-  mem_data->setValid(false);
-  active = false;
+std::atomic_int pool_id = 0;
 
-  PROFILE_MEM_DEALLOC(buf);
-}
+} // namespace
 
-CachePool::CachePool(const std::string &name) :
-  swap_device(std::make_shared<SwapDevice>(name + std::to_string(getpid()))) {}
+CachePool::CachePool(const std::string &n) :
+  name(n),
+  swap_device(std::make_shared<SwapDevice>(n + "_" + std::to_string(getpid()) +
+                                           "_" + std::to_string(pool_id++))) {}
 
-CachePool::CachePool(const std::string &path, const std::string &name) {
+CachePool::CachePool(const std::string &path, const std::string &n) : name(n) {
   if (path.empty())
-    swap_device = std::make_shared<SwapDevice>(name + std::to_string(getpid()));
+    swap_device = std::make_shared<SwapDevice>(
+      n + "_" + std::to_string(getpid()) + "_" + std::to_string(pool_id++));
   else
     swap_device =
-      std::make_shared<SwapDevice>(path, name + std::to_string(getpid()));
+      std::make_shared<SwapDevice>(path, n + "_" + std::to_string(getpid()) +
+                                           "_" + std::to_string(pool_id++));
 }
 
 CachePool::~CachePool() {
@@ -103,6 +141,23 @@ void CachePool::invalidate(unsigned int id) {
   }
 }
 
+unsigned int CachePool::requestMemory(size_t bytes, unsigned int start_time,
+                                      unsigned int end_time,
+                                      std::vector<unsigned int> exec_order,
+                                      TensorLifespan lifespan) {
+  auto id = MemoryPool::requestMemory(bytes, start_time, end_time, exec_order,
+                                      lifespan);
+
+  const CachePolicy policy = convertTensorLifespanToCachePolicy(lifespan);
+
+  policies.push_back(policy);
+
+  NNTR_THROW_IF(id != policies.size(), std::runtime_error)
+    << "Invalid requqestMemory call exist";
+
+  return id;
+}
+
 std::shared_ptr<MemoryData<float>> CachePool::getMemory(unsigned int id) {
   NNTR_THROW_IF(!swap_device->isOperating(), std::invalid_argument)
     << "Allocate memory before allocation";
@@ -110,16 +165,17 @@ std::shared_ptr<MemoryData<float>> CachePool::getMemory(unsigned int id) {
   off_t offset = getMemoryOffset().at(id - 1);
   size_t len = getMemorySize().at(id - 1);
   auto exe_order = getMemoryExecOrder().at(id - 1);
+  auto policy = getCachePolicy().at(id - 1);
   auto mem_data = std::make_shared<MemoryData<float>>(
     id, std::bind(&CachePool::validate, this, std::placeholders::_1),
     std::bind(&CachePool::invalidate, this, std::placeholders::_1));
-  auto elem = std::make_shared<CacheElem>(swap_device, id, offset, len,
-                                          mem_data, exe_order);
-
+  auto elem =
+    std::make_shared<CacheElem>(swap_device, id, offset, len, mem_data, policy);
   elems[id] = elem;
 
   std::string ords;
   for (auto &o : exe_order) {
+    exec_ids[o].push_back(id);
     ords.append(std::to_string(o));
     ords.append(" ");
   }
@@ -130,69 +186,111 @@ std::shared_ptr<MemoryData<float>> CachePool::getMemory(unsigned int id) {
 }
 
 void CachePool::flush() {
-  for (auto elem : actives)
-    elem->swapOut();
+  for (auto &elem : actives)
+    elem->swapOut(CacheElem::LAST_ACCESS);
+
+  for (auto &[id, elem] : elems)
+    elem->reset();
 
   actives.clear();
 }
 
 void CachePool::flushExcept(unsigned int order) {
+  auto exe_orders = getMemoryExecOrder();
+
   actives.remove_if([&, order](auto elem) -> bool {
-    auto exe_order = elem->getExeOrder();
+    auto id = elem->getId();
+    auto exe_order = exe_orders.at(id - 1);
     auto found = std::find(exe_order.begin(), exe_order.end(), order);
     if (found == exe_order.end()) {
-      elem->swapOut();
+      /**
+       * We assumes that flushExcept will be called in front of each execution
+       * order, and the order is incremental. So, we can conclude that, if the
+       * order passes by the max order of the cache element, it was LAST access
+       * of the element.
+       */
+      CacheElem::Options opt = CacheElem::NONE;
+      if (*std::max_element(exe_order.begin(), exe_order.end()) < order)
+        opt = CacheElem::LAST_ACCESS;
+      elem->swapOut(opt);
       return true;
     }
     return false;
   });
 }
 
+void CachePool::flushExcept(std::vector<unsigned int> order) {
+  auto exe_orders = getMemoryExecOrder();
+
+  actives.remove_if([&, order](const auto elem) -> bool {
+    auto id = elem->getId();
+    auto exe_order = exe_orders.at(id - 1);
+    for (auto &o : order) {
+      auto found = std::find(exe_order.begin(), exe_order.end(), o);
+      if (found != exe_order.end())
+        return false;
+    }
+    /**
+     * We assumes that flushExcept will be called in front of each execution
+     * order, and the order is incremental. So, we can conclude that, if the
+     * order passes by the max order of the cache element, it was LAST access of
+     * the element.
+     */
+    CacheElem::Options opt = CacheElem::NONE;
+    if (*std::max_element(exe_order.begin(), exe_order.end()) < order[0])
+      opt = CacheElem::LAST_ACCESS;
+    elem->swapOut(opt);
+    return true;
+  });
+}
+
 void CachePool::clear() {
   flush();
   deallocate();
+  policies.clear();
   MemoryPool::clear();
 }
 
 bool CachePool::isAllocated() const { return swap_device->isOperating(); }
 
 void CachePool::loadExec(unsigned int order) {
-  for (auto &[id, elem] : elems) {
-    auto exe_order = elem->getExeOrder();
-    auto found = std::find(exe_order.begin(), exe_order.end(), order);
-    if (found != exe_order.end())
-      validate(elem->getId());
-  }
+  for (auto &id : exec_ids[order])
+    validate(id);
 }
 
 void CachePool::initCacheElemIter(CacheElemsIter &iter) {
   iter = elems.begin();
 }
 
-bool CachePool::isLastCacheElemIter(const CacheElemsIter &iter) const {
+bool CachePool::isLastCacheElemIter(const CacheElemsIter &iter) {
   return iter == elems.end();
 }
 
-bool CachePool::loadExecOnce(unsigned int order, CacheElemsIter &iter) {
-  if (iter == elems.end())
+void CachePool::initExecIdsIter(unsigned int order, ExecIdsIter &iter) {
+  iter = exec_ids[order].begin();
+}
+
+bool CachePool::isLastExecIdsIter(unsigned int order, const ExecIdsIter &iter) {
+  return iter == exec_ids[order].end();
+}
+
+bool CachePool::loadExecOnce(unsigned int order, ExecIdsIter &iter) {
+  if (iter == exec_ids[order].end())
     return true;
 
-  auto elem = iter->second;
-  auto exe_order = elem->getExeOrder();
-  auto found = std::find(exe_order.begin(), exe_order.end(), order);
-  if (found != exe_order.end())
-    validate(elem->getId());
+  validate(*iter);
 
   iter++;
   return false;
 }
 
 void CachePool::unloadExec(unsigned int order) {
+  auto exe_orders = getMemoryExecOrder();
   for (auto &[id, elem] : elems) {
-    auto exe_order = elem->getExeOrder();
+    auto exe_order = exe_orders.at(id - 1);
     auto found = std::find(exe_order.begin(), exe_order.end(), order);
     if (found != exe_order.end())
-      invalidate(elem->getId());
+      invalidate(id);
   }
 }
 
index 5076715..b78fa00 100644 (file)
 
 #include <list>
 #include <mutex>
+#include <vector>
 
+#include <cache_elem.h>
 #include <memory_pool.h>
 #include <swap_device.h>
 
 namespace nntrainer {
 
 /**
- * @class   CacheElem
- * @brief   Cache element containing swap address
- */
-class CacheElem {
-public:
-  /**
-   * @brief CacheElem default constructor
-   *
-   */
-  explicit CacheElem(std::shared_ptr<SwapDevice> dev, unsigned int mem_id,
-                     off_t off, size_t len,
-                     std::shared_ptr<MemoryData<float>> data,
-                     std::vector<unsigned int> order) :
-    device(dev),
-    active(false),
-    id(mem_id),
-    offset(off),
-    length(len),
-    mem_data(data),
-    exe_order(order) {}
-
-  /**
-   * @brief CacheElem destructor
-   *
-   */
-  virtual ~CacheElem() {}
-
-  /**
-   * @brief load data from swap device
-   *
-   */
-  void swapIn();
-
-  /**
-   * @brief unload data to swap device
-   *
-   */
-  void swapOut();
-
-  /**
-   * @brief unload data to swap device
-   *
-   * @return active status
-''   */
-  bool isActive() const { return active; }
-
-  /**
-   * @brief get execution orders
-   *
-   * @return execution orders
-   */
-  std::vector<unsigned int> &getExeOrder() { return exe_order; }
-
-  /**
-   * @brief get length of cache element
-   *
-   * @return length of cache element in byte
-   */
-  size_t getLength() const { return length; }
-
-  /**
-   * @brief get id of cache element
-   *
-   * @return cache element id
-   */
-  unsigned int getId() const { return id; }
-
-private:
-  std::mutex device_mutex;            /**< protect device */
-  std::shared_ptr<SwapDevice> device; /**< swap device */
-  bool active;                        /**< element is loaded */
-  unsigned int id;                    /**< memory id */
-  off_t offset;                       /**< element offset from swap device */
-  size_t length;                      /**< element size */
-  std::shared_ptr<MemoryData<float>> mem_data; /**< allocated memory data */
-  std::vector<unsigned int> exe_order;         /**< execution order */
-};
-
-/**
  * @class   CachePool
  * @brief   Cache memory with fixed size utilizing swap device
  */
 class CachePool : public MemoryPool {
 public:
-  using CacheElems = std::map<unsigned int, std::shared_ptr<CacheElem>>;
+  using CacheElems =
+    std::map<unsigned int,
+             std::shared_ptr<CacheElem>>; /**< cache id, cache elem */
   using CacheElemsIter = CacheElems::iterator;
+  using ExecIds = std::vector<unsigned int>;
+  using ExecIdsIter = ExecIds::iterator;
 
   /**
    * @brief CachePool default constructor
    *
+   * @param name name of the cache pool
    */
   explicit CachePool(const std::string &name);
 
@@ -141,6 +69,15 @@ public:
   virtual void deallocate();
 
   /**
+   * @brief Request Memory from memory pool
+   * @note start_time is inclusive, but end_time is exclusive
+   */
+  virtual unsigned int requestMemory(
+    size_t bytes, unsigned int start_time, unsigned int end_time,
+    std::vector<unsigned int> exec_order = std::vector<unsigned int>(),
+    TensorLifespan lifespan = TensorLifespan::MAX_LIFESPAN);
+
+  /**
    * @brief Get the allocated cache
    *
    * @param id The token received from the requestMemory
@@ -161,6 +98,7 @@ public:
   /**
    * @brief Flush cache data to device
    *
+   * @note it must be called only when epoch ends.
    */
   virtual void flush();
 
@@ -172,6 +110,13 @@ public:
   virtual void flushExcept(unsigned int order);
 
   /**
+   * @brief Flush cache data to device except given order
+   *
+   * @param order except execution order
+   */
+  virtual void flushExcept(std::vector<unsigned int> order);
+
+  /**
    * @brief Clear the memory pool
    *
    */
@@ -196,14 +141,28 @@ public:
    *
    * @param order execution order
    */
-  virtual bool isLastCacheElemIter(const CacheElemsIter &iter) const;
+  virtual bool isLastCacheElemIter(const CacheElemsIter &iter);
+
+  /**
+   * @brief Load cache data by execution order
+   *
+   * @param order execution order
+   */
+  virtual void initExecIdsIter(unsigned int order, ExecIdsIter &iter);
+
+  /**
+   * @brief Check iterator is last element
+   *
+   * @param order execution order
+   */
+  virtual bool isLastExecIdsIter(unsigned int order, const ExecIdsIter &iter);
 
   /**
    * @brief Load cache data by execution order
    *
    * @param order execution order
    */
-  virtual bool loadExecOnce(unsigned int order, CacheElemsIter &iter);
+  virtual bool loadExecOnce(unsigned int order, ExecIdsIter &iter);
 
   /**
    * @brief Unload cache data by execution order
@@ -222,6 +181,13 @@ public:
    */
   virtual void unloadActives();
 
+  /**
+   * @brief Get name
+   *
+   * @return cache pool name
+   */
+  virtual std::string getName() { return name; }
+
 protected:
   /**
    * @brief validate cache element
@@ -237,10 +203,23 @@ protected:
    */
   virtual void invalidate(unsigned int id);
 
+  /**
+   * @brief Get cache policies
+   *
+   * @return Cache polices
+   */
+  std::vector<CachePolicy> &getCachePolicy() { return policies; }
+
+private:
+  std::string name;                        /**< pool name */
   std::shared_ptr<SwapDevice> swap_device; /**< swap device */
   CacheElems elems;                        /**< cache elements */
 
   std::list<std::shared_ptr<CacheElem>> actives;
+  std::vector<CachePolicy> policies;
+  std::map<unsigned int, ExecIds> exec_ids;
+
+  std::mutex mod_mutex;
 };
 
 } // namespace nntrainer
index 01f2705..b90a17f 100644 (file)
@@ -27,7 +27,8 @@ namespace nntrainer {
  */
 unsigned int MemoryPool::requestMemory(size_t bytes, unsigned int start_time,
                                        unsigned int end_time,
-                                       std::vector<unsigned int> exec_order) {
+                                       std::vector<unsigned int> exec_order,
+                                       TensorLifespan lifespan) {
   if (bytes == 0)
     throw std::invalid_argument("Requesting memory of 0 size");
 
index 523bed9..18a950a 100644 (file)
@@ -25,6 +25,7 @@
 
 #include <memory_data.h>
 #include <memory_planner.h>
+#include <tensor_wrap_specs.h>
 
 namespace nntrainer {
 
@@ -52,6 +53,8 @@ public:
    * @param bytes The size of the memory requested in bytes
    * @param start_time The start of the validity interval of this memory
    * @param end_time The end of the validity interval of this memory
+   * @param exec_order execution orders of this memory
+   * @param lifespan lifespan of memory
    *
    * @return The token to get the pointer for this memory after allocation
    * @note start_time is inclusive, but end_time is exclusive
@@ -59,7 +62,8 @@ public:
    */
   virtual unsigned int requestMemory(
     size_t bytes, unsigned int start_time, unsigned int end_time,
-    std::vector<unsigned int> exec_order = std::vector<unsigned int>());
+    std::vector<unsigned int> exec_order = std::vector<unsigned int>(),
+    TensorLifespan lifespan = TensorLifespan::MAX_LIFESPAN);
 
   /**
    * @brief Plan the layout with memory planner
@@ -195,7 +199,7 @@ private:
   std::vector<std::vector<unsigned int>>
     memory_exec_order; /**< execution order for the requested memory */
 
-  void *mem_pool;   /**< memory pool allocated at once */
+  void *mem_pool; /**< memory pool allocated at once */
 
   size_t pool_size; /**< memory requirement for this pool */
 
index a18e95a..732c0dc 100644 (file)
@@ -11,6 +11,7 @@
  *
  */
 
+#include <fcntl.h>
 #include <malloc.h>
 #include <profiler.h>
 #include <stdlib.h>
@@ -28,25 +29,31 @@ void SwapDevice::start(size_t size) {
   if (fd > 0)
     return;
 
-  fd = open(dev_path.c_str(), O_RDWR | O_CREAT | O_TRUNC, (mode_t)0666);
-  NNTR_THROW_IF(fd < 0, std::runtime_error) << "open file: " << dev_path;
+  fd =
+    open(dev_path.c_str(), O_RDWR | O_CREAT | O_TRUNC | O_SYNC, (mode_t)0666);
+  NNTR_THROW_IF(fd < 0, std::runtime_error)
+    << "SwapDevice: open file: " << dev_path;
 
   off_t off;
 
   /* make sparse file */
-  off = lseek(fd, (off_t)size - 1, SEEK_SET);
-  NNTR_THROW_IF(off < 0, std::runtime_error) << "seek file: " << dev_path;
+  off = lseek(fd, size - 1, SEEK_SET);
+  NNTR_THROW_IF(off < 0, std::runtime_error)
+    << "SwapDevice: seek file: " << dev_path;
 
   ssize_t len;
   len = write(fd, "", 1);
-  NNTR_THROW_IF(len != 1, std::runtime_error) << "write file: " << dev_path;
+  NNTR_THROW_IF(len != 1, std::runtime_error)
+    << "SwapDevice: write file: " << dev_path;
 
   off = lseek(fd, 0, SEEK_SET);
-  NNTR_THROW_IF(off < 0, std::runtime_error) << "seek file: " << dev_path;
+  NNTR_THROW_IF(off < 0, std::runtime_error)
+    << "SwapDevice: seek file: " << dev_path;
 }
 
-void *SwapDevice::getBuffer(off_t offset, size_t size) {
-  NNTR_THROW_IF(fd <= 0, std::runtime_error) << "SwapDevice is not started";
+void *SwapDevice::getBuffer(off_t offset, size_t size, bool alloc_only) {
+  NNTR_THROW_IF(fd <= 0, std::runtime_error)
+    << "SwapDevice: Device is not started";
 
 #ifdef USE_MMAP
   // page aligned
@@ -57,7 +64,7 @@ void *SwapDevice::getBuffer(off_t offset, size_t size) {
   char *ptr = static_cast<char *>(
     mmap(NULL, len, PROT_READ | PROT_WRITE, MAP_PRIVATE, fd, off));
   NNTR_THROW_IF(ptr == (void *)-1, std::runtime_error)
-    << "mmap: " << std::string(strerror(errno));
+    << "SwapDevice: mmap: " << std::string(strerror(errno));
 
   void *buf = static_cast<void *>(ptr + diff);
   mapped[buf] = std::make_pair(ptr, len);
@@ -69,14 +76,18 @@ void *SwapDevice::getBuffer(off_t offset, size_t size) {
   void *ptr;
 
   ptr = calloc(1, size);
-  NNTR_THROW_IF(ptr == NULL, std::runtime_error) << "memory alloc failed";
+  NNTR_THROW_IF(ptr == NULL, std::runtime_error)
+    << "SwapDevice: memory alloc failed";
 
-  off = lseek(fd, offset, SEEK_SET);
-  NNTR_THROW_IF(off < 0, std::runtime_error) << "seek file: " << dev_path;
+  if (!alloc_only) {
+    off = lseek(fd, offset, SEEK_SET);
+    NNTR_THROW_IF(off < 0, std::runtime_error)
+      << "SwapDevice: seek file: " << dev_path;
 
-  len = read(fd, ptr, size);
-  NNTR_THROW_IF(len != (ssize_t)size, std::runtime_error)
-    << "read file: " << dev_path;
+    len = read(fd, ptr, size);
+    NNTR_THROW_IF(len != (ssize_t)size, std::runtime_error)
+      << "SwapDevice: read file: " << dev_path;
+  }
 
   allocated[ptr] = std::make_pair(offset, (ssize_t)size);
 
@@ -84,8 +95,9 @@ void *SwapDevice::getBuffer(off_t offset, size_t size) {
 #endif
 }
 
-void SwapDevice::putBuffer(void *ptr) {
-  NNTR_THROW_IF(fd <= 0, std::runtime_error) << "SwapDevice is not started";
+void SwapDevice::putBuffer(void *ptr, bool dealloc_only) {
+  NNTR_THROW_IF(fd <= 0, std::runtime_error)
+    << "SwapDevice: Device is not started";
 #ifdef USE_MMAP
   int ret;
 
@@ -95,7 +107,7 @@ void SwapDevice::putBuffer(void *ptr) {
   auto info = mapped[ptr];
   ret = munmap(std::get<void *>(info), std::get<size_t>(info));
   NNTR_THROW_IF(ret == -1, std::runtime_error)
-    << "munmap: " << std::string(strerror(errno));
+    << "SwapDevice: munmap: " << std::string(strerror(errno));
 
   mapped.erase(ptr);
 
@@ -108,15 +120,19 @@ void SwapDevice::putBuffer(void *ptr) {
   ssize_t len;
 
   NNTR_THROW_IF(allocated.find(ptr) == allocated.end(), std::invalid_argument)
-    << "Couldn't find buffer";
+    << "SwapDevice: Couldn't find buffer";
 
   auto [offset, size] = allocated[ptr];
 
-  off = lseek(fd, offset, SEEK_SET);
-  NNTR_THROW_IF(off < 0, std::runtime_error) << "seek file: " << dev_path;
+  if (!dealloc_only) {
+    off = lseek(fd, offset, SEEK_SET);
+    NNTR_THROW_IF(off < 0, std::runtime_error)
+      << "SwapDevice: seek file: " << dev_path;
 
-  len = write(fd, ptr, size);
-  NNTR_THROW_IF(len != size, std::runtime_error) << "write file: " << dev_path;
+    len = write(fd, ptr, size);
+    NNTR_THROW_IF(len != size, std::runtime_error)
+      << "SwapDevice: write file: " << dev_path;
+  }
 
   free(ptr);
   allocated.erase(ptr);
@@ -151,7 +167,7 @@ void SwapDevice::finish() {
   int status = std::remove(dev_path.c_str());
 
   NNTR_THROW_IF(status, std::runtime_error)
-    << "Couldn't remove " << dev_path.c_str();
+    << "SwapDevice: Couldn't remove " << dev_path.c_str();
 }
 
 } // namespace nntrainer
index 84b47a1..11d5dcd 100644 (file)
@@ -77,19 +77,20 @@ public:
    *
    * @param offset Requested offset of swap device file
    * @param size Requested size.
+   * @param alloc_only only allocate buffer without reading data
    *
    * @return The pointer of the swap space
    *
    */
-  void *getBuffer(off_t offset, size_t size);
+  void *getBuffer(off_t offset, size_t size, bool alloc_only = false);
 
   /**
    * @brief Deallocate and put data
    *
    * @param ptr The pointer obtained from getBuffer
-   *
+   * @param dealloc_only only deallocate buffer without writing data
    */
-  void putBuffer(void *ptr);
+  void putBuffer(void *ptr, bool dealloc_only = false);
 
   /**
    * @brief Close device
index daa43d1..851f13c 100644 (file)
@@ -19,6 +19,7 @@
 
 #include <functional>
 #include <ini_wrapper.h>
+#include <neuralnet.h>
 #include <tensor_dim.h>
 
 inline constexpr const char *DIM_UNUSED = "1:1:1";
@@ -71,6 +72,7 @@ protected:
    *
    */
   nntrainerModelTest() :
+    props({"memory_swap=false"}),
     iteration(0),
     name(""),
     options(ModelTestOption::NO_THROW_RUN) {}
@@ -109,7 +111,9 @@ protected:
    * @return std::unique_ptr<nntrainer::NeuralNetwork> created model
    */
   std::unique_ptr<nntrainer::NeuralNetwork> createModel() {
-    return nn_creator();
+    auto nn = nn_creator();
+    nn->setProperty(props);
+    return nn;
   }
 
   /**
@@ -175,6 +179,8 @@ protected:
                 std::function<std::unique_ptr<nntrainer::NeuralNetwork>()>
                   creator = nullptr);
 
+  std::vector<std::string> props; /**< property to be initially set */
+
 private:
   /**
    * @brief Get the Golden Name object