LibLinear.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include "lib/config.h"
00011
00012 #ifdef HAVE_LAPACK
00013 #include "lib/io.h"
00014 #include "classifier/svm/LibLinear.h"
00015 #include "classifier/svm/SVM_linear.h"
00016 #include "classifier/svm/Tron.h"
00017 #include "features/SparseFeatures.h"
00018
00019 CLibLinear::CLibLinear(LIBLINEAR_LOSS l)
00020 : CSparseLinearClassifier()
00021 {
00022 loss=l;
00023 use_bias=false;
00024 C1=1;
00025 C2=1;
00026 }
00027
00028 CLibLinear::CLibLinear(DREAL C, CSparseFeatures<DREAL>* traindat, CLabels* trainlab)
00029 : CSparseLinearClassifier(), C1(C), C2(C), use_bias(true), epsilon(1e-5)
00030 {
00031 set_features(traindat);
00032 set_labels(trainlab);
00033 loss=LR;
00034 }
00035
00036
00037 CLibLinear::~CLibLinear()
00038 {
00039 }
00040
00041 bool CLibLinear::train()
00042 {
00043 ASSERT(labels);
00044 ASSERT(get_features());
00045 ASSERT(labels->is_two_class_labeling());
00046
00047 CSparseFeatures<DREAL>* sfeat=(CSparseFeatures<DREAL>*) features;
00048
00049 INT num_train_labels=labels->get_num_labels();
00050 INT num_feat=features->get_num_features();
00051 INT num_vec=features->get_num_vectors();
00052
00053 ASSERT(num_vec==num_train_labels);
00054 delete[] w;
00055 if (use_bias)
00056 w=new DREAL[num_feat+1];
00057 else
00058 w=new DREAL[num_feat+0];
00059 w_dim=num_feat;
00060
00061 problem prob;
00062 if (use_bias)
00063 {
00064 prob.n=w_dim+1;
00065 memset(w, 0, sizeof(DREAL)*(w_dim+1));
00066 }
00067 else
00068 {
00069 prob.n=w_dim;
00070 memset(w, 0, sizeof(DREAL)*(w_dim+0));
00071 }
00072 prob.l=num_vec;
00073 prob.x=sfeat;
00074 prob.y=new int[prob.l];
00075 prob.use_bias=use_bias;
00076
00077 for (int i=0; i<prob.l; i++)
00078 prob.y[i]=labels->get_int_label(i);
00079
00080 SG_INFO( "%d training points %d dims\n", prob.l, prob.n);
00081
00082 function *fun_obj=NULL;
00083
00084 switch (loss)
00085 {
00086 case LR:
00087 fun_obj=new l2_lr_fun(&prob, get_C1(), get_C2());
00088 break;
00089 case L2:
00090 fun_obj=new l2loss_svm_fun(&prob, get_C1(), get_C2());
00091 break;
00092 default:
00093 SG_ERROR("unknown loss\n");
00094 break;
00095 }
00096
00097 if (fun_obj)
00098 {
00099 CTron tron_obj(fun_obj, epsilon);
00100 tron_obj.tron(w);
00101 DREAL sgn=prob.y[0];
00102
00103 for (INT i=0; i<w_dim; i++)
00104 w[i]*=sgn;
00105
00106 if (use_bias)
00107 set_bias(sgn*w[w_dim]);
00108 else
00109 set_bias(0);
00110
00111 delete fun_obj;
00112 }
00113
00114 delete[] prob.y;
00115
00116 return true;
00117 }
00118 #endif //HAVE_LAPACK