SVMLin.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2006-2008 Soeren Sonnenburg
00008  * Copyright (C) 2006-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "classifier/svm/SVMLin.h"
00012 #include "features/Labels.h"
00013 #include "lib/Mathematics.h"
00014 #include "classifier/svm/ssl.h"
00015 #include "classifier/SparseLinearClassifier.h"
00016 #include "features/SparseFeatures.h"
00017 #include "features/Labels.h"
00018 
00019 CSVMLin::CSVMLin()
00020 : CSparseLinearClassifier(), C1(1), C2(1), epsilon(1e-5), use_bias(true)
00021 {
00022 }
00023 
00024 CSVMLin::CSVMLin(DREAL C, CSparseFeatures<DREAL>* traindat, CLabels* trainlab)
00025 : CSparseLinearClassifier(), C1(C), C2(C), epsilon(1e-5), use_bias(true)
00026 {
00027     set_features(traindat);
00028     set_labels(trainlab);
00029 }
00030 
00031 
00032 CSVMLin::~CSVMLin()
00033 {
00034 }
00035 
00036 bool CSVMLin::train()
00037 {
00038     ASSERT(labels);
00039     ASSERT(get_features());
00040 
00041     INT num_train_labels=0;
00042     DREAL* train_labels=labels->get_labels(num_train_labels);
00043     INT num_feat=features->get_num_features();
00044     INT num_vec=features->get_num_vectors();
00045 
00046     ASSERT(num_vec==num_train_labels);
00047     delete[] w;
00048 
00049     struct options Options;
00050     struct data Data;
00051     struct vector_double Weights;
00052     struct vector_double Outputs;
00053 
00054     Data.l=num_vec;
00055     Data.m=num_vec;
00056     Data.u=0; 
00057     Data.n=num_feat+1;
00058     Data.nz=num_feat+1;
00059     Data.Y=train_labels;
00060     Data.features=get_features();
00061     Data.C = new double[Data.l];
00062 
00063     Options.algo = SVM;
00064     Options.lambda=1/(2*get_C1());
00065     Options.lambda_u=1/(2*get_C1());
00066     Options.S=10000;
00067     Options.R=0.5;
00068     Options.epsilon = get_epsilon();
00069     Options.cgitermax=10000;
00070     Options.mfnitermax=50;
00071     Options.Cp = get_C2()/get_C1();
00072     Options.Cn = 1;
00073     
00074     if (use_bias)
00075         Options.bias=1.0;
00076     else
00077         Options.bias=0.0;
00078 
00079     for(int i=0;i<num_vec;i++)
00080     {
00081         if(train_labels[i]>0) 
00082             Data.C[i]=Options.Cp;
00083         else 
00084             Data.C[i]=Options.Cn;
00085     }
00086     ssl_train(&Data, &Options, &Weights, &Outputs);
00087     ASSERT(Weights.vec && Weights.d==num_feat+1);
00088 
00089     DREAL sgn=train_labels[0];
00090     for (INT i=0; i<num_feat+1; i++)
00091         Weights.vec[i]*=sgn;
00092 
00093     CSparseLinearClassifier::set_w(Weights.vec, num_feat);
00094     CSparseLinearClassifier::set_bias(Weights.vec[num_feat]);
00095 
00096     delete[] Data.C;
00097     delete[] train_labels;
00098     delete[] Outputs.vec;
00099     return true;
00100 }

SHOGUN Machine Learning Toolbox - Documentation