Simulated Annealing for ANN_MLP training method (#10213)
authorLaurentBerger <laurent.berger@univ-lemans.fr>
Fri, 15 Dec 2017 10:57:39 +0000 (11:57 +0100)
committerVadim Pisarevsky <vadim.pisarevsky@gmail.com>
Fri, 15 Dec 2017 10:57:39 +0000 (13:57 +0300)
* Simulated Annealing for ANN_MLP training method

* EXPECT_LT

* just to test new data

* manage RNG

* Try again

* Just run buildbot with new data

* try to understand

* Test layer

* New data- new test

* Force RNG in backprop

* Use Impl to avoid virtual method

* reset all weights

* try to solve ABI

* retry

* ABI solved?

* till problem with dynamic_cast

* Something is wrong

* Solved?

* disable backprop test

* remove ANN_MLP_ANNEALImpl

* Disable weight in varmap

* Add example for SimulatedAnnealing

doc/opencv.bib
modules/ml/include/opencv2/ml.hpp
modules/ml/src/ann_mlp.cpp
modules/ml/test/test_mltests2.cpp
samples/cpp/travelsalesman.cpp [new file with mode: 0644]

index dbb6e25..1cb3d0b 100644 (file)
   number = {3},
   publisher = {Elsevier}
 }
+@ARTICLE{Kirkpatrick83,
+  author = {Kirkpatrick, S. and  Gelatt, C. D.  Jr and Vecchi, M. P. },
+  title = {Optimization by Simulated Annealing},
+  year = {1983},
+  pages = {671--680},
+  journal = {Science},
+  volume = {220},
+  number = {4598},
+  publisher = {American Association for the Advancement of Science}
+}
+
 @INPROCEEDINGS{Kolmogorov03,
   author = {Kim, Junhwan and Kolmogorov, Vladimir and Zabih, Ramin},
   title = {Visual correspondence using energy minimization and mutual information},
index eb2d7ed..35a5a62 100644 (file)
@@ -1406,13 +1406,14 @@ public:
     /** Available training methods */
     enum TrainingMethods {
         BACKPROP=0, //!< The back-propagation algorithm.
-        RPROP=1 //!< The RPROP algorithm. See @cite RPROP93 for details.
+        RPROP = 1, //!< The RPROP algorithm. See @cite RPROP93 for details.
+        ANNEAL = 2 //!< The simulated annealing algorithm. See @cite Kirkpatrick83 for details.
     };
 
     /** Sets training method and common parameters.
     @param method Default value is ANN_MLP::RPROP. See ANN_MLP::TrainingMethods.
-    @param param1 passed to setRpropDW0 for ANN_MLP::RPROP and to setBackpropWeightScale for ANN_MLP::BACKPROP
-    @param param2 passed to setRpropDWMin for ANN_MLP::RPROP and to setBackpropMomentumScale for ANN_MLP::BACKPROP.
+    @param param1 passed to setRpropDW0 for ANN_MLP::RPROP and to setBackpropWeightScale for ANN_MLP::BACKPROP and to initialT for ANN_MLP::ANNEAL.
+    @param param2 passed to setRpropDWMin for ANN_MLP::RPROP and to setBackpropMomentumScale for ANN_MLP::BACKPROP and to finalT for ANN_MLP::ANNEAL.
     */
     CV_WRAP virtual void setTrainMethod(int method, double param1 = 0, double param2 = 0) = 0;
 
@@ -1499,6 +1500,34 @@ public:
     /** @copybrief getRpropDWMax @see getRpropDWMax */
     CV_WRAP virtual void setRpropDWMax(double val) = 0;
 
+    /** ANNEAL: Update initial temperature.
+    It must be \>=0. Default value is 10.*/
+    /** @see setAnnealInitialT */
+    CV_WRAP double getAnnealInitialT() const;
+    /** @copybrief getAnnealInitialT @see getAnnealInitialT */
+    CV_WRAP void setAnnealInitialT(double val);
+
+    /** ANNEAL: Update final temperature.
+    It must be \>=0 and less than initialT. Default value is 0.1.*/
+    /** @see setAnnealFinalT */
+    CV_WRAP double getAnnealFinalT() const;
+    /** @copybrief getAnnealFinalT @see getAnnealFinalT */
+    CV_WRAP void setAnnealFinalT(double val);
+
+    /** ANNEAL: Update cooling ratio.
+    It must be \>0 and less than 1. Default value is 0.95.*/
+    /** @see setAnnealCoolingRatio */
+    CV_WRAP double getAnnealCoolingRatio() const;
+    /** @copybrief getAnnealCoolingRatio @see getAnnealCoolingRatio */
+    CV_WRAP void setAnnealCoolingRatio(double val);
+
+    /** ANNEAL: Update iteration per step.
+    It must be \>0 . Default value is 10.*/
+    /** @see setAnnealItePerStep */
+    CV_WRAP int getAnnealItePerStep() const;
+    /** @copybrief getAnnealItePerStep @see getAnnealItePerStep */
+    CV_WRAP void setAnnealItePerStep(int val);
+
     /** possible activation functions */
     enum ActivationFunctions {
         /** Identity function: \f$f(x)=x\f$ */
@@ -1838,6 +1867,111 @@ CV_EXPORTS void randMVNormal( InputArray mean, InputArray cov, int nsamples, Out
 CV_EXPORTS void createConcentricSpheresTestSet( int nsamples, int nfeatures, int nclasses,
                                                 OutputArray samples, OutputArray responses);
 
+/** @brief Artificial Neural Networks - Multi-Layer Perceptrons.
+
+@sa @ref ml_intro_ann
+*/
+class CV_EXPORTS_W ANN_MLP_ANNEAL : public ANN_MLP
+{
+public:
+    /** @see setAnnealInitialT */
+    CV_WRAP virtual double getAnnealInitialT() const;
+    /** @copybrief getAnnealInitialT @see getAnnealInitialT */
+    CV_WRAP virtual void setAnnealInitialT(double val);
+
+    /** ANNEAL: Update final temperature.
+    It must be \>=0 and less than initialT. Default value is 0.1.*/
+    /** @see setAnnealFinalT */
+    CV_WRAP  virtual double getAnnealFinalT() const;
+    /** @copybrief getAnnealFinalT @see getAnnealFinalT */
+    CV_WRAP  virtual void setAnnealFinalT(double val);
+
+    /** ANNEAL: Update cooling ratio.
+    It must be \>0 and less than 1. Default value is 0.95.*/
+    /** @see setAnnealCoolingRatio */
+    CV_WRAP  virtual double getAnnealCoolingRatio() const;
+    /** @copybrief getAnnealCoolingRatio @see getAnnealCoolingRatio */
+    CV_WRAP  virtual void setAnnealCoolingRatio(double val);
+
+    /** ANNEAL: Update iteration per step.
+    It must be \>0 . Default value is 10.*/
+    /** @see setAnnealItePerStep */
+    CV_WRAP virtual int getAnnealItePerStep() const;
+    /** @copybrief getAnnealItePerStep @see getAnnealItePerStep */
+    CV_WRAP virtual  void setAnnealItePerStep(int val);
+
+
+    /** @brief Creates empty model
+
+    Use StatModel::train to train the model, Algorithm::load\<ANN_MLP\>(filename) to load the pre-trained model.
+    Note that the train method has optional flags: ANN_MLP::TrainFlags.
+    */
+//    CV_WRAP static Ptr<ANN_MLP> create();
+
+};
+
+/****************************************************************************************\
+*                                   Simulated annealing solver                             *
+\****************************************************************************************/
+
+/** @brief The class implements simulated annealing for optimization.
+@cite Kirkpatrick83 for details
+*/
+class CV_EXPORTS SimulatedAnnealingSolver : public Algorithm
+{
+public:
+    SimulatedAnnealingSolver() { init(); };
+    ~SimulatedAnnealingSolver();
+    /** Give energy value for  a state of system.*/
+    virtual double energy() =0;
+    /** Function which change the state of system (random pertubation).*/
+    virtual void changedState() = 0;
+    /** Function to reverse to the previous state.*/
+    virtual void reverseChangedState() = 0;
+    /** Simulated annealing procedure.  */
+    int run();
+    /** Set intial temperature of simulated annealing procedure.
+    *@param x new initial temperature. x\>0
+    */
+    void setInitialTemperature(double x);
+    /** Set final temperature of simulated annealing procedure.
+    *@param x new final temperature value. 0\<x\<initial temperature
+    */
+    void setFinalTemperature(double x);
+    double getFinalTemperature();
+    /** Set setCoolingRatio of simulated annealing procedure : T(t) = coolingRatio * T(t-1).
+    * @param x new cooling ratio value. 0\<x\<1
+    */
+    void setCoolingRatio(double x);
+    /** Set number iteration per temperature step.
+    * @param ite number of iteration per temperature step ite \> 0
+    */
+    void setIterPerStep(int ite);
+    struct Impl;
+protected :
+    void init();
+    Impl* impl;
+};
+struct SimulatedAnnealingSolver::Impl
+{
+    RNG rEnergy;
+    double coolingRatio;
+    double initialT;
+    double finalT;
+    int iterPerStep;
+    Impl()
+    {
+        initialT = 2;
+        finalT = 0.1;
+        coolingRatio = 0.95;
+        iterPerStep = 100;
+        refcount = 1;
+    }
+    int refcount;
+    ~Impl() { refcount--;CV_Assert(refcount==0); }
+};
+
+
 //! @} ml
 
 }
index 4a47b3d..fddc0ae 100644 (file)
@@ -42,6 +42,7 @@
 
 namespace cv { namespace ml {
 
+
 struct AnnParams
 {
     AnnParams()
@@ -51,6 +52,8 @@ struct AnnParams
         bpDWScale = bpMomentScale = 0.1;
         rpDW0 = 0.1; rpDWPlus = 1.2; rpDWMinus = 0.5;
         rpDWMin = FLT_EPSILON; rpDWMax = 50.;
+        initialT=10;finalT=0.1,coolingRatio=0.95;itePerStep=10;
+
     }
 
     TermCriteria termCrit;
@@ -64,6 +67,11 @@ struct AnnParams
     double rpDWMinus;
     double rpDWMin;
     double rpDWMax;
+
+    double initialT;
+    double finalT;
+    double coolingRatio;
+    int itePerStep;
 };
 
 template <typename T>
@@ -72,13 +80,208 @@ inline T inBounds(T val, T min_val, T max_val)
     return std::min(std::max(val, min_val), max_val);
 }
 
-class ANN_MLPImpl : public ANN_MLP
+SimulatedAnnealingSolver::~SimulatedAnnealingSolver()
+{
+    if (impl) delete impl;
+}
+
+void SimulatedAnnealingSolver::init()
+{
+    impl = new SimulatedAnnealingSolver::Impl();
+}
+
+void SimulatedAnnealingSolver::setIterPerStep(int ite)
+{
+    CV_Assert(ite>0);
+    impl->iterPerStep = ite;
+}
+
+int SimulatedAnnealingSolver::run()
+{
+    CV_Assert(impl->initialT>impl->finalT);
+    double Ti = impl->initialT;
+    double previousEnergy = energy();
+    int exchange = 0;
+    while (Ti > impl->finalT)
+    {
+        for (int i = 0; i < impl->iterPerStep; i++)
+        {
+            changedState();
+            double newEnergy = energy();
+            if (newEnergy < previousEnergy)
+            {
+                previousEnergy = newEnergy;
+            }
+            else
+            {
+                double r = impl->rEnergy.uniform(double(0.0), double(1.0));
+                if (r < exp(-(newEnergy - previousEnergy) / Ti))
+                {
+                    previousEnergy = newEnergy;
+                    exchange++;
+                }
+                else
+                    reverseChangedState();
+            }
+
+        }
+        Ti *= impl->coolingRatio;
+    }
+    impl->finalT = Ti;
+    return exchange;
+}
+
+void SimulatedAnnealingSolver::setInitialTemperature(double x)
+{
+    CV_Assert(x>0);
+    impl->initialT = x;
+};
+
+void SimulatedAnnealingSolver::setFinalTemperature(double x)
+{
+    CV_Assert(x>0);
+    impl->finalT = x;
+};
+
+double SimulatedAnnealingSolver::getFinalTemperature()
+{
+    return impl->finalT;
+};
+
+void SimulatedAnnealingSolver::setCoolingRatio(double x)
+{
+    CV_Assert(x>0 && x<1);
+    impl->coolingRatio = x;
+};
+
+class SimulatedAnnealingANN_MLP : public ml::SimulatedAnnealingSolver
+{
+public:
+    ml::ANN_MLP *nn;
+    Ptr<ml::TrainData> data;
+    int nbVariables;
+    vector<double*> adrVariables;
+    RNG rVar;
+    RNG rIndex;
+    double varTmp;
+    int index;
+
+    SimulatedAnnealingANN_MLP(ml::ANN_MLP *x, Ptr<ml::TrainData> d) : nn(x), data(d)
+    {
+        initVarMap();
+    };
+    void changedState()
+    {
+        index = rIndex.uniform(0, nbVariables);
+        double dv = rVar.uniform(-1.0, 1.0);
+        varTmp = *adrVariables[index];
+        *adrVariables[index] = dv;
+    };
+    void reverseChangedState()
+    {
+        *adrVariables[index] = varTmp;
+    };
+    double energy() { return nn->calcError(data, false, noArray()); }
+protected:
+    void initVarMap()
+    {
+        Mat l = nn->getLayerSizes();
+        nbVariables = 0;
+        adrVariables.clear();
+        for (int i = 1; i < l.rows-1; i++)
+        {
+            Mat w = nn->getWeights(i);
+            for (int j = 0; j < w.rows; j++)
+            {
+                for (int k = 0; k < w.cols; k++, nbVariables++)
+                {
+                    if (j == w.rows - 1)
+                    {
+                        adrVariables.push_back(&w.at<double>(w.rows - 1, k));
+                    }
+                    else
+                    {
+                        adrVariables.push_back(&w.at<double>(j, k));
+                    }
+                }
+            }
+        }
+    }
+
+};
+
+double ANN_MLP::getAnnealInitialT() const
+{
+    const ANN_MLP_ANNEAL* this_ = dynamic_cast<const ANN_MLP_ANNEAL*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    return this_->getAnnealInitialT();
+}
+
+void ANN_MLP::setAnnealInitialT(double val)
+{
+    ANN_MLP_ANNEAL* this_ = dynamic_cast<ANN_MLP_ANNEAL*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    this_->setAnnealInitialT(val);
+}
+
+double ANN_MLP::getAnnealFinalT() const
+{
+    const ANN_MLP_ANNEAL* this_ = dynamic_cast<const ANN_MLP_ANNEAL*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    return this_->getAnnealFinalT();
+}
+
+void ANN_MLP::setAnnealFinalT(double val)
+{
+    ANN_MLP_ANNEAL* this_ = dynamic_cast<ANN_MLP_ANNEAL*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    this_->setAnnealFinalT(val);
+}
+
+double ANN_MLP::getAnnealCoolingRatio() const
+{
+    const ANN_MLP_ANNEAL* this_ = dynamic_cast<const ANN_MLP_ANNEAL*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    return this_->getAnnealCoolingRatio();
+}
+
+void ANN_MLP::setAnnealCoolingRatio(double val)
+{
+    ANN_MLP_ANNEAL* this_ = dynamic_cast<ANN_MLP_ANNEAL*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    this_->setAnnealCoolingRatio(val);
+}
+
+int ANN_MLP::getAnnealItePerStep() const
+{
+    const ANN_MLP_ANNEAL* this_ = dynamic_cast<const ANN_MLP_ANNEAL*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    return this_->getAnnealItePerStep();
+}
+
+void ANN_MLP::setAnnealItePerStep(int val)
+{
+    ANN_MLP_ANNEAL* this_ = dynamic_cast<ANN_MLP_ANNEAL*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    this_->setAnnealItePerStep(val);
+}
+
+
+class ANN_MLPImpl : public ANN_MLP_ANNEAL
 {
 public:
     ANN_MLPImpl()
     {
         clear();
-        setActivationFunction( SIGMOID_SYM, 0, 0 );
+        setActivationFunction( SIGMOID_SYM, 0, 0);
         setLayerSizes(Mat());
         setTrainMethod(ANN_MLP::RPROP, 0.1, FLT_EPSILON);
     }
@@ -93,6 +296,10 @@ public:
     CV_IMPL_PROPERTY(double, RpropDWMinus, params.rpDWMinus)
     CV_IMPL_PROPERTY(double, RpropDWMin, params.rpDWMin)
     CV_IMPL_PROPERTY(double, RpropDWMax, params.rpDWMax)
+    CV_IMPL_PROPERTY(double, AnnealInitialT, params.initialT)
+    CV_IMPL_PROPERTY(double, AnnealFinalT, params.finalT)
+    CV_IMPL_PROPERTY(double, AnnealCoolingRatio, params.coolingRatio)
+    CV_IMPL_PROPERTY(int, AnnealItePerStep, params.itePerStep)
 
     void clear()
     {
@@ -107,7 +314,7 @@ public:
 
     void setTrainMethod(int method, double param1, double param2)
     {
-        if (method != ANN_MLP::RPROP && method != ANN_MLP::BACKPROP)
+        if (method != ANN_MLP::RPROP && method != ANN_MLP::BACKPROP && method != ANN_MLP::ANNEAL)
             method = ANN_MLP::RPROP;
         params.trainMethod = method;
         if(method == ANN_MLP::RPROP )
@@ -117,15 +324,30 @@ public:
             params.rpDW0 = param1;
             params.rpDWMin = std::max( param2, 0. );
         }
-        else if(method == ANN_MLP::BACKPROP )
+        else if (method == ANN_MLP::BACKPROP)
         {
-            if( param1 <= 0 )
+            if (param1 <= 0)
                 param1 = 0.1;
             params.bpDWScale = inBounds<double>(param1, 1e-3, 1.);
-            if( param2 < 0 )
+            if (param2 < 0)
                 param2 = 0.1;
-            params.bpMomentScale = std::min( param2, 1. );
+            params.bpMomentScale = std::min(param2, 1.);
         }
+/*        else if (method == ANN_MLP::ANNEAL)
+        {
+            if (param1 <= 0)
+                param1 = 10;
+            if (param2 <= 0 || param2>param1)
+                param2 = 0.1;
+            if (param3 <= 0 || param3 >=1)
+                param3 = 0.95;
+            if (param4 <= 0)
+                param4 = 10;
+            params.initialT = param1;
+            params.finalT = param2;
+            params.coolingRatio = param3;
+            params.itePerStep = param4;
+        }*/
     }
 
     int getTrainMethod() const
@@ -133,7 +355,7 @@ public:
         return params.trainMethod;
     }
 
-    void setActivationFunction(int _activ_func, double _f_param1, double _f_param2 )
+    void setActivationFunction(int _activ_func, double _f_param1, double _f_param2)
     {
         if( _activ_func < 0 || _activ_func > LEAKYRELU)
             CV_Error( CV_StsOutOfRange, "Unknown activation function" );
@@ -779,13 +1001,33 @@ public:
         termcrit.maxCount = std::max((params.termCrit.type & CV_TERMCRIT_ITER ? params.termCrit.maxCount : MAX_ITER), 1);
         termcrit.epsilon = std::max((params.termCrit.type & CV_TERMCRIT_EPS ? params.termCrit.epsilon : DEFAULT_EPSILON), DBL_EPSILON);
 
-        int iter = params.trainMethod == ANN_MLP::BACKPROP ?
-            train_backprop( inputs, outputs, sw, termcrit ) :
-            train_rprop( inputs, outputs, sw, termcrit );
-
+        int iter = 0;
+        switch(params.trainMethod){
+        case ANN_MLP::BACKPROP:
+            iter = train_backprop(inputs, outputs, sw, termcrit);
+            break;
+        case ANN_MLP::RPROP:
+            iter = train_rprop(inputs, outputs, sw, termcrit);
+            break;
+        case ANN_MLP::ANNEAL:
+            iter = train_anneal(trainData);
+            break;
+        }
         trained = iter > 0;
         return trained;
     }
+    int train_anneal(const Ptr<TrainData>& trainData)
+    {
+        SimulatedAnnealingANN_MLP t(this, trainData);
+        t.setFinalTemperature(params.finalT);
+        t.setInitialTemperature(params.initialT);
+        t.setCoolingRatio(params.coolingRatio);
+        t.setIterPerStep(params.itePerStep);
+        trained = true; // Enable call to CalcError
+        int iter =  t.run();
+        trained =false;
+        return iter;
+    }
 
     int train_backprop( const Mat& inputs, const Mat& outputs, const Mat& _sw, TermCriteria termCrit )
     {
@@ -849,7 +1091,7 @@ public:
                 E = 0;
 
                 // shuffle indices
-                for( i = 0; i < count; i++ )
+                for( i = 0; i <count; i++ )
                 {
                     j = rng.uniform(0, count);
                     k = rng.uniform(0, count);
@@ -1200,7 +1442,7 @@ public:
             fs << "dw_scale" << params.bpDWScale;
             fs << "moment_scale" << params.bpMomentScale;
         }
-        else if( params.trainMethod == ANN_MLP::RPROP )
+        else if (params.trainMethod == ANN_MLP::RPROP)
         {
             fs << "train_method" << "RPROP";
             fs << "dw0" << params.rpDW0;
@@ -1209,6 +1451,14 @@ public:
             fs << "dw_min" << params.rpDWMin;
             fs << "dw_max" << params.rpDWMax;
         }
+        else if (params.trainMethod == ANN_MLP::ANNEAL)
+        {
+            fs << "train_method" << "ANNEAL";
+            fs << "initialT" << params.initialT;
+            fs << "finalT" << params.finalT;
+            fs << "coolingRatio" << params.coolingRatio;
+            fs << "itePerStep" << params.itePerStep;
+        }
         else
             CV_Error(CV_StsError, "Unknown training method");
 
@@ -1270,7 +1520,7 @@ public:
         f_param1 = (double)fn["f_param1"];
         f_param2 = (double)fn["f_param2"];
 
-        setActivationFunction( activ_func, f_param1, f_param2 );
+        setActivationFunction( activ_func, f_param1, f_param2);
 
         min_val = (double)fn["min_val"];
         max_val = (double)fn["max_val"];
@@ -1290,7 +1540,7 @@ public:
                 params.bpDWScale = (double)tpn["dw_scale"];
                 params.bpMomentScale = (double)tpn["moment_scale"];
             }
-            else if( tmethod_name == "RPROP" )
+            else if (tmethod_name == "RPROP")
             {
                 params.trainMethod = ANN_MLP::RPROP;
                 params.rpDW0 = (double)tpn["dw0"];
@@ -1299,6 +1549,14 @@ public:
                 params.rpDWMin = (double)tpn["dw_min"];
                 params.rpDWMax = (double)tpn["dw_max"];
             }
+            else if (tmethod_name == "ANNEAL")
+            {
+                params.trainMethod = ANN_MLP::ANNEAL;
+                params.initialT = (double)tpn["initialT"];
+                params.finalT = (double)tpn["finalT"];
+                params.coolingRatio = (double)tpn["coolingRatio"];
+                params.itePerStep = tpn["itePerStep"];
+            }
             else
                 CV_Error(CV_StsParseError, "Unknown training method (should be BACKPROP or RPROP)");
 
@@ -1390,6 +1648,8 @@ public:
 };
 
 
+
+
 Ptr<ANN_MLP> ANN_MLP::create()
 {
     return makePtr<ANN_MLPImpl>();
@@ -1401,12 +1661,74 @@ Ptr<ANN_MLP> ANN_MLP::load(const String& filepath)
     fs.open(filepath, FileStorage::READ);
     CV_Assert(fs.isOpened());
     Ptr<ANN_MLP> ann = makePtr<ANN_MLPImpl>();
-
     ((ANN_MLPImpl*)ann.get())->read(fs.getFirstTopLevelNode());
     return ann;
 }
 
+double ANN_MLP_ANNEAL::getAnnealInitialT() const
+{
+    const ANN_MLPImpl* this_ = dynamic_cast<const ANN_MLPImpl*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    return this_->getAnnealInitialT();
+}
+
+void ANN_MLP_ANNEAL::setAnnealInitialT(double val)
+{
+    ANN_MLPImpl* this_ = dynamic_cast< ANN_MLPImpl*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    this_->setAnnealInitialT(val);
+}
+
+double ANN_MLP_ANNEAL::getAnnealFinalT() const
+{
+    const ANN_MLPImpl* this_ = dynamic_cast<const ANN_MLPImpl*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    return this_->getAnnealFinalT();
+}
+
+void ANN_MLP_ANNEAL::setAnnealFinalT(double val)
+{
+    ANN_MLPImpl* this_ = dynamic_cast<ANN_MLPImpl*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    this_->setAnnealFinalT(val);
+}
+
+double ANN_MLP_ANNEAL::getAnnealCoolingRatio() const
+{
+    const ANN_MLPImpl* this_ = dynamic_cast<const ANN_MLPImpl*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    return this_->getAnnealCoolingRatio();
+}
+
+void ANN_MLP_ANNEAL::setAnnealCoolingRatio(double val)
+{
+    ANN_MLPImpl* this_ = dynamic_cast< ANN_MLPImpl*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    this_->setAnnealInitialT(val);
+}
+
+int ANN_MLP_ANNEAL::getAnnealItePerStep() const
+{
+    const ANN_MLPImpl* this_ = dynamic_cast<const ANN_MLPImpl*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    return this_->getAnnealItePerStep();
+}
+
+void ANN_MLP_ANNEAL::setAnnealItePerStep(int val)
+{
+    ANN_MLPImpl* this_ = dynamic_cast<ANN_MLPImpl*>(this);
+    if (!this_)
+        CV_Error(Error::StsNotImplemented, "the class is not ANN_MLP_ANNEAL");
+    this_->setAnnealInitialT(val);
+}
 
-    }}
+}}
 
 /* End of file. */
