LibLinear.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include "lib/config.h"
00011
00012 #ifdef HAVE_LAPACK
00013 #include "lib/io.h"
00014 #include "classifier/svm/LibLinear.h"
00015 #include "classifier/svm/SVM_linear.h"
00016 #include "classifier/svm/Tron.h"
00017 #include "features/SparseFeatures.h"
00018
00019 CLibLinear::CLibLinear(LIBLINEAR_LOSS l)
00020 : CSparseLinearClassifier()
00021 {
00022 loss=l;
00023 use_bias=false;
00024 C1=1;
00025 C2=1;
00026 }
00027
00028 CLibLinear::CLibLinear(
00029 float64_t C, CSparseFeatures<float64_t>* traindat, CLabels* trainlab)
00030 : CSparseLinearClassifier(), C1(C), C2(C), use_bias(true), epsilon(1e-5)
00031 {
00032 set_features(traindat);
00033 set_labels(trainlab);
00034 loss=LR;
00035 }
00036
00037
00038 CLibLinear::~CLibLinear()
00039 {
00040 }
00041
00042 bool CLibLinear::train()
00043 {
00044 ASSERT(labels);
00045 ASSERT(get_features());
00046 ASSERT(labels->is_two_class_labeling());
00047
00048 CSparseFeatures<float64_t>* sfeat=(CSparseFeatures<float64_t>*) features;
00049
00050 int32_t num_train_labels=labels->get_num_labels();
00051 int32_t num_feat=features->get_num_features();
00052 int32_t num_vec=features->get_num_vectors();
00053
00054 ASSERT(num_vec==num_train_labels);
00055 delete[] w;
00056 if (use_bias)
00057 w=new float64_t[num_feat+1];
00058 else
00059 w=new float64_t[num_feat+0];
00060 w_dim=num_feat;
00061
00062 problem prob;
00063 if (use_bias)
00064 {
00065 prob.n=w_dim+1;
00066 memset(w, 0, sizeof(float64_t)*(w_dim+1));
00067 }
00068 else
00069 {
00070 prob.n=w_dim;
00071 memset(w, 0, sizeof(float64_t)*(w_dim+0));
00072 }
00073 prob.l=num_vec;
00074 prob.x=sfeat;
00075 prob.y=new int[prob.l];
00076 prob.use_bias=use_bias;
00077
00078 for (int32_t i=0; i<prob.l; i++)
00079 prob.y[i]=labels->get_int_label(i);
00080
00081 SG_INFO( "%d training points %d dims\n", prob.l, prob.n);
00082
00083 function *fun_obj=NULL;
00084
00085 switch (loss)
00086 {
00087 case LR:
00088 fun_obj=new l2_lr_fun(&prob, get_C1(), get_C2());
00089 break;
00090 case L2:
00091 fun_obj=new l2loss_svm_fun(&prob, get_C1(), get_C2());
00092 break;
00093 default:
00094 SG_ERROR("unknown loss\n");
00095 break;
00096 }
00097
00098 if (fun_obj)
00099 {
00100 CTron tron_obj(fun_obj, epsilon);
00101 tron_obj.tron(w);
00102 float64_t sgn=prob.y[0];
00103
00104 for (int32_t i=0; i<w_dim; i++)
00105 w[i]*=sgn;
00106
00107 if (use_bias)
00108 set_bias(sgn*w[w_dim]);
00109 else
00110 set_bias(0);
00111
00112 delete fun_obj;
00113 }
00114
00115 delete[] prob.y;
00116
00117 return true;
00118 }
00119 #endif //HAVE_LAPACK