SVMLin.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "classifier/svm/SVMLin.h"
00012 #include "features/Labels.h"
00013 #include "lib/Mathematics.h"
00014 #include "classifier/svm/ssl.h"
00015 #include "classifier/SparseLinearClassifier.h"
00016 #include "features/SparseFeatures.h"
00017 #include "features/Labels.h"
00018
00019 CSVMLin::CSVMLin()
00020 : CSparseLinearClassifier(), C1(1), C2(1), epsilon(1e-5), use_bias(true)
00021 {
00022 }
00023
00024 CSVMLin::CSVMLin(DREAL C, CSparseFeatures<DREAL>* traindat, CLabels* trainlab)
00025 : CSparseLinearClassifier(), C1(C), C2(C), epsilon(1e-5), use_bias(true)
00026 {
00027 set_features(traindat);
00028 set_labels(trainlab);
00029 }
00030
00031
00032 CSVMLin::~CSVMLin()
00033 {
00034 }
00035
00036 bool CSVMLin::train()
00037 {
00038 ASSERT(labels);
00039 ASSERT(get_features());
00040
00041 INT num_train_labels=0;
00042 DREAL* train_labels=labels->get_labels(num_train_labels);
00043 INT num_feat=features->get_num_features();
00044 INT num_vec=features->get_num_vectors();
00045
00046 ASSERT(num_vec==num_train_labels);
00047 delete[] w;
00048
00049 struct options Options;
00050 struct data Data;
00051 struct vector_double Weights;
00052 struct vector_double Outputs;
00053
00054 Data.l=num_vec;
00055 Data.m=num_vec;
00056 Data.u=0;
00057 Data.n=num_feat+1;
00058 Data.nz=num_feat+1;
00059 Data.Y=train_labels;
00060 Data.features=get_features();
00061 Data.C = new double[Data.l];
00062
00063 Options.algo = SVM;
00064 Options.lambda=1/(2*get_C1());
00065 Options.lambda_u=1/(2*get_C1());
00066 Options.S=10000;
00067 Options.R=0.5;
00068 Options.epsilon = get_epsilon();
00069 Options.cgitermax=10000;
00070 Options.mfnitermax=50;
00071 Options.Cp = get_C2()/get_C1();
00072 Options.Cn = 1;
00073
00074 if (use_bias)
00075 Options.bias=1.0;
00076 else
00077 Options.bias=0.0;
00078
00079 for(int i=0;i<num_vec;i++)
00080 {
00081 if(train_labels[i]>0)
00082 Data.C[i]=Options.Cp;
00083 else
00084 Data.C[i]=Options.Cn;
00085 }
00086 ssl_train(&Data, &Options, &Weights, &Outputs);
00087 ASSERT(Weights.vec && Weights.d==num_feat+1);
00088
00089 DREAL sgn=train_labels[0];
00090 for (INT i=0; i<num_feat+1; i++)
00091 Weights.vec[i]*=sgn;
00092
00093 CSparseLinearClassifier::set_w(Weights.vec, num_feat);
00094 CSparseLinearClassifier::set_bias(Weights.vec[num_feat]);
00095
00096 delete[] Data.C;
00097 delete[] train_labels;
00098 delete[] Outputs.vec;
00099 return true;
00100 }