From 26c5fd529b0d1c774d1e095f8c7db30480334148 Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Tue, 1 Dec 2020 10:48:32 +0900 Subject: [PATCH] [var_grad] Trainable inferred from gradient Trainable property of a variable was earlier inferred by storing a trainable variable Now, trainable will be inferred using gradient.uninitialized() **Self evaluation:** 1. Build test: [x]Passed [ ]Failed [ ]Skipped 2. Run test: [x]Passed [ ]Failed [ ]Skipped Signed-off-by: Parichay Kapoor --- nntrainer/tensor/var_grad.cpp | 20 ++++++++++++++------ nntrainer/tensor/var_grad.h | 13 +++++++++---- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/nntrainer/tensor/var_grad.cpp b/nntrainer/tensor/var_grad.cpp index f79ad685..a95d9f49 100644 --- a/nntrainer/tensor/var_grad.cpp +++ b/nntrainer/tensor/var_grad.cpp @@ -16,11 +16,9 @@ namespace nntrainer { -Var_Grad::Var_Grad(const Var_Grad &rhs) : - trainable(rhs.trainable), - name(rhs.name) { +Var_Grad::Var_Grad(const Var_Grad &rhs) : name(rhs.name) { var = rhs.var.clone(); - if (rhs.trainable) + if (rhs.getTrainable()) grad = rhs.grad.clone(); } @@ -31,15 +29,25 @@ Var_Grad &Var_Grad::operator=(const Var_Grad &rhs) { } Var_Grad::Var_Grad(const TensorDim &dim, bool train, const std::string &name) : - trainable(train), name(name) { var = Tensor(dim); grad = Tensor(); - if (trainable) { + if (train) { grad = Tensor(dim); } resetGradient(); } +void Var_Grad::setTrainable(bool train) { + if (train == getTrainable()) + return; + + if (train) { + grad = Tensor(var.getDim()); + } else { + grad = Tensor(); + } +} + } // namespace nntrainer diff --git a/nntrainer/tensor/var_grad.h b/nntrainer/tensor/var_grad.h index 2be67e44..9f967e94 100644 --- a/nntrainer/tensor/var_grad.h +++ b/nntrainer/tensor/var_grad.h @@ -27,8 +27,9 @@ class Var_Grad { public: /** * @brief Var_Grad default constructor + * @note Default variable is not trainable as gradient is 0 dim tensor */ - Var_Grad() : trainable(false) {} + Var_Grad() = default; /** * @brief Construct a new Var_Grad object @@ -50,7 +51,6 @@ public: using std::swap; swap(lhs.var, rhs.var); - swap(lhs.trainable, rhs.trainable); swap(lhs.grad, rhs.grad); swap(lhs.name, rhs.name); } @@ -98,7 +98,13 @@ public: * @return true if trainable * @return false is not trainable */ - bool getTrainable() { return trainable; } + bool getTrainable() const { return !grad.uninitialized(); } + + /** + * @brief set if the Var_Grad is trainable + * @param train true if trainable, else false + */ + void setTrainable(bool train); /** * @brief Get the name of the Var_Grad @@ -145,7 +151,6 @@ protected: Tensor var; /**< variable to be updated and used */ Tensor grad; /**< gradient for the variable */ - bool trainable; /**< if this variable is trainable */ std::string name; /**< name of the parameter */ }; -- 2.34.1