}
}
-TEST(ML_ANN, Method)
+CV_ENUM(ANN_MLP_METHOD, ANN_MLP::RPROP, ANN_MLP::ANNEAL)
+
+typedef tuple<ANN_MLP_METHOD, string, int> ML_ANN_METHOD_Params;
+typedef TestWithParam<ML_ANN_METHOD_Params> ML_ANN_METHOD;
+
+TEST_P(ML_ANN_METHOD, Test)
{
+ int methodType = get<0>(GetParam());
+ string methodName = get<1>(GetParam());
+ int N = get<2>(GetParam());
+
String folder = string(cvtest::TS::ptr()->get_data_path());
String original_path = folder + "waveform.data";
- String dataname = folder + "waveform";
+ String dataname = folder + "waveform" + '_' + methodName;
Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0);
- Mat responses(tdata2->getResponses().rows, 3, CV_32FC1, Scalar(0));
- for (int i = 0; i<tdata2->getResponses().rows; i++)
+ Mat samples = tdata2->getSamples()(Range(0, N), Range::all());
+ Mat responses(N, 3, CV_32FC1, Scalar(0));
+ for (int i = 0; i < N; i++)
responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1;
- Ptr<TrainData> tdata = TrainData::create(tdata2->getSamples(), ml::ROW_SAMPLE, responses);
+ Ptr<TrainData> tdata = TrainData::create(samples, ml::ROW_SAMPLE, responses);
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
RNG& rng = theRNG();
rng.state = 0;
tdata->setTrainTestSplitRatio(0.8);
- vector<int> methodType;
- methodType.push_back(ml::ANN_MLP::RPROP);
- methodType.push_back(ml::ANN_MLP::ANNEAL);
-// methodType.push_back(ml::ANN_MLP::BACKPROP); -----> NO BACKPROP TEST
- vector<String> methodName;
- methodName.push_back("_rprop");
- methodName.push_back("_anneal");
-// methodName.push_back("_backprop"); -----> NO BACKPROP TEST
+ Mat testSamples = tdata->getTestSamples();
+
#ifdef GENERATE_TESTDATA
{
Ptr<ml::ANN_MLP> xx = ml::ANN_MLP_ANNEAL::create();
fs.release();
}
#endif
- for (size_t i = 0; i < methodType.size(); i++)
{
FileStorage fs;
- fs.open(dataname + "_init_weight.yml.gz", FileStorage::READ + FileStorage::BASE64);
+ fs.open(dataname + "_init_weight.yml.gz", FileStorage::READ);
Ptr<ml::ANN_MLP> x = ml::ANN_MLP_ANNEAL::create();
x->read(fs.root());
- x->setTrainMethod(methodType[i]);
- if (methodType[i] == ml::ANN_MLP::ANNEAL)
+ x->setTrainMethod(methodType);
+ if (methodType == ml::ANN_MLP::ANNEAL)
{
x->setAnnealEnergyRNG(RNG(CV_BIG_INT(0xffffffff)));
x->setAnnealInitialT(12);
}
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 100, 0.01));
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS);
- ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << methodName[i];
+ ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << methodName;
+ string filename = dataname + ".yml.gz";
+ Mat r_gold;
#ifdef GENERATE_TESTDATA
- x->save(dataname + methodName[i] + ".yml.gz");
+ x->save(filename);
+ x->predict(testSamples, r_gold);
+ {
+ FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::WRITE + FileStorage::BASE64);
+ fs_response << "response" << r_gold;
+ }
+#else
+ {
+ FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::READ);
+ fs_response["response"] >> r_gold;
+ }
#endif
- Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(dataname + methodName[i] + ".yml.gz");
- ASSERT_TRUE(y != NULL) << "Could not load " << dataname + methodName[i] + ".yml";
- Mat testSamples = tdata->getTestSamples();
- Mat rx, ry, dst;
+ ASSERT_FALSE(r_gold.empty());
+ Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(filename);
+ ASSERT_TRUE(y != NULL) << "Could not load " << filename;
+ Mat rx, ry;
for (int j = 0; j < 4; j++)
{
rx = x->getWeights(j);
ry = y->getWeights(j);
double n = cvtest::norm(rx, ry, NORM_INF);
- EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i] << " layer : " << j;
+ EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for layer: " << j;
}
x->predict(testSamples, rx);
y->predict(testSamples, ry);
- double n = cvtest::norm(rx, ry, NORM_INF);
- EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i];
+ double n = cvtest::norm(ry, rx, NORM_INF);
+ EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to result of the saved model";
+ n = cvtest::norm(r_gold, rx, NORM_INF);
+ EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to 'gold' response";
}
}
+INSTANTIATE_TEST_CASE_P(/*none*/, ML_ANN_METHOD,
+ testing::Values(
+ make_tuple<ANN_MLP_METHOD, string, int>(ml::ANN_MLP::RPROP, "rprop", 5000),
+ make_tuple<ANN_MLP_METHOD, string, int>(ml::ANN_MLP::ANNEAL, "anneal", 1000)
+ //make_pair<ANN_MLP_METHOD, string>(ml::ANN_MLP::BACKPROP, "backprop", 5000); -----> NO BACKPROP TEST
+ )
+);
+
// 6. dtree
// 7. boost