[Manager] Use TensorPool for Gradients
authorParichay Kapoor <pk.kapoor@samsung.com>
Wed, 25 Aug 2021 04:48:46 +0000 (13:48 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 30 Sep 2021 23:13:38 +0000 (08:13 +0900)
Use TensorPool for gradients of the weights.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
jni/Android.mk
nntrainer/graph/network_graph.cpp
nntrainer/layers/layer_context.cpp
nntrainer/layers/layer_context.h
nntrainer/layers/layer_node.h
nntrainer/tensor/manager.cpp
nntrainer/tensor/manager.h

index 9a21c91..796772c 100644 (file)
@@ -139,6 +139,7 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/tensor/var_grad.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/tensor/weight.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/tensor/tensor_dim.cpp \
+                  $(NNTRAINER_ROOT)/nntrainer/tensor/tensor_pool.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/tensor/memory_pool.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/tensor/basic_planner.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/tensor/blas_interface.cpp \
index c978565..63d7ee9 100644 (file)
@@ -430,12 +430,15 @@ void NetworkGraph::setBatchSize(unsigned int batch_size) {
 
   for (auto iter = cbegin(); iter != cend(); iter++) {
     (*iter)->setBatch(batch_size);
-    const InitLayerContext &init_context = (*iter)->getInitContext();
-    // resize tensors spec
-    for (auto const &ts : init_context.getTensorsSpec()) {
-      tensor_manager->setBatchSize(std::get<3>(ts), batch_size);
-      tensor_manager->setBatchSize(std::get<3>(ts) + Var_Grad::grad_suffix,
-                                   batch_size);
+    if ((*iter)->isRunContextAvailable()) {
+      const RunLayerContext &context = (*iter)->getRunContext();
+      // resize tensors spec
+      for (unsigned int idx = 0; idx < context.getNumTensors(); idx++) {
+        auto const &ts = context.getTensor(idx);
+        tensor_manager->setBatchSize(ts.getName(), batch_size);
+        auto const &ts_grad = context.getTensorGrad(idx);
+        tensor_manager->setBatchSize(ts_grad.getName(), batch_size);
+      }
     }
   }
   tensor_manager->setBatchSize(batch_size);
index c65b0ed..bccf574 100644 (file)
@@ -181,6 +181,16 @@ Tensor &RunLayerContext::getTensor(unsigned int idx) {
 }
 
 /**
+ * @brief Get the Tensor object
+ *
+ * @param idx Identifier of the tensor
+ * @return Tensor& Reference to the tensor
+ */
+const Tensor &RunLayerContext::getTensor(unsigned int idx) const {
+  return tensors[idx]->getVariableRef();
+}
+
+/**
  * @brief Get the Tensor Grad object
  *
  * @param idx Identifier of the tensor
@@ -194,6 +204,19 @@ Tensor &RunLayerContext::getTensorGrad(unsigned int idx) {
 }
 
 /**
+ * @brief Get the Tensor Grad object
+ *
+ * @param idx Identifier of the tensor
+ * @return Tensor& Reference to the tensor grad tensor
+ */
+const Tensor &RunLayerContext::getTensorGrad(unsigned int idx) const {
+  if (!tensors[idx]->hasGradient())
+    throw std::invalid_argument(
+      "Requesting gradient for a non-trainable tensor.");
+  return tensors[idx]->getGradientRef();
+}
+
+/**
  * @brief check if the tensor has gradient
  *
  * @param idx Identifier of the tensor
index 54e19ff..a09a040 100644 (file)
@@ -429,6 +429,14 @@ public:
   Tensor &getTensor(unsigned int idx);
 
   /**
+   * @brief Get the Tensor object
+   *
+   * @param idx Identifier of the tensor
+   * @return Tensor& Reference to the tensor
+   */
+  const Tensor &getTensor(unsigned int idx) const;
+
+  /**
    * @brief Get the Tensor Grad object
    *
    * @param idx Identifier of the tensor
@@ -437,6 +445,14 @@ public:
   Tensor &getTensorGrad(unsigned int idx);
 
   /**
+   * @brief Get the Tensor Grad object
+   *
+   * @param idx Identifier of the tensor
+   * @return Tensor& Reference to the tensor grad tensor
+   */
+  const Tensor &getTensorGrad(unsigned int idx) const;
+
+  /**
    * @brief check if the tensor has gradient
    *
    * @param idx Identifier of the tensor
index 88a6670..abeb71e 100644 (file)
@@ -561,6 +561,18 @@ public:
   }
 
   /**
+   * @brief   check if run layer context is available
+   *
+   * @retval  bool true if context is available else false
+   */
+  bool isRunContextAvailable() const {
+    if (!run_context)
+      return false;
+
+    return true;
+  }
+
+  /**
    * @brief Set the Run Context object with given tensor packs
    *
    * @param weights weights
index f769291..f2536ab 100644 (file)
@@ -324,17 +324,14 @@ void Manager::deallocateWeights() {
 }
 
 void Manager::allocateGradients() {
-  /** Allocate the source tensors for shared memories */
-  if (!shared_grad.empty())
-    shared_grad.allocate();
-
   if (LAYER_V2) {
     for (auto &w : weights_v2) {
       w->allocateOptimizerVariables();
     }
-    if (tensor_pool.minMemoryRequirement() > 0)
-      tensor_pool.allocate();
   } else {
+    /** Allocate the source tensors for shared memories */
+    if (!shared_grad.empty())
+      shared_grad.allocate();
     for (auto &l_w : weights) {
       for (auto &w : l_w) {
         Weight &weight = w.get();
@@ -345,14 +342,12 @@ void Manager::allocateGradients() {
 }
 
 void Manager::deallocateGradients() {
-  shared_grad.deallocate();
-
   if (LAYER_V2) {
     for (auto &w : weights_v2) {
       w->deallocateOptimizerVariables();
     }
-    tensor_pool.deallocate();
   } else {
+    shared_grad.deallocate();
     for (auto &l_w : weights) {
       for (auto &w : l_w) {
         Weight &weight = w.get();
@@ -366,9 +361,7 @@ void Manager::deallocateGradients() {
  * @brief Initialize the weight gradients
  */
 void Manager::initializeGradients() {
-  if (LAYER_V2) {
-    tensor_pool.finalize(BasicPlanner(), 0, max_exec_order);
-  } else {
+  if (!LAYER_V2) {
     if (total_weight_size == 0) {
       ml_logw(
         "Nothing done on initialize because there is no weight registered");
@@ -510,9 +503,9 @@ void Manager::allocateInOuts() {
     for (auto &out : outputs_v2) {
       out->allocateVariable();
     }
-    for (auto &t : tensors_v2) {
-      t->allocateVariable();
-    }
+    // for (auto &t : tensors_v2) {
+    //   t->allocateVariable();
+    // }
   } else {
     for (auto &l_io : in_outs) {
       for (auto &io : l_io) {
@@ -532,9 +525,9 @@ void Manager::deallocateInOuts() {
     for (auto &out : outputs_v2) {
       out->deallocateVariable();
     }
-    for (auto &t : tensors_v2) {
-      t->deallocateVariable();
-    }
+    // for (auto &t : tensors_v2) {
+    //   t->deallocateVariable();
+    // }
   } else {
     for (auto &l_io : in_outs) {
       for (auto &io : l_io) {
@@ -556,9 +549,9 @@ void Manager::allocateDerivatives() {
     for (auto &out : outputs_v2) {
       out->allocateGradient();
     }
-    for (auto &t : tensors_v2) {
-      t->allocateGradient();
-    }
+    // for (auto &t : tensors_v2) {
+    //   t->allocateGradient();
+    // }
   } else {
     for (auto &l_io : in_outs) {
       for (auto &io : l_io) {
@@ -578,9 +571,9 @@ void Manager::deallocateDerivatives() {
     for (auto &out : outputs_v2) {
       out->deallocateGradient();
     }
-    for (auto &t : tensors_v2) {
-      t->deallocateGradient();
-    }
+    // for (auto &t : tensors_v2) {
+    //   t->deallocateGradient();
+    // }
   } else {
     for (auto &l_io : in_outs) {
       for (auto &io : l_io) {
@@ -653,9 +646,9 @@ void Manager::initializeTensorsInference() {
     }
 
     // Inference Mode without optimizations
-    for (auto &ts : tensors_v2) {
-      ts->initialize(Tensor(), Tensor(), false);
-    }
+    // for (auto &ts : tensors_v2) {
+    //   ts->initialize(Tensor(), Tensor(), false);
+    // }
 
     // In inference mode, do not allocate the memory for the input of the first
     // layer. These is the first entry in the in_outs. Inference() will override
@@ -672,11 +665,10 @@ void Manager::initializeTensorsTrain() {
   // Initialize gradients
   initializeGradients();
 
-  // Initialize shared derivative memory
-  if (max_derivative_size > 0 && enable_activation_memory_opt)
-    shared_deriv = Tensor(TensorDim({max_derivative_size}), false);
-
   if (!LAYER_V2) {
+    // Initialize shared derivative memory
+    if (max_derivative_size > 0 && enable_activation_memory_opt)
+      shared_deriv = Tensor(TensorDim({max_derivative_size}), false);
     for (unsigned int idx = 0; idx < in_outs.size(); idx++) {
       auto &l_io = in_outs[idx];
       unsigned int offset = 0;
@@ -703,15 +695,17 @@ void Manager::initializeTensorsTrain() {
       }
     }
   } else {
+    tensor_pool.finalize(BasicPlanner(), 0, max_exec_order);
+
     // Training Mode without optimizations
     for (auto &outs : outputs_v2) {
       outs->initialize(Tensor(), Tensor(), true);
     }
 
     // Training Mode without optimizations
-    for (auto &ts : tensors_v2) {
-      ts->initialize(Tensor(), Tensor(), true);
-    }
+    // for (auto &ts : tensors_v2) {
+    //   ts->initialize(Tensor(), Tensor(), true);
+    // }
 
     // Training Mode without optimizations
     for (auto &ins : inputs_v2) {
@@ -814,34 +808,61 @@ Manager::requestWeights(const GraphNode &node,
 std::vector<Var_Grad *>
 Manager::requestTensors(const GraphNode &node,
                         const std::vector<Var_Grad::Spec> &tensors_spec) {
-  auto ret = requestTensors<Var_Grad>(node, tensors_spec, tensors_v2);
   const auto &exec_order = node.getExecutionOrder();
-  for (unsigned int idx = 0; idx < ret.size(); idx++) {
-    auto const &t = ret[idx];
-    auto const &vname = t->getName();
-    auto const &gname = t->getGradientName();
-    auto const &tspan = std::get<4>(tensors_spec[idx]);
+
+  std::vector<Var_Grad *> ret;
+  size_t current_size = tensors_v2.size();
+
+  for (auto const &ts : std::as_const(tensors_spec)) {
+    auto const &tspan = std::get<4>(ts);
+    std::vector<unsigned int> var_exec_order;
+    std::vector<unsigned int> grad_exec_order;
 
     /** usage for tensors */
     if (enum_class_logical_and<TensorLifespan>(
           tspan, TensorLifespan::FORWARD_FUNC_LIFESPAN))
-      tensor_exec_order[vname].push_back(std::get<0>(exec_order));
+      var_exec_order.push_back(std::get<0>(exec_order));
 
     /** usage for tensors gradient in backwarding */
     if (enum_class_logical_and<TensorLifespan>(
           tspan, TensorLifespan::BACKWARD_FUNC_LIFESPAN)) {
-      tensor_exec_order[vname].push_back(std::get<1>(exec_order));
-      tensor_exec_order[gname].push_back(std::get<1>(exec_order));
+      var_exec_order.push_back(std::get<1>(exec_order));
+      grad_exec_order.push_back(std::get<1>(exec_order));
 
-      tensor_exec_order[vname].push_back(std::get<2>(exec_order));
-      tensor_exec_order[gname].push_back(std::get<2>(exec_order));
+      var_exec_order.push_back(std::get<2>(exec_order));
+      grad_exec_order.push_back(std::get<2>(exec_order));
     }
 
-    /** set tensor lifespan */
-    expandLifespan(vname, tspan);
-    expandLifespan(gname, tspan);
+    Tensor *var =
+      tensor_pool.requestTensor(std::get<0>(ts), /// tensor dim
+                                var_exec_order,
+                                tspan,           /// lifespan
+                                std::get<3>(ts), /// name
+                                std::get<1>(ts)  /// tensor initializer
+      );
+    max_exec_order =
+      std::max(max_exec_order,
+               *std::max_element(var_exec_order.begin(), var_exec_order.end()));
+
+    Tensor *grad = nullptr;
+    // TODO: change to enum_class_and
+    if (std::get<2>(ts) /** need gradient */ &&
+        enum_class_or(tspan, TensorLifespan::FORWARD_FUNC_LIFESPAN) !=
+          TensorLifespan::FORWARD_FUNC_LIFESPAN)
+      grad = tensor_pool.requestTensor(
+        std::get<0>(ts), /// tensor dim
+        grad_exec_order, tspan,
+        std::get<3>(ts) + Var_Grad::grad_suffix, /// name
+        Tensor::Initializer::ZEROS               /// tensor initializer
+      );
+
+    tensors_v2.emplace_back(std::make_unique<Var_Grad>(var, grad));
   }
 
+  std::transform(tensors_v2.begin() + current_size, tensors_v2.end(),
+                 std::back_inserter(ret),
+                 [](auto const &elem) { return elem.get(); });
+
   return ret;
 }
 
index b4254e2..789e33d 100644 (file)
@@ -26,6 +26,7 @@
 #include <unordered_map>
 #include <vector>
 
+#include <basic_planner.h>
 #include <graph_node.h>
 #include <tensor_pool.h>
 #include <var_grad.h>
@@ -366,6 +367,11 @@ public:
    */
   void initializeTensors(bool training);
 
+  /**
+   * @brief   Check if the manager has allocated tensors
+   *
+   * @return true if tensors allocated, else false
+   */
   bool isAllocated() const { return tensors_allocated; }
 
   /**
@@ -412,11 +418,15 @@ public:
       allocateWeights();
 
     if (!tensors_allocated) {
+      tensor_pool.finalize(BasicPlanner(), 0, max_exec_order);
       if (model_training)
         allocateGradients();
       allocateInOuts();
       if (model_training)
         allocateDerivatives();
+
+      if (tensor_pool.minMemoryRequirement() > 0)
+        tensor_pool.allocate();
       tensors_allocated = true;
     }
   }
@@ -435,6 +445,7 @@ public:
       if (model_training)
         deallocateDerivatives();
 
+      tensor_pool.deallocate();
       tensors_allocated = false;
     }
   }