[Tensor] Add Tensor Wrap method
authorJihoon Lee <jhoon.it.lee@samsung.com>
Tue, 15 Dec 2020 04:50:49 +0000 (13:50 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 24 Dec 2020 05:48:11 +0000 (14:48 +0900)
Add Tensor some factory methods
1. burrows external memory and use from
2. create from shared pointer without copy

To restrict unwanted use, those methods are static methods
called `Tensor::Wrap`

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/tensor/manager.cpp
nntrainer/tensor/manager.h
nntrainer/tensor/tensor.cpp
nntrainer/tensor/tensor.h
nntrainer/tensor/var_grad.cpp
nntrainer/tensor/var_grad.h
nntrainer/tensor/weight.cpp
nntrainer/tensor/weight.h
test/unittest/unittest_nntrainer_tensor.cpp

index 2a66266..6bc3780 100644 (file)
@@ -2,27 +2,53 @@
 /**
  * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
  *
- * @file       manager.cpp
- * @date       2 Dec 2020
- * @brief      This is NNtrainer manager for all weights, i/o and intermediate
+ * @file   manager.cpp
+ * @date   2 Dec 2020
+ * @brief  This is NNtrainer manager for all weights, i/o and intermediate
  * tensors
- * @see                https://github.com/nnstreamer/nntrainer
- * @author     Parichay Kapoor <pk.kapoor@samsung.com>
- * @bug                No known bugs except for NYI items
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @bug    No known bugs except for NYI items
  *
  */
 
+#ifdef __ANDROID__
+#include <android/sharedmem.h>
+#endif
+
+#include <cassert>
+#include <fcntl.h>
 #include <functional>
+#include <limits>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <unistd.h>
 #include <vector>
 
 #include <manager.h>
+#include <nntrainer_log.h>
 
 namespace nntrainer {
 
+Manager::Manager(bool enable_gradient_memory_opt_, bool use_shared_memory_) :
+  total_weight_size(0),
+  total_grad_size(0),
+  max_grad_size(0),
+  enable_gradient_memory_opt(enable_gradient_memory_opt_),
+  use_shared_memory(use_shared_memory_),
+  fd(-1),
+  buf(nullptr),
+  buf_size(0) {}
+
+Manager::~Manager() { releaseSharedMemory(); }
+
 /**
  * @brief     Add weight to be tracked and updated with nntrainer
  */
 void Manager::trackWeight(std::reference_wrapper<Weight> w) {
+  /// @warning this does not track the weight size etcs.. This might break when
+  /// use_shared_memory = true
   std::vector<std::reference_wrapper<Weight>> temp = {w};
   weights.emplace_back(temp);
 }
@@ -35,37 +61,148 @@ void Manager::trackWeights(std::vector<Weight> &ws) {
   layer_weights.reserve(ws.size());
 
   size_t weight_size = 0;
+  size_t grad_size = 0;
 
   for (auto &w : ws) {
     layer_weights.emplace_back(std::ref(w));
+    size_t len = w.getDim().getDataLen();
+    weight_size += len;
     if (w.getTrainable())
-      weight_size += w.getDim().getDataLen();
+      grad_size += len;
   }
 
   weights.push_back(layer_weights);
 
-  max_weight_size = std::max(max_weight_size, weight_size);
+  total_weight_size += weight_size;
+  total_grad_size += grad_size;
+  max_grad_size = std::max(max_grad_size, grad_size);
 }
 
 /**
  * @brief Allocate and initialize the weight variable
  */
 void Manager::initialize() {
-  Tensor shared_grad;
-  if (max_weight_size > 0 && enable_gradient_memory_opt)
-    shared_grad = Tensor(max_weight_size);
+  if (total_weight_size == 0) {
+    ml_logw("Nothing done on initialize because there is no weight registered");
+    return;
+  }
+  using AllocFunc = std::function<Tensor(const TensorDim &, size_t)>;
+
+  AllocFunc allocate_none = [](const TensorDim &dim, size_t) {
+    return Tensor();
+  };
+
+  AllocFunc allocate_weight = allocate_none;
+  AllocFunc allocate_grad = allocate_none;
+
+  if (use_shared_memory) {
+    size_t grad_size =
+      enable_gradient_memory_opt ? max_grad_size : total_grad_size;
+    size_t total_size = total_weight_size + grad_size;
+
+    if (total_size >= std::numeric_limits<size_t>::max() / sizeof(float)) {
+      throw std::invalid_argument(
+        "weights exceed maximum size supported for shared memory");
+    }
+
+    size_t weight_bytes_size =
+      (total_weight_size + total_grad_size) * sizeof(float);
+    initializeSharedMemory(weight_bytes_size);
+
+    allocate_grad = allocate_weight = [&](const TensorDim &dim, size_t offset) {
+      return Tensor::Wrap(buf, dim, offset);
+    };
+
+  } else {
+    if (max_grad_size > 0 && enable_gradient_memory_opt) {
+      std::shared_ptr<float> window(new float[max_grad_size],
+                                    std::default_delete<float[]>());
+
+      allocate_grad = [window](const TensorDim &dim, size_t offset) {
+        return Tensor::Wrap(window, dim, offset);
+      };
+    }
+  }
+
+  size_t weight_offset = 0;
+  size_t grad_initial_offset = use_shared_memory ? total_weight_size : 0;
+  size_t grad_offset = grad_initial_offset;
 
   for (auto &l_w : weights) {
-    size_t offset = 0;
+    if (enable_gradient_memory_opt) {
+      grad_offset = grad_initial_offset;
+    }
+
     for (auto &w : l_w) {
       Weight &weight = w.get();
-      if (weight.getTrainable() && enable_gradient_memory_opt) {
-        weight.initialize(
-          shared_grad.getSharedDataTensor(weight.getDim(), offset));
-        offset += weight.getDim().getDataLen();
-      } else {
-        weight.initialize();
-      }
+      auto dim = weight.getDim();
+      Tensor weight_prealloc = allocate_weight(dim, weight_offset);
+      Tensor grad_prealloc =
+        weight.getTrainable() ? allocate_grad(dim, grad_offset) : Tensor();
+
+      weight_offset += dim.getDataLen();
+      grad_offset += dim.getDataLen();
+      weight.initialize(weight_prealloc, grad_prealloc);
+    }
+  }
+}
+
+void Manager::initializeSharedMemory(size_t size) {
+
+#ifdef __ANDROID__
+  /// unfortunately, memfd_create is not supported before android level 30
+  auto fd_ = ASharedMemory_create("", size);
+  if (fd_ < 0) {
+    releaseSharedMemory();
+    throw std::runtime_error("[Manager] creating mem fd failed");
+  }
+
+  if (ASharedMemory_setProt(fd_, PROT_READ | PROT_WRITE) < 0) {
+    releaseSharedMemory();
+    throw std::runtime_error("[Manager] Setting prot failed");
+  }
+
+  auto buf_ = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0);
+#else
+  /// @todo create a file in tmpfs and bind to memfs
+  /// memfd_create is not available for number of platforms so this is commented
+  // auto fd_ = memfd_create("", 0);
+  // if (fd_ < 0) {
+  //   releaseSharedMemory();
+  //   throw std::runtime_error("[Manager] creating mem fd failed");
+  // }
+  // if (ftruncate(fd_, size) < 0) {
+  //   releaseSharedMemory();
+  //   throw std::runtime_error("[Manager] truncating fd failed");
+  // }
+
+  auto fd_ = -1;
+  auto buf_ = mmap(NULL, size, PROT_READ | PROT_WRITE,
+                   MAP_PRIVATE | MAP_ANONYMOUS, fd_, 0);
+#endif
+  if (buf_ == MAP_FAILED) {
+    releaseSharedMemory();
+    throw std::runtime_error("[Manager] mmap failed");
+  }
+
+  buf = reinterpret_cast<float *>(buf_);
+  fd = fd_;
+  buf_size = size;
+}
+
+void Manager::releaseSharedMemory() noexcept {
+  if (buf != nullptr) {
+#ifdef DEBUG
+    assert(buf_size > 0);
+#endif
+    if (munmap(buf, buf_size) < 0) {
+      ml_logw("[Manager] munmap failed on destruction please check");
+    }
+  }
+
+  if (fd != -1) {
+    if (close(fd) < 0) {
+      ml_logw("[Manager] closing fd failed on destruction please check");
     }
   }
 }
index 82e6030..500fecc 100644 (file)
@@ -2,13 +2,13 @@
 /**
  * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
  *
- * @file       manager.h
- * @date       30 Nov 2020
- * @brief      This is NNtrainer manager for all weights, i/o and intermediate
+ * @file   manager.h
+ * @date   30 Nov 2020
+ * @brief  This is NNtrainer manager for all weights, i/o and intermediate
  * tensors
- * @see                https://github.com/nnstreamer/nntrainer
- * @author     Parichay Kapoor <pk.kapoor@samsung.com>
- * @bug                No known bugs except for NYI items
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug           No known bugs except for NYI items
  *
  */
 
@@ -33,12 +33,18 @@ public:
   /**
    * @brief     Constructor of Manager
    */
-  Manager() : max_weight_size(0), enable_gradient_memory_opt(true) {}
+  Manager(bool enable_gradient_memory_opt_ = true,
+          bool use_shared_memory_ = true);
+
+  /// @todo copy ctor / assignment ops but leave move ctor / assignment
+  Manager(const Manager &) = default;
+
+  Manager &operator=(const Manager &) = default;
 
   /**
    * @brief     Destructor of Manager
    */
-  ~Manager() {}
+  ~Manager();
 
   /**
    * @brief     Add weight to be tracked and updated with nntrainer
@@ -72,6 +78,15 @@ public:
   }
 
   /**
+   * @brief Get the File descriptor.
+   * Will return -1 except for android
+   * @todo make this available for other platforms
+   *
+   * @return -1 if not applicable, else file descriptor
+   */
+  int getFd() noexcept { return fd; }
+
+  /**
    * @brief Allocate and initialize the weight variable
    */
   void initialize();
@@ -81,17 +96,42 @@ public:
    */
   void reset() {
     weights.clear();
-    max_weight_size = 0;
+    max_grad_size = 0;
+    total_weight_size = 0;
+    total_grad_size = 0;
+    releaseSharedMemory();
   }
 
 private:
+  /**
+   * @brief initialize shared memory, buf_size is set here
+   *
+   * @param size Byte size
+   */
+  void initializeSharedMemory(size_t size);
+
+  /**
+   * @brief release shared memory, if use_sha
+   *
+   */
+  void releaseSharedMemory() noexcept;
+
   // TODO: ensure that names of these weights are unique
   /**< Weights all the layer in the model to be managed */
   std::vector<std::vector<std::reference_wrapper<Weight>>> weights;
 
-  size_t max_weight_size; /**< max weight required by a layer */
+  size_t total_weight_size; /**< total weight size */
+  size_t total_grad_size;   /**< total weight size */
+  size_t max_grad_size;     /**< max trainable weight required by a layer */
 
   bool enable_gradient_memory_opt; /**< share memory among all the gradients */
+
+  /**< shared memory related */
+  bool use_shared_memory; /**< uses shared memory object which is owned by
+                             manager */
+  int fd;                 /**< fd to access the shared_memory  */
+  float *buf;             /**< buffer object when use_shared_memory */
+  size_t buf_size;        /**< buffer size */
 };
 
 } // namespace nntrainer