index f6b9bb7..a193606 100644 (file)
@@ -79,8 +79,10 @@ int str_to_ann_train_method( String& str )
 {
     if( !str.compare("BACKPROP") )
         return ANN_MLP::BACKPROP;
-    if( !str.compare("RPROP") )
+    if (!str.compare("RPROP"))
         return ANN_MLP::RPROP;
+    if (!str.compare("ANNEAL"))
+        return ANN_MLP::ANNEAL;
     CV_Error( CV_StsBadArg, "incorrect ann train method string" );
     return -1;
 }
@@ -241,13 +243,92 @@ TEST(ML_ANN, ActivationFunction)
         Mat rx, ry, dst;
         x->predict(testSamples, rx);
         y->predict(testSamples, ry);
-        absdiff(rx, ry, dst);
-        double minVal, maxVal;
-        minMaxLoc(dst, &minVal, &maxVal);
-        ASSERT_TRUE(maxVal<FLT_EPSILON) << "Predict are not equal for " << dataname + activationName[i] + ".yml and " << activationName[i];
+        double n = cvtest::norm(rx, ry, NORM_INF);
+        EXPECT_LT(n,FLT_EPSILON) << "Predict are not equal for " << dataname + activationName[i] + ".yml and " << activationName[i];
 #endif
     }
 }
