bn_follow(false),
weight_decay(),
weight_ini_type(WEIGHT_UNKNOWN),
- flatten(false) {}
+ flatten(false),
+ trainable(true) {}
/**
* @brief Destructor of Layer Class
*/
int setCost(CostType c);
+ /**
+ * @brief set trainable for this layer
+ * @param[in] train to enable/disable train
+ */
+ void setTrainable(bool train) { trainable = train; }
/**
* @brief get gradients
*/
bool flatten;
+ /*
+ * @brief making this false will skip updating this layer variables
+ */
+ bool trainable;
+
/**
* @brief Gradient for the weights in this layer
* @note The order of gradients should match the order in weights
gradients.clear();
- // Update K / bias
- for (unsigned int i = 0; i < filter_size; ++i) {
- Tensor djdw = delK[i]
- .chain()
- .applyIf(this->isWeightDecayL2Norm(), _LIFT(add_i),
- filters[i], weight_decay.lambda)
- .run();
-
- gradients.push_back(djdw);
- gradients.push_back(delBias[i]);
+ if (trainable) {
+ // Update K / bias
+ for (unsigned int i = 0; i < filter_size; ++i) {
+ Tensor djdw = delK[i]
+ .chain()
+ .applyIf(this->isWeightDecayL2Norm(), _LIFT(add_i),
+ filters[i], weight_decay.lambda)
+ .run();
+
+ gradients.push_back(djdw);
+ gradients.push_back(delBias[i]);
+ }
+ opt.apply_gradients(weights, gradients, iteration);
}
- opt.apply_gradients(weights, gradients, iteration);
return strip_pad(ret, padding);
}
weight_decay.lambda)
.run();
- gradients.clear();
- gradients.push_back(djdw);
- gradients.push_back(djdb);
+ if (trainable) {
+ gradients.clear();
+ gradients.push_back(djdw);
+ gradients.push_back(djdb);
- opt.apply_gradients(weights, gradients, iteration);
+ opt.apply_gradients(weights, gradients, iteration);
+ }
return ret;
}