namespace nntrainer {
-Var_Grad::Var_Grad(const Var_Grad &rhs) :
- trainable(rhs.trainable),
- name(rhs.name) {
+Var_Grad::Var_Grad(const Var_Grad &rhs) : name(rhs.name) {
var = rhs.var.clone();
- if (rhs.trainable)
+ if (rhs.getTrainable())
grad = rhs.grad.clone();
}
}
Var_Grad::Var_Grad(const TensorDim &dim, bool train, const std::string &name) :
- trainable(train),
name(name) {
var = Tensor(dim);
grad = Tensor();
- if (trainable) {
+ if (train) {
grad = Tensor(dim);
}
resetGradient();
}
+void Var_Grad::setTrainable(bool train) {
+ if (train == getTrainable())
+ return;
+
+ if (train) {
+ grad = Tensor(var.getDim());
+ } else {
+ grad = Tensor();
+ }
+}
+
} // namespace nntrainer
public:
/**
* @brief Var_Grad default constructor
+ * @note Default variable is not trainable as gradient is 0 dim tensor
*/
- Var_Grad() : trainable(false) {}
+ Var_Grad() = default;
/**
* @brief Construct a new Var_Grad object
using std::swap;
swap(lhs.var, rhs.var);
- swap(lhs.trainable, rhs.trainable);
swap(lhs.grad, rhs.grad);
swap(lhs.name, rhs.name);
}
* @return true if trainable
* @return false is not trainable
*/
- bool getTrainable() { return trainable; }
+ bool getTrainable() const { return !grad.uninitialized(); }
+
+ /**
+ * @brief set if the Var_Grad is trainable
+ * @param train true if trainable, else false
+ */
+ void setTrainable(bool train);
/**
* @brief Get the name of the Var_Grad
Tensor var; /**< variable to be updated and used */
Tensor grad; /**< gradient for the variable */
- bool trainable; /**< if this variable is trainable */
std::string name; /**< name of the parameter */
};