+//#define GENERATE_TESTDATA
+TEST(ML_ANN, Method)
+{
+    String folder = string(cvtest::TS::ptr()->get_data_path());
+    String original_path = folder + "waveform.data";
+    String dataname = folder + "waveform";
+
+    Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0);
+    Mat responses(tdata2->getResponses().rows, 3, CV_32FC1, Scalar(0));
+    for (int i = 0; i<tdata2->getResponses().rows; i++)
+        responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1;
+    Ptr<TrainData> tdata = TrainData::create(tdata2->getSamples(), ml::ROW_SAMPLE, responses);
+
+    ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
+    RNG& rng = theRNG();
+    rng.state = 0;
+    tdata->setTrainTestSplitRatio(0.8);
+
+    vector<int> methodType;
+    methodType.push_back(ml::ANN_MLP::RPROP);
+    methodType.push_back(ml::ANN_MLP::ANNEAL);
+//    methodType.push_back(ml::ANN_MLP::BACKPROP); -----> NO BACKPROP TEST
+    vector<String> methodName;
+    methodName.push_back("_rprop");
+    methodName.push_back("_anneal");
+//    methodName.push_back("_backprop"); -----> NO BACKPROP TEST
+#ifdef GENERATE_TESTDATA
+    Ptr<ml::ANN_MLP> xx = ml::ANN_MLP_ANNEAL::create();
+    Mat_<int> layerSizesXX(1, 4);
+    layerSizesXX(0, 0) = tdata->getNVars();
+    layerSizesXX(0, 1) = 30;
+    layerSizesXX(0, 2) = 30;
+    layerSizesXX(0, 3) = tdata->getResponses().cols;
+    xx->setLayerSizes(layerSizesXX);
+    xx->setActivationFunction(ml::ANN_MLP::SIGMOID_SYM);
+    xx->setTrainMethod(ml::ANN_MLP::RPROP);
+    xx->setTermCriteria(TermCriteria(TermCriteria::COUNT, 1, 0.01));
+    xx->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE);
+    FileStorage fs;
+    fs.open(dataname + "_init_weight.yml.gz", FileStorage::WRITE + FileStorage::BASE64);
+    xx->write(fs);
+    fs.release();
+#endif
+    for (size_t i = 0; i < methodType.size(); i++)
+    {
+        FileStorage fs;
+        fs.open(dataname + "_init_weight.yml.gz", FileStorage::READ + FileStorage::BASE64);
+        Ptr<ml::ANN_MLP> x = ml::ANN_MLP_ANNEAL::create();
+        x->read(fs.root());
+        x->setTrainMethod(methodType[i]);
+        if (methodType[i] == ml::ANN_MLP::ANNEAL)
+        {
+            x->setAnnealInitialT(12);
+            x->setAnnealFinalT(0.15);
+            x->setAnnealCoolingRatio(0.96);
+            x->setAnnealItePerStep(11);
+        }
+        x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 100, 0.01));
+        x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS);
+        ASSERT_TRUE(x->isTrained()) << "Could not train networks with  " << methodName[i];
+#ifdef  GENERATE_TESTDATA
+        x->save(dataname + methodName[i] + ".yml.gz");
+#endif
+        Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(dataname + methodName[i] + ".yml.gz");
+        ASSERT_TRUE(y != NULL) << "Could not load   " << dataname + methodName[i] + ".yml";
+        Mat testSamples = tdata->getTestSamples();
+        Mat rx, ry, dst;
+        for (int j = 0; j < 4; j++)
+        {
+            rx = x->getWeights(j);
+            ry = y->getWeights(j);
+            double n = cvtest::norm(rx, ry, NORM_INF);
+            EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i] << " layer : " << j;
+        }
+        x->predict(testSamples, rx);
+        y->predict(testSamples, ry);
+        double n = cvtest::norm(rx, ry, NORM_INF);
+        EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i];
+    }
+}
+
 
 // 6. dtree
 // 7. boost
