[model/graph] Clip the gradients and then apply
authorParichay Kapoor <pk.kapoor@samsung.com>
Wed, 24 Nov 2021 15:01:30 +0000 (00:01 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 1 Dec 2021 09:54:01 +0000 (18:54 +0900)
- Calculate the global norm for the gradients which needs to be clipped
- Clip the gradients with the calculated global norm
- apply the clipped gradients

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/graph/network_graph.cpp
nntrainer/graph/network_graph.h
nntrainer/models/neuralnet.cpp
nntrainer/tensor/manager.cpp
nntrainer/tensor/manager.h
nntrainer/tensor/var_grad.h
nntrainer/tensor/weight.h

index 4bf5699..bb30cac 100644 (file)
@@ -12,6 +12,7 @@
  * @todo    Support multi-input graph.
  */
 
+#include <cmath>
 #include <sstream>
 
 #include <activation_layer.h>
@@ -289,8 +290,8 @@ void NetworkGraph::setBatchSize(unsigned int batch_size) {
     label_dims[idx] = tensor_manager->getTensor(label_list[idx])->getDim();
 }
 
-void NetworkGraph::applyGradients(LayerNode *node,
-                                  std::function<void(Weight &)> apply_func) {
+void NetworkGraph::applyGradients(
+  LayerNode *node, const std::function<void(Weight &)> &apply_func) {
   auto &rc = node->getRunContext();
   auto num_weight = rc.getNumWeights();
   for (unsigned i = 0; i < num_weight; ++i) {
@@ -339,7 +340,8 @@ sharedConstTensors NetworkGraph::forwarding(bool training) const {
 
 void NetworkGraph::backwarding(
   int iteration,
-  std::function<void(std::shared_ptr<LayerNode>, int)> &backwarding_op) const {
+  std::function<void(std::shared_ptr<LayerNode>, int)> &backwarding_op,
+  std::function<void(Weight &, int)> &apply_grad_clip_op) const {
   /**
    * last layer backwarding is run out of this loop
    */
@@ -363,6 +365,28 @@ void NetworkGraph::backwarding(
     backwarding_op(ln, iteration);
     END_PROFILE(profile_keys.at(ln->getType()));
   }
+
+  /** perform clipping of the gradients by global norm if any */
+  if (clip_weights.empty())
+    return;
+
+  /** calculate the global norm */
+  Tensor global_norm_t(
+    TensorDim({1u, 1u, 1u, (unsigned int)clip_weights.size()}));
+  float *global_norm_data = global_norm_t.getData();
+  for (unsigned int idx = 0; idx < clip_weights.size(); idx++) {
+    auto const &w = clip_weights[idx];
+    global_norm_data[idx] = w->getGradientNorm();
+  }
+  float global_norm = global_norm_t.l2norm();
+  /** apply the gradient with the above global norm */
+  for (auto w : clip_weights) {
+    w->clipGradientByGlobalNorm(global_norm);
+  }
+  /** apply the gradient with the above global norm */
+  for (auto w : clip_weights) {
+    apply_grad_clip_op(*w, iteration);
+  }
 }
 
 void NetworkGraph::setMaxExecutionOrder(bool skip_optimize) {
@@ -545,26 +569,27 @@ NetworkGraph::canExecuteInPlace(const std::shared_ptr<LayerNode> &lnode) {
 
   /**
    * @note Conditions to decide if this layer node can be in-place:
-   * This is a generic case where the layer can support in-place but will modify
-   * its input in-place. This includes layers like activation, etc. Apply checks
-   * below to ensure that the layers can work in-place:
+   * This is a generic case where the layer can support in-place but will
+   * modify its input in-place. This includes layers like activation, etc.
+   * Apply checks below to ensure that the layers can work in-place:
    * - if any of the input layer are restriction, then this layer cannot work
    *   as layers behind this layer have added restrictions.
    * - if all of the input layers are either not inplace or have no
    * restrictions, then this layer can operate in-place.
    *
    * @note Conditions to decide the type of inplace for this layer:
-   * This is a generic case, and always restrictions on the next nodes to be not
-   * inplace.
+   * This is a generic case, and always restrictions on the next nodes to be
+   * not inplace.
    *
    * @note This logic is prone to change as more layers are allowed to
    * work in-place such as concat layer, split layer, addition layer, dropout
    * layer, etc.
    *
-   * @todo This logic sets layers to in-place one-by-one as they arrive. However
-   * setting some layers to in-place can save more memory than others (like
-   * multiout layer vs activaiton layer). The layers need to sorted based on the
-   * memory save they provide and then make them in-place in that order.
+   * @todo This logic sets layers to in-place one-by-one as they arrive.
+   * However setting some layers to in-place can save more memory than others
+   * (like multiout layer vs activaiton layer). The layers need to sorted
+   * based on the memory save they provide and then make them in-place in that
+   * order.
    */
   if (lnode->getType() == ActivationLayer::type ||
       lnode->getType() == BatchNormalizationLayer::type) {
@@ -577,8 +602,8 @@ NetworkGraph::canExecuteInPlace(const std::shared_ptr<LayerNode> &lnode) {
 
     /**
      * if the layer does io_independent_backwarding where the input and output
-     * is not requried during backwarding, then it is a non-restricting in-place
-     * layer.
+     * is not requried during backwarding, then it is a non-restricting
+     * in-place layer.
      */
     if (io_independent_backwarding(lnode))
       return InPlace::NON_RESTRICTING;
@@ -618,11 +643,11 @@ setInplaceSharedMemoryConfigByLayer(const std::shared_ptr<LayerNode> &lnode,
   }
   /** @todo for addition layer, variables are not shared but gradients are */
   /**
-   * @todo for layers which support in-place, both variables and gradients will
-   * be be shared.
+   * @todo for layers which support in-place, both variables and gradients
+   * will be be shared.
    *
-   * @todo add a check here is the layer being checked here can support in-place
-   * or not
+   * @todo add a check here is the layer being checked here can support
+   * in-place or not
    */
 }
 
@@ -641,8 +666,8 @@ NetworkGraph::finalizeContext(const std::shared_ptr<LayerNode> &lnode,
 
   /**
    * Request manager for either a pre-allocated output as input or a newly
-   * allocated input. This is necesary for manager to know when this input node
-   * is going to be used.
+   * allocated input. This is necesary for manager to know when this input
+   * node is going to be used.
    */
   std::vector<std::string> input_names;
   input_names.reserve(prev_inputs.size());
@@ -664,8 +689,8 @@ NetworkGraph::finalizeContext(const std::shared_ptr<LayerNode> &lnode,
 
   /**
    * Request manager for either a pre-allocated input as output or a newly
-   * allocated input. This is necesary for manager to know when this output node
-   * is going to be used with in-place optimizations.
+   * allocated input. This is necesary for manager to know when this output
+   * node is going to be used with in-place optimizations.
    */
   const std::vector<Var_Grad *> &outputs =
     tensor_manager->requestOutputs(gnode, init_context.getOutputDimensions(),
@@ -679,7 +704,8 @@ NetworkGraph::finalizeContext(const std::shared_ptr<LayerNode> &lnode,
     /// later(#1707)
     // auto shared_node = getLayerNode(shared_node_str).get();
     // NNTR_THROW_IF(shared_node == nullptr, std::invalid_argument)
-    //   << "shared_node requested but it is not registered in the graph, name:
+    //   << "shared_node requested but it is not registered in the graph,
+    //   name:
     //   "
     //   << shared_node_str << " requested from " << lnode->getName();
     // NNTR_THROW_IF(shared_node->getType() != lnode->getType(),
@@ -689,7 +715,8 @@ NetworkGraph::finalizeContext(const std::shared_ptr<LayerNode> &lnode,
     //   lnode->getType()
     //   << " depedent node name: " << lnode->getName();
     // NNTR_THROW_IF(!shared_node->isFinalized(), std::invalid_argument)
-    //   << "shared node must be prior to the dependent node and it should be "
+    //   << "shared node must be prior to the dependent node and it should be
+    //   "
     //      "finalized beforehand, shared node name: "
     //   << shared_node_str << " dependent node name: " << lnode->getName();
     // auto num_weight = shared_node->getNumWeights();
@@ -858,8 +885,8 @@ int NetworkGraph::initialize(
         if (!pred(lnode)) {
           continue;
         }
-        /// when name is empty, we identify everything as the node, all of them
-        /// must be having identical dimensions
+        /// when name is empty, we identify everything as the node, all of
+        /// them must be having identical dimensions
         identify(lnode);
       }
     } else {
@@ -900,6 +927,13 @@ int NetworkGraph::initialize(
     return ML_ERROR_INVALID_PARAMETER;
   }
 
+  /** select weights which would require clipping of the gradients by global
+   * norm if any */
+  clip_weights = tensor_manager->getWeights([](const Weight *w) {
+    return w->hasGradient() && w->isGradientLastAccess() &&
+           w->isGradientClipByGlobalNorm();
+  });
+
   return ML_ERROR_NONE;
 }
 
index a8deb3f..d064b71 100644 (file)
@@ -142,7 +142,7 @@ public:
    * @param apply_func apply function
    */
   static void applyGradients(LayerNode *node,
-                             std::function<void(Weight &)> apply_func);
+                             const std::function<void(Weight &)> &apply_func);
 
   /**
    * @brief     forwarding network graph
@@ -155,10 +155,12 @@ public:
    * @brief     backwarding the network graph
    * @param[in] iteration current iteration number
    * @param[in] backwarding_op operation for the backwarding
+   * @param[in] apply_grad_clip_op operation for applying the clip gradients
    */
   void backwarding(
     int iteration,
-    std::function<void(std::shared_ptr<LayerNode>, int)> &backwarding_op) const;
+    std::function<void(std::shared_ptr<LayerNode>, int)> &backwarding_op,
+    std::function<void(Weight &, int)> &apply_grad_clip_op) const;
 
   /**
    * @brief     get begin iterator for the graph
@@ -368,6 +370,8 @@ private:
 
   std::unordered_map<std::string, int>
     profile_keys; /**< profile keys based on the layer type */
+  std::vector<Weight *>
+    clip_weights; /**< weights with global norm based clipping enabled */
 
   /**
    * @brief     topological sort
index 88dcbd3..bb5dc6a 100644 (file)
@@ -311,7 +311,13 @@ void NeuralNetwork::backwarding(int iteration) {
     }
   };
 
-  model_graph.backwarding(iteration, backwarding_op);
+  std::function<void(Weight &, int)> apply_grad_clip_op =
+    [opt_ = opt.get()](Weight &w, int iteration) -> void {
+    RunOptimizerContext opt_context(&w, iteration);
+    opt_->applyGradient(opt_context);
+  };
+
+  model_graph.backwarding(iteration, backwarding_op, apply_grad_clip_op);
 }
 
 void NeuralNetwork::save(const std::string &file_path,
index c534bca..58d2198 100644 (file)
@@ -578,14 +578,16 @@ std::vector<Tensor *> Manager::requestWeightOptimizerVariables(
   return ret;
 }
 
-std::vector<Weight *> Manager::getWeights() {
-  std::vector<Weight *> all_weights;
+std::vector<Weight *>
+Manager::getWeights(const std::function<bool(const Weight *)> &condition) {
+  std::vector<Weight *> conditional_weights;
 
   for (auto &w : weights_v2) {
-    all_weights.push_back(w.get());
+    if (!condition || condition(w.get()))
+      conditional_weights.push_back(w.get());
   }
 
-  return all_weights;
+  return conditional_weights;
 }
 
 void Manager::finalizeTensorPool(TensorPool &pool, unsigned int start,
index 9425bb9..ed1c00c 100644 (file)
@@ -245,11 +245,12 @@ public:
                  bool shared_var = true, bool shared_grad = true);
 
   /**
-   * @brief     Get all the weights
+   * @brief     Get all the weights which match the above condition
    *
-   * @return    return all the weights
+   * @return    return the weights with satisfying the above condition
    */
-  std::vector<Weight *> getWeights();
+  std::vector<Weight *>
+  getWeights(const std::function<bool(const Weight *)> &condition = nullptr);
 
   /**
    * @brief Get the Min Max of a tensor execution order
index 30b8452..dfe1b9a 100644 (file)
@@ -291,6 +291,13 @@ public:
    */
   bool isGradientLastAccess() const { return is_last_access_gradient; }
 
+  /**
+   * @brief Get the norm of the gradient
+   *
+   * @return float l2 norm of the gradient
+   */
+  float getGradientNorm() const { return grad->l2norm(); }
+
   inline static const std::string grad_suffix = ":grad";
 
 protected:
index 723687d..957203b 100644 (file)
@@ -246,7 +246,19 @@ public:
    * @return true if it is to be clipped
    * @return false otherwise
    */
-  bool isGradientClipByGlobalNorm() { return clip_by_global_norm > epsilon; }
+  bool isGradientClipByGlobalNorm() const {
+    return clip_by_global_norm > epsilon;
+  }
+
+  /**
+   * @brief clip the gradient value based on the given global norm
+   *
+   * @param global_norm the global norm for all the weights
+   */
+  void clipGradientByGlobalNorm(const float global_norm) {
+    if (global_norm > clip_by_global_norm)
+      grad->multiply_i(clip_by_global_norm / (global_norm + epsilon));
+  }
 
 private:
   static constexpr float epsilon = 1e-8; /**< epsilon for zero comparison */