int regularized;
int train_method;
int mini_batch_size;
- CvTermCriteria term_crit;
+ cv::TermCriteria term_crit;
LogisticRegressionParams();
- LogisticRegressionParams(double alpha, int num_iters, int norm, int regularized, int train_method, int minbatchsize);
+ LogisticRegressionParams(double learning_rate, int iters, int train_method, int normlization, int reg, int mini_batch_size);
+
};
class CV_EXPORTS LogisticRegression
{
public:
-
- LogisticRegression();
+ LogisticRegression( const LogisticRegressionParams& params);
LogisticRegression(cv::InputArray data_ip, cv::InputArray labels_ip, const LogisticRegressionParams& params);
virtual ~LogisticRegression();
virtual bool train(cv::InputArray data_ip, cv::InputArray label_ip);
virtual void predict( cv::InputArray data, cv::OutputArray predicted_labels ) const;
- virtual void save(std::string filepath) const;
- virtual void load(const std::string filepath);
+ virtual void write(FileStorage& fs) const;
+ virtual void read(const FileNode& fn);
- cv::Mat get_learnt_thetas() const;
+ const cv::Mat get_learnt_thetas() const;
+ virtual void clear();
protected:
virtual cv::Mat compute_mini_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
virtual bool set_label_map(const cv::Mat& labels);
static cv::Mat remap_labels(const cv::Mat& labels, const std::map<int, int>& lmap);
-
- virtual void write(FileStorage& fs) const;
- virtual void read(const FileNode& fn);
- virtual void clear();
-
};
}// namespace cv