index 09fc60c..f6ae6f8 100644 (file)
@@ -92,6 +92,35 @@ Tensor::Tensor(const TensorDim &d, const float *buf) : Tensor() {
   }
 }
 
+Tensor Tensor::Wrap(float *buf, const TensorDim &d, int offset) {
+  if (d.getDataLen() == 0 || buf == nullptr) {
+    throw std::invalid_argument(
+      "[Tensor::Wrap] empty tensor dim is not allowed");
+  }
+
+  Tensor tmp;
+  tmp.dim = d;
+  tmp.strides = d.computeStrides();
+  /// Tensor does not own the memory
+  tmp.data = std::shared_ptr<float>(buf + offset, [](void *) {});
+
+  return tmp;
+}
+
+Tensor Tensor::Wrap(std::shared_ptr<float> buf, const TensorDim &d,
+                    int offset) {
+  if (d.getDataLen() == 0 || buf == nullptr) {
+    throw std::invalid_argument(
+      "[Tensor::Wrap] empty tensor dim is not allowed");
+  }
+
+  Tensor tmp;
+  tmp.dim = d;
+  tmp.data = std::shared_ptr<float>(buf, buf.get() + offset);
+
+  return tmp;
+}
+
 bool Tensor::operator==(const Tensor &rhs) const {
   if (this->dim != rhs.dim)
     return false;
index 573f7c0..29dad5f 100644 (file)
@@ -59,6 +59,31 @@ public:
   Tensor(const TensorDim &d, const float *buf = nullptr);
 
   /**
+   * @brief Construct a new Tensor object from a buffer
+   * This will not copy buffer to a new tensor but directly uses it
+   *
+   * @param d tensor dim
+   * @param buf buffer
+   * @param offset offset to be used from current
+   * @return Tensor object
+   * @throws std::invalid_argument if buf is null
+   */
+  static Tensor Wrap(float *buf, const TensorDim &d, int offset = 0);
+
+  /**
+   * @brief Construct a new Tensor object from a buffer
+   * This will shared the buf
+   *
+   * @param d tensor dim
+   * @param buf buffer
+   * @param offset offset to be used
+   * @return Tensor object
+   * @throws std::invalid_argument if buf is null
+   */
+  static Tensor Wrap(std::shared_ptr<float> buf, const TensorDim &d,
+                     int offset = 0);
+
+  /**
    * @brief     Constructor of Tensor
    * @param[in] batch Batch of Tensor
    * @param[in] channel Channel of Tensor
index 604b2a1..458fb68 100644 (file)
@@ -24,15 +24,20 @@ Var_Grad::Var_Grad(const TensorDim &dim, bool train, const std::string &name) :
   grad = std::make_shared<Tensor>();
 }
 
-void Var_Grad::initialize(const Tensor &grad_shared) {
-  var = std::make_shared<Tensor>(dim);
+void Var_Grad::initialize(const Tensor &weights_preallocated,
+                          const Tensor &grad_preallocated) {
+  if (!weights_preallocated.uninitialized()) {
+    var = std::make_shared<Tensor>(weights_preallocated);
+  } else {
+    var = std::make_shared<Tensor>(dim);
+  }
 
-  if (!grad_shared.uninitialized()) {
+  if (!grad_preallocated.uninitialized()) {
     /**
      * Making a new tensor is intentional here as this tensor is not shared
      * with other layers but the internal memory is.
      */
-    grad = std::make_shared<Tensor>(grad_shared);
+    grad = std::make_shared<Tensor>(grad_preallocated);
   } else {
     grad = std::make_shared<Tensor>();
     if (trainable) {
index 673e162..876d59f 100644 (file)
@@ -92,8 +92,12 @@ public:
 
   /**
    * @brief Allocate and initialize the weight variable
+   *
+   * @param weight_preallocated if initialized, use this tensor for weight
+   * @param grad_preallocated if initialized, use this tensor for grad
    */
-  virtual void initialize(const Tensor &grad_shared = Tensor());
+  virtual void initialize(const Tensor &weight_preallocated = Tensor(),
+                          const Tensor &grad_preallocated = Tensor());
 
   /**
    * @brief Get the TensorDim
index fb9bd03..562c42f 100644 (file)
@@ -24,8 +24,9 @@ Weight::Weight(const TensorDim &dim, const WeightInitializer init, bool train,
     throw std::invalid_argument("Weight initializer unknown");
 }
 
-void Weight::initialize(const Tensor &grad_shared) {
-  Var_Grad::initialize(grad_shared);
+void Weight::initialize(const Tensor &weights_preallocated,
+                        const Tensor &grad_preallocated) {
+  Var_Grad::initialize(weights_preallocated, grad_preallocated);
 
   Tensor &var_ref = getVariableRef();
   const TensorDim dim = var_ref.getDim();
index 284082a..d0d0943 100644 (file)
@@ -79,9 +79,10 @@ public:
     bool train = true, std::string name = "");
 
   /**
-   * @brief Allocate and initialize the weight variable
+   * @copydoc var_grad::initialize(const Tensor &, const Tensor &)
    */
-  void initialize(const Tensor &grad_shared = Tensor());
+  void initialize(const Tensor &weight_preallocated = Tensor(),
+                  const Tensor &grad_preallocated = Tensor());
 
   /**
    * @brief Swap for weight
index f9a8350..89a533f 100644 (file)
@@ -78,6 +78,28 @@ TEST(nntrainer_TensorDim, setTensorDim_04_p) {
   EXPECT_EQ(d.width(), 7);
 }
 
+TEST(nntrainer_Tensor, TensorWrap_p) {
+  float dat[] = {1, 2, 3};
+
+  {
+    nntrainer::Tensor a = nntrainer::Tensor::Wrap(dat, {3});
+    /// check if a.getData() has same address with dat
+    EXPECT_EQ(dat, a.getData());
+    {
+      /// check if b.getData() has same address with data
+      nntrainer::Tensor b = a;
+      EXPECT_EQ(dat, b.getData());
+    }
+  }
+  /// check if dat is accessible after destruction of all the tensor
+  EXPECT_FLOAT_EQ(dat[2], 3);
+}
+
+TEST(nntrainer_Tensor, TensorWrap_n) {
+  float dat[] = {1, 2, 3};
+  EXPECT_THROW(nntrainer::Tensor::Wrap(dat, {}), std::invalid_argument);
+}
+
 TEST(nntrainer_Tensor, Tensor_01_p) {
   int status = ML_ERROR_NONE;
   nntrainer::Tensor tensor = nntrainer::Tensor(1, 2, 3);