ml: refactor ML_ANN test
authorAlexander Alekhin <alexander.alekhin@intel.com>
Mon, 19 Feb 2018 16:45:04 +0000 (19:45 +0300)
committerAlexander Alekhin <alexander.alekhin@intel.com>
Mon, 19 Feb 2018 18:35:48 +0000 (21:35 +0300)
modules/ml/test/test_mltests2.cpp

index 7d6bc1d..2ff0c93 100644 (file)
@@ -252,31 +252,35 @@ TEST(ML_ANN, ActivationFunction)
     }
 }
 
-TEST(ML_ANN, Method)
+CV_ENUM(ANN_MLP_METHOD, ANN_MLP::RPROP, ANN_MLP::ANNEAL)
+
+typedef tuple<ANN_MLP_METHOD, string, int> ML_ANN_METHOD_Params;
+typedef TestWithParam<ML_ANN_METHOD_Params> ML_ANN_METHOD;
+
+TEST_P(ML_ANN_METHOD, Test)
 {
+    int methodType = get<0>(GetParam());
+    string methodName = get<1>(GetParam());
+    int N = get<2>(GetParam());
+
     String folder = string(cvtest::TS::ptr()->get_data_path());
     String original_path = folder + "waveform.data";
-    String dataname = folder + "waveform";
+    String dataname = folder + "waveform" + '_' + methodName;
 
     Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0);
-    Mat responses(tdata2->getResponses().rows, 3, CV_32FC1, Scalar(0));
-    for (int i = 0; i<tdata2->getResponses().rows; i++)
+    Mat samples = tdata2->getSamples()(Range(0, N), Range::all());
+    Mat responses(N, 3, CV_32FC1, Scalar(0));
+    for (int i = 0; i < N; i++)
         responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1;
-    Ptr<TrainData> tdata = TrainData::create(tdata2->getSamples(), ml::ROW_SAMPLE, responses);
+    Ptr<TrainData> tdata = TrainData::create(samples, ml::ROW_SAMPLE, responses);
 
     ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
     RNG& rng = theRNG();
     rng.state = 0;
     tdata->setTrainTestSplitRatio(0.8);
 
-    vector<int> methodType;
-    methodType.push_back(ml::ANN_MLP::RPROP);
-    methodType.push_back(ml::ANN_MLP::ANNEAL);
-//    methodType.push_back(ml::ANN_MLP::BACKPROP); -----> NO BACKPROP TEST
-    vector<String> methodName;
-    methodName.push_back("_rprop");
-    methodName.push_back("_anneal");
-//    methodName.push_back("_backprop"); -----> NO BACKPROP TEST
+    Mat testSamples = tdata->getTestSamples();
+
 #ifdef GENERATE_TESTDATA
     {
     Ptr<ml::ANN_MLP> xx = ml::ANN_MLP_ANNEAL::create();
@@ -296,14 +300,13 @@ TEST(ML_ANN, Method)
     fs.release();
     }
 #endif
-    for (size_t i = 0; i < methodType.size(); i++)
     {
         FileStorage fs;
-        fs.open(dataname + "_init_weight.yml.gz", FileStorage::READ + FileStorage::BASE64);
+        fs.open(dataname + "_init_weight.yml.gz", FileStorage::READ);
         Ptr<ml::ANN_MLP> x = ml::ANN_MLP_ANNEAL::create();
         x->read(fs.root());
-        x->setTrainMethod(methodType[i]);
-        if (methodType[i] == ml::ANN_MLP::ANNEAL)
+        x->setTrainMethod(methodType);
+        if (methodType == ml::ANN_MLP::ANNEAL)
         {
             x->setAnnealEnergyRNG(RNG(CV_BIG_INT(0xffffffff)));
             x->setAnnealInitialT(12);
@@ -313,28 +316,50 @@ TEST(ML_ANN, Method)
         }
         x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 100, 0.01));
         x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS);
-        ASSERT_TRUE(x->isTrained()) << "Could not train networks with  " << methodName[i];
+        ASSERT_TRUE(x->isTrained()) << "Could not train networks with  " << methodName;
+        string filename = dataname + ".yml.gz";
+        Mat r_gold;
 #ifdef  GENERATE_TESTDATA
-        x->save(dataname + methodName[i] + ".yml.gz");
+        x->save(filename);
+        x->predict(testSamples, r_gold);
+        {
+            FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::WRITE + FileStorage::BASE64);
+            fs_response << "response" << r_gold;
+        }
+#else
+        {
+            FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::READ);
+            fs_response["response"] >> r_gold;
+        }
 #endif
-        Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(dataname + methodName[i] + ".yml.gz");
-        ASSERT_TRUE(y != NULL) << "Could not load   " << dataname + methodName[i] + ".yml";
-        Mat testSamples = tdata->getTestSamples();
-        Mat rx, ry, dst;
+        ASSERT_FALSE(r_gold.empty());
+        Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(filename);
+        ASSERT_TRUE(y != NULL) << "Could not load   " << filename;
+        Mat rx, ry;
         for (int j = 0; j < 4; j++)
         {
             rx = x->getWeights(j);
             ry = y->getWeights(j);
             double n = cvtest::norm(rx, ry, NORM_INF);
-            EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i] << " layer : " << j;
+            EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for layer: " << j;
         }
         x->predict(testSamples, rx);
         y->predict(testSamples, ry);
-        double n = cvtest::norm(rx, ry, NORM_INF);
-        EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i];
+        double n = cvtest::norm(ry, rx, NORM_INF);
+        EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to result of the saved model";
+        n = cvtest::norm(r_gold, rx, NORM_INF);
+        EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to 'gold' response";
     }
 }
 
+INSTANTIATE_TEST_CASE_P(/*none*/, ML_ANN_METHOD,
+    testing::Values(
+        make_tuple<ANN_MLP_METHOD, string, int>(ml::ANN_MLP::RPROP, "rprop", 5000),
+        make_tuple<ANN_MLP_METHOD, string, int>(ml::ANN_MLP::ANNEAL, "anneal", 1000)
+        //make_pair<ANN_MLP_METHOD, string>(ml::ANN_MLP::BACKPROP, "backprop", 5000); -----> NO BACKPROP TEST
+    )
+);
+
 
 // 6. dtree
 // 7. boost