SVM_linear.h
Go to the documentation of this file.00001 #ifndef _LIBLINEAR_H
00002 #define _LIBLINEAR_H
00003
00004 #include "lib/config.h"
00005
00006 #ifdef HAVE_LAPACK
00007 #include "classifier/svm/Tron.h"
00008 #include "features/SparseFeatures.h"
00009
00010 #ifdef __cplusplus
00011 extern "C" {
00012 #endif
00013
00015 struct problem
00016 {
00018 int32_t l;
00020 int32_t n;
00022 int32_t *y;
00024 CSparseFeatures<float64_t>* x;
00026 bool use_bias;
00027 };
00028
00030 struct parameter
00031 {
00033 int32_t solver_type;
00034
00035
00037 float64_t eps;
00039 float64_t C;
00041 int32_t nr_weight;
00043 int32_t *weight_label;
00045 float64_t* weight;
00046 };
00047
00049 struct model
00050 {
00052 struct parameter param;
00054 int32_t nr_class;
00056 int32_t nr_feature;
00058 float64_t *w;
00060 int32_t *label;
00062 float64_t bias;
00063 };
00064
00065 struct model* train(const struct problem *prob, const struct parameter *param);
00066 void cross_validation(
00067 const struct problem *prob, const struct parameter *param, int32_t nr_fold,
00068 int32_t *target);
00069
00070 int32_t predict_values(
00071 const struct model *model_, const struct feature_node *x,
00072 float64_t* dec_values);
00073 int32_t predict(const struct model *model_, const struct feature_node *x);
00074 int32_t predict_probability(
00075 const struct model *model_, const struct feature_node *x,
00076 float64_t* prob_estimates);
00077
00078 int32_t save_model(const char *model_file_name, const struct model *model_);
00079 struct model *load_model(const char *model_file_name);
00080
00081 int32_t get_nr_feature(const struct model *model_);
00082 int32_t get_nr_class(const struct model *model_);
00083 void get_labels(const struct model *model_, int32_t* label);
00084
00085 void destroy_model(struct model *model_);
00086 void destroy_param(struct parameter *param);
00087 const char *check_parameter(
00088 const struct problem *prob, const struct parameter *param);
00089
00090 #ifdef __cplusplus
00091 }
00092 #endif
00093
00095 class l2loss_svm_fun : public function
00096 {
00097 public:
00104 l2loss_svm_fun(const problem *prob, float64_t Cp, float64_t Cn);
00105 ~l2loss_svm_fun();
00106
00112 float64_t fun(float64_t *w);
00113
00119 void grad(float64_t *w, float64_t *g);
00120
00126 void Hv(float64_t *s, float64_t *Hs);
00127
00132 int32_t get_nr_variable(void);
00133
00134 private:
00135 void Xv(float64_t *v, float64_t *Xv);
00136 void subXv(float64_t *v, float64_t *Xv);
00137 void subXTv(float64_t *v, float64_t *XTv);
00138
00139 float64_t *C;
00140 float64_t *z;
00141 float64_t *D;
00142 int32_t *I;
00143 int32_t sizeI;
00144 const problem *prob;
00145 };
00146
00148 class l2_lr_fun : public function
00149 {
00150 public:
00157 l2_lr_fun(const problem *prob, float64_t Cp, float64_t Cn);
00158 ~l2_lr_fun();
00159
00165 float64_t fun(float64_t *w);
00166
00172 void grad(float64_t *w, float64_t *g);
00173
00179 void Hv(float64_t *s, float64_t *Hs);
00180
00181 int32_t get_nr_variable(void);
00182
00183 private:
00184 void Xv(float64_t *v, float64_t *Xv);
00185 void XTv(float64_t *v, float64_t *XTv);
00186
00187 float64_t *C;
00188 float64_t *z;
00189 float64_t *D;
00190 const problem *prob;
00191 };
00192 #endif //HAVE_LAPACK
00193 #endif //_LIBLINEAR_H