SVM_linear.h
Go to the documentation of this file.00001 #ifndef _LIBLINEAR_H
00002 #define _LIBLINEAR_H
00003
00004 #include "lib/config.h"
00005
00006 #ifdef HAVE_LAPACK
00007 #include "classifier/svm/Tron.h"
00008 #include "features/SparseFeatures.h"
00009
00010 #ifdef __cplusplus
00011 extern "C" {
00012 #endif
00013
00015 struct problem
00016 {
00018 INT l;
00020 INT n;
00022 INT *y;
00024 CSparseFeatures<DREAL>* x;
00026 bool use_bias;
00027 };
00028
00030 struct parameter
00031 {
00033 int solver_type;
00034
00035
00037 double eps;
00039 double C;
00041 int nr_weight;
00043 int *weight_label;
00045 double* weight;
00046 };
00047
00049 struct model
00050 {
00052 struct parameter param;
00054 int nr_class;
00056 int nr_feature;
00058 double *w;
00060 int *label;
00062 double bias;
00063 };
00064
00065 struct model* train(const struct problem *prob, const struct parameter *param);
00066 void cross_validation(const struct problem *prob, const struct parameter *param, int nr_fold, int *target);
00067
00068 int predict_values(const struct model *model_, const struct feature_node *x, double* dec_values);
00069 int predict(const struct model *model_, const struct feature_node *x);
00070 int predict_probability(const struct model *model_, const struct feature_node *x, double* prob_estimates);
00071
00072 int save_model(const char *model_file_name, const struct model *model_);
00073 struct model *load_model(const char *model_file_name);
00074
00075 int get_nr_feature(const struct model *model_);
00076 int get_nr_class(const struct model *model_);
00077 void get_labels(const struct model *model_, int* label);
00078
00079 void destroy_model(struct model *model_);
00080 void destroy_param(struct parameter *param);
00081 const char *check_parameter(const struct problem *prob, const struct parameter *param);
00082
00083 #ifdef __cplusplus
00084 }
00085 #endif
00086
00088 class l2loss_svm_fun : public function
00089 {
00090 public:
00097 l2loss_svm_fun(const problem *prob, double Cp, double Cn);
00098 ~l2loss_svm_fun();
00099
00105 double fun(double *w);
00106
00112 void grad(double *w, double *g);
00113
00119 void Hv(double *s, double *Hs);
00120
00125 int get_nr_variable(void);
00126
00127 private:
00128 void Xv(double *v, double *Xv);
00129 void subXv(double *v, double *Xv);
00130 void subXTv(double *v, double *XTv);
00131
00132 double *C;
00133 double *z;
00134 double *D;
00135 int *I;
00136 int sizeI;
00137 const problem *prob;
00138 };
00139
00141 class l2_lr_fun : public function
00142 {
00143 public:
00150 l2_lr_fun(const problem *prob, double Cp, double Cn);
00151 ~l2_lr_fun();
00152
00158 double fun(double *w);
00159
00165 void grad(double *w, double *g);
00166
00172 void Hv(double *s, double *Hs);
00173
00174 int get_nr_variable(void);
00175
00176 private:
00177 void Xv(double *v, double *Xv);
00178 void XTv(double *v, double *XTv);
00179
00180 double *C;
00181 double *z;
00182 double *D;
00183 const problem *prob;
00184 };
00185 #endif //HAVE_LAPACK
00186 #endif //_LIBLINEAR_H