[var_grad] Trainable inferred from gradient submit/tizen/20201201.095355 submit/tizen/20201202.082821
authorParichay Kapoor <pk.kapoor@samsung.com>
Tue, 1 Dec 2020 01:48:32 +0000 (10:48 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 1 Dec 2020 06:15:26 +0000 (15:15 +0900)
Trainable property of a variable was earlier inferred by storing a trainable variable
Now, trainable will be inferred using gradient.uninitialized()

**Self evaluation:**
1. Build test: [x]Passed [ ]Failed [ ]Skipped
2. Run test: [x]Passed [ ]Failed [ ]Skipped

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/tensor/var_grad.cpp
nntrainer/tensor/var_grad.h

index f79ad6853fdf1948cca9dcb0b952fda8e657af0e..a95d9f49007d653695eee0661020b2e8acd06e1f 100644 (file)
 
 namespace nntrainer {
 
-Var_Grad::Var_Grad(const Var_Grad &rhs) :
-  trainable(rhs.trainable),
-  name(rhs.name) {
+Var_Grad::Var_Grad(const Var_Grad &rhs) : name(rhs.name) {
   var = rhs.var.clone();
-  if (rhs.trainable)
+  if (rhs.getTrainable())
     grad = rhs.grad.clone();
 }
 
@@ -31,15 +29,25 @@ Var_Grad &Var_Grad::operator=(const Var_Grad &rhs) {
 }
 
 Var_Grad::Var_Grad(const TensorDim &dim, bool train, const std::string &name) :
-  trainable(train),
   name(name) {
   var = Tensor(dim);
 
   grad = Tensor();
-  if (trainable) {
+  if (train) {
     grad = Tensor(dim);
   }
   resetGradient();
 }
 
+void Var_Grad::setTrainable(bool train) {
+  if (train == getTrainable())
+    return;
+
+  if (train) {
+    grad = Tensor(var.getDim());
+  } else {
+    grad = Tensor();
+  }
+}
+
 } // namespace nntrainer
index 2be67e44971f7511be83baadee3910edf6448582..9f967e94571a4394eb495aaf6264357c35efb1e0 100644 (file)
@@ -27,8 +27,9 @@ class Var_Grad {
 public:
   /**
    * @brief Var_Grad default constructor
+   * @note Default variable is not trainable as gradient is 0 dim tensor
    */
-  Var_Grad() : trainable(false) {}
+  Var_Grad() = default;
 
   /**
    * @brief Construct a new Var_Grad object
@@ -50,7 +51,6 @@ public:
     using std::swap;
 
     swap(lhs.var, rhs.var);
-    swap(lhs.trainable, rhs.trainable);
     swap(lhs.grad, rhs.grad);
     swap(lhs.name, rhs.name);
   }
@@ -98,7 +98,13 @@ public:
    * @return true if trainable
    * @return false is not trainable
    */
-  bool getTrainable() { return trainable; }
+  bool getTrainable() const { return !grad.uninitialized(); }
+
+  /**
+   * @brief set if the Var_Grad is trainable
+   * @param train true if trainable, else false
+   */
+  void setTrainable(bool train);
 
   /**
    * @brief Get the name of the Var_Grad
@@ -145,7 +151,6 @@ protected:
 
   Tensor var;       /**< variable to be updated and used */
   Tensor grad;      /**< gradient for the variable */
-  bool trainable;   /**< if this variable is trainable */
   std::string name; /**< name of the parameter */
 };