SVM_linear.h

Go to the documentation of this file.
00001 #ifndef _LIBLINEAR_H
00002 #define _LIBLINEAR_H
00003 
00004 #include "lib/config.h"
00005 
00006 #ifdef HAVE_LAPACK
00007 #include "classifier/svm/Tron.h"
00008 #include "features/SparseFeatures.h"
00009 
00010 #ifdef __cplusplus
00011 extern "C" {
00012 #endif
00013 
00015 struct problem
00016 {
00018     INT l;
00020     INT n;
00022     INT *y;
00024     CSparseFeatures<DREAL>* x;
00026     bool use_bias;
00027 };
00028 
00030 struct parameter
00031 {
00033     int solver_type;
00034 
00035     /* these are for training only */
00037     double eps;
00039     double C;
00041     int nr_weight;
00043     int *weight_label;
00045     double* weight;
00046 };
00047 
00049 struct model
00050 {
00052     struct parameter param;
00054     int nr_class;
00056     int nr_feature;
00058     double *w;
00060     int *label;
00062     double bias;
00063 };
00064 
00065 struct model* train(const struct problem *prob, const struct parameter *param);
00066 void cross_validation(const struct problem *prob, const struct parameter *param, int nr_fold, int *target);
00067 
00068 int predict_values(const struct model *model_, const struct feature_node *x, double* dec_values);
00069 int predict(const struct model *model_, const struct feature_node *x);
00070 int predict_probability(const struct model *model_, const struct feature_node *x, double* prob_estimates);
00071 
00072 int save_model(const char *model_file_name, const struct model *model_);
00073 struct model *load_model(const char *model_file_name);
00074 
00075 int get_nr_feature(const struct model *model_);
00076 int get_nr_class(const struct model *model_);
00077 void get_labels(const struct model *model_, int* label);
00078 
00079 void destroy_model(struct model *model_);
00080 void destroy_param(struct parameter *param);
00081 const char *check_parameter(const struct problem *prob, const struct parameter *param);
00082 
00083 #ifdef __cplusplus
00084 }
00085 #endif
00086 
00088 class l2loss_svm_fun : public function
00089 {
00090 public:
00097     l2loss_svm_fun(const problem *prob, double Cp, double Cn);
00098     ~l2loss_svm_fun();
00099     
00105     double fun(double *w);
00106     
00112     void grad(double *w, double *g);
00113 
00119     void Hv(double *s, double *Hs);
00120 
00125     int get_nr_variable(void);
00126 
00127 private:
00128     void Xv(double *v, double *Xv);
00129     void subXv(double *v, double *Xv);
00130     void subXTv(double *v, double *XTv);
00131 
00132     double *C;
00133     double *z;
00134     double *D;
00135     int *I;
00136     int sizeI;
00137     const problem *prob;
00138 };
00139 
00141 class l2_lr_fun : public function
00142 {
00143 public:
00150     l2_lr_fun(const problem *prob, double Cp, double Cn);
00151     ~l2_lr_fun();
00152 
00158     double fun(double *w);
00159     
00165     void grad(double *w, double *g);
00166 
00172     void Hv(double *s, double *Hs);
00173 
00174     int get_nr_variable(void);
00175 
00176 private:
00177     void Xv(double *v, double *Xv);
00178     void XTv(double *v, double *XTv);
00179 
00180     double *C;
00181     double *z;
00182     double *D;
00183     const problem *prob;
00184 };
00185 #endif //HAVE_LAPACK
00186 #endif //_LIBLINEAR_H

SHOGUN Machine Learning Toolbox - Documentation