return;
}
+ if (weights_initialized)
+ return;
+
AllocFunc allocate_weight = getAllocFunc(true);
unsigned int weight_offset = 0;
}
weights_initialized = true;
+ /** weights are allocated without delay */
weights_allocated = true;
}
void Manager::allocateWeights() {
+ if (weights_allocated)
+ return;
+
for (auto &l_w : weights) {
for (auto &w : l_w) {
Weight &weight = w.get();
if (!weights_initialized)
initializeWeights();
- // Allocate gradients
+ if (tensors_initialized)
+ return;
+
+ // Initialize gradients
if (trainable)
initializeGradients();
- // Allocate shared derivative memory
+ // Initialize shared derivative memory
if (max_derivative_size > 0 && enable_activation_memory_opt && trainable)
shared_deriv = Tensor(TensorDim({max_derivative_size}), false);
// @todo Do not count memory of the input tensor of the input layer in the
// estimate of max_shared_inout as it is not used
- // Allocate shared input/output memory for inference
+ // Initialize shared input/output memory for inference
// @note Memory for label is not allocated here as inference doesnt has label
if (!trainable && enable_inference_inout_memory_opt)
shared_inout = Tensor(TensorDim({max_shared_inout}), false);
EXPECT_EQ(model->setProperty({"batch_size=4"}), ML_ERROR_NONE);
EXPECT_NO_THROW(model->train());
- EXPECT_FLOAT_EQ(model->getTrainingLoss(), 1.9582682);
- EXPECT_FLOAT_EQ(model->getValidationLoss(), 2.1831701);
- EXPECT_FLOAT_EQ(model->getLoss(), 2.1985414);
+ EXPECT_FLOAT_EQ(model->getTrainingLoss(), 1.9613363);
+ EXPECT_FLOAT_EQ(model->getValidationLoss(), 2.1835098);
+ EXPECT_FLOAT_EQ(model->getLoss(), 2.1977143);
}
/**
mkModelTc(conv_same_padding_multi_stride, "3:1:1:10", 10),
mkModelTc(conv_no_loss_validate, "3:1:1:10", 1),
mkModelTc(conv_none_loss_validate, "3:1:1:10", 1)
-), [](const testing::TestParamInfo<nntrainerModelTest::ParamType>& info){
- return std::get<0>(info.param).getName();
-});
+// / #if gtest_version <= 1.7.0
+));
+/// #else gtest_version > 1.8.0
+// ), [](const testing::TestParamInfo<nntrainerModelTest::ParamType>& info){
+// return std::get<0>(info.param).getName();
+// });
+/// #end if */
// clang-format on
/**