diff --git a/samples/cpp/travelsalesman.cpp b/samples/cpp/travelsalesman.cpp
new file mode 100644 (file)
index 0000000..138f61d
--- /dev/null
@@ -0,0 +1,111 @@
+#include <opencv2/opencv.hpp>
+
+using namespace std;
+using namespace cv;
+
+void DrawTravelMap(Mat &img, vector<Point> &p, vector<int> &n);
+
+class TravelSalesman : public ml::SimulatedAnnealingSolver
+{
+private :
+    vector<Point> &posCity;
+    vector<int> &next;
+    RNG rng;
+    int d0,d1,d2,d3;
+
+public:
+
+    TravelSalesman(vector<Point> &p,vector<int> &n):posCity(p),next(n)
+    {
+        rng = theRNG();
+    };
+    /** Give energy value for  a state of system.*/
+    virtual double energy();
+    /** Function which change the state of system (random pertubation).*/
+    virtual void changedState();
+    /** Function to reverse to the previous state.*/
+    virtual void reverseChangedState();
+
+};
+
+void TravelSalesman::changedState()
+{
+    d0 = rng.uniform(0,static_cast<int>(posCity.size()));
+    d1 = next[d0];
+    d2 = next[d1];
+    d3 = next[d2];
+    int d0Tmp = d0;
+    int d1Tmp = d1;
+    int d2Tmp = d2;
+
+    next[d0Tmp] = d2;
+    next[d2Tmp] = d1;
+    next[d1Tmp] = d3;
+}
+
+
+void TravelSalesman::reverseChangedState()
+{
+    next[d0] = d1;
+    next[d1] = d2;
+    next[d2] = d3;
+}
+
+double TravelSalesman::energy()
+{
+    double e=0;
+    for (size_t i = 0; i < next.size(); i++)
+    {
+        e +=  norm(posCity[i]-posCity[next[i]]);
+    }
+    return e;
+}
+
+
+void DrawTravelMap(Mat &img, vector<Point> &p, vector<int> &n)
+{
+    for (size_t i = 0; i < n.size(); i++)
+    {
+        circle(img,p[i],5,Scalar(0,0,255),2);
+        line(img,p[i],p[n[i]],Scalar(0,255,0),2);
+    }
+}
+int main(void)
+{
+    int nbCity=40;
+    Mat img(500,500,CV_8UC3,Scalar::all(0));
+    RNG &rng=theRNG();
+    int radius=static_cast<int>(img.cols*0.45);
+    Point center(img.cols/2,img.rows/2);
+
+    vector<Point> posCity(nbCity);
+    vector<int> next(nbCity);
+    for (size_t i = 0; i < posCity.size(); i++)
+    {
+        double theta = rng.uniform(0., 2 * CV_PI);
+        posCity[i].x = static_cast<int>(radius*cos(theta)) + center.x;
+        posCity[i].y = static_cast<int>(radius*sin(theta)) + center.y;
+        next[i]=(i+1)%nbCity;
+    }
+    TravelSalesman ts(posCity,next);
+    ts.setCoolingRatio(0.99);
+    ts.setInitialTemperature(100);
+    ts.setIterPerStep(10000*nbCity);
+    ts.setFinalTemperature(100*0.97);
+    DrawTravelMap(img,posCity,next);
+    imshow("Map",img);
+    waitKey(10);
+    for (int i = 0; i < 100; i++)
+    {
+        ts.run();
+        img = Mat::zeros(img.size(),CV_8UC3);
+        DrawTravelMap(img, posCity, next);
+        imshow("Map", img);
+        waitKey(10);
+        double ti=ts.getFinalTemperature();
+        cout<<ti <<"  -> "<<ts.energy()<<"\n";
+        ts.setInitialTemperature(ti);
+        ts.setFinalTemperature(ti*0.97);
+    }
+    return 0;
+}