SVM.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Soeren Sonnenburg
00008  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _SVM_H___
00012 #define _SVM_H___
00013 
00014 #include "lib/common.h"
00015 #include "features/Features.h"
00016 #include "kernel/Kernel.h"
00017 #include "kernel/KernelMachine.h"
00018 
00019 class CKernelMachine;
00020 
00043 class CSVM : public CKernelMachine
00044 {
00045     public:
00049         CSVM(INT num_sv=0);
00050 
00058         CSVM(DREAL C, CKernel* k, CLabels* lab);
00059         virtual ~CSVM();
00060 
00063         void set_defaults(INT num_sv=0);
00064 
00068         bool load(FILE* svm_file);
00069 
00073         bool save(FILE* svm_file);
00074 
00079         inline void set_nu(DREAL nue) { nu=nue; }
00080 
00086         inline void set_C(DREAL c1, DREAL c2) { C1=c1; C2=c2; }
00087 
00092         inline void set_weight_epsilon(DREAL eps) { weight_epsilon=eps; }
00093 
00098         inline void set_epsilon(DREAL eps) { epsilon=eps; }
00099 
00104         inline void set_tube_epsilon(DREAL eps) { tube_epsilon=eps; }
00105 
00110         inline void set_C_mkl(DREAL C) { C_mkl = C; }
00111 
00116         inline void set_qpsize(INT qps) { qpsize=qps; }
00117 
00122         inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; }
00123 
00128         inline bool get_bias_enabled() { return use_bias; }
00129 
00134         inline DREAL get_weight_epsilon() { return weight_epsilon; }
00135 
00140         inline DREAL get_epsilon() { return epsilon; }
00141 
00146         inline DREAL get_nu() { return nu; }
00147 
00152         inline DREAL get_C1() { return C1; }
00153 
00158         inline DREAL get_C2() { return C2; }
00159 
00164         inline int get_qpsize() { return qpsize; }
00165 
00171         inline int get_support_vector(INT idx)
00172         {
00173             ASSERT(svm_model.svs && idx<svm_model.num_svs);
00174             return svm_model.svs[idx];
00175         }
00176 
00182         inline DREAL get_alpha(INT idx)
00183         {
00184             ASSERT(svm_model.alpha && idx<svm_model.num_svs);
00185             return svm_model.alpha[idx];
00186         }
00187 
00194         inline bool set_support_vector(INT idx, INT val)
00195         {
00196             if (svm_model.svs && idx<svm_model.num_svs)
00197                 svm_model.svs[idx]=val;
00198             else
00199                 return false;
00200 
00201             return true;
00202         }
00203 
00210         inline bool set_alpha(INT idx, DREAL val)
00211         {
00212             if (svm_model.alpha && idx<svm_model.num_svs)
00213                 svm_model.alpha[idx]=val;
00214             else
00215                 return false;
00216 
00217             return true;
00218         }
00219 
00224         inline DREAL get_bias()
00225         {
00226             return svm_model.b;
00227         }
00228 
00233         inline void set_bias(DREAL bias)
00234         {
00235             svm_model.b=bias;
00236         }
00237 
00242         inline int get_num_support_vectors()
00243         {
00244             return svm_model.num_svs;
00245         }
00246 
00252         void set_alphas(DREAL* alphas, INT d)
00253         {
00254             ASSERT(alphas);
00255             ASSERT(d==svm_model.num_svs);
00256 
00257             for(int i=0; i<d; i++)
00258                 svm_model.alpha[i]=alphas[i];
00259         }
00260 
00266         void set_support_vectors(INT* svs, INT d)
00267         {
00268             ASSERT(svs);
00269             ASSERT(d==svm_model.num_svs);
00270 
00271             for(int i=0; i<d; i++)
00272                 svm_model.svs[i]=svs[i];
00273         }
00274 
00280         void get_support_vectors(INT** svs, INT* num)
00281         {
00282             int nsv = get_num_support_vectors();
00283 
00284             ASSERT(svs && num);
00285             *svs=NULL;
00286             *num=nsv;
00287 
00288             if (nsv>0)
00289             {
00290                 *svs = (INT*) malloc(sizeof(INT)*nsv);
00291                 for(int i=0; i<nsv; i++)
00292                     (*svs)[i] = get_support_vector(i);
00293             }
00294         }
00295 
00301         void get_alphas(DREAL** alphas, INT* d1)
00302         {
00303             int nsv = get_num_support_vectors();
00304 
00305             ASSERT(alphas && d1);
00306             *alphas=NULL;
00307             *d1=nsv;
00308 
00309             if (nsv>0)
00310             {
00311                 *alphas = (DREAL*) malloc(nsv*sizeof(DREAL));
00312                 for(int i=0; i<nsv; i++)
00313                     (*alphas)[i] = get_alpha(i);
00314             }
00315         }
00316 
00321         inline bool create_new_model(INT num)
00322         {
00323             delete[] svm_model.alpha;
00324             delete[] svm_model.svs;
00325 
00326             svm_model.b=0;
00327             svm_model.num_svs=num;
00328 
00329             if (num>0)
00330             {
00331                 svm_model.alpha= new double[num];
00332                 svm_model.svs= new int[num];
00333                 return (svm_model.alpha!=NULL && svm_model.svs!=NULL);
00334             }
00335             else
00336             {
00337                 svm_model.alpha= NULL;
00338                 svm_model.svs=NULL;
00339                 return true;
00340             }
00341         }
00342 
00347         inline void set_shrinking_enabled(bool enable)
00348         {
00349             use_shrinking=enable;
00350         }
00351 
00356         inline bool get_shrinking_enabled()
00357         {
00358             return use_shrinking;
00359         }
00360 
00365         inline void set_mkl_enabled(bool enable)
00366         {
00367             use_mkl=enable;
00368         }
00369 
00374         inline bool get_mkl_enabled()
00375         {
00376             return use_mkl;
00377         }
00378 
00383         DREAL compute_objective();
00384 
00389         inline void set_objective(DREAL v)
00390         {
00391             objective=v;
00392         }
00393 
00398         inline DREAL get_objective()
00399         {
00400             return objective ;
00401         }
00402 
00407         bool init_kernel_optimization();
00408 
00414         virtual CLabels* classify(CLabels* lab=NULL);
00415 
00421         virtual DREAL classify_example(INT num);
00422 
00428         static void* classify_example_helper(void* p);
00429 
00434         void set_precomputed_subkernels_enabled(bool flag)
00435         {
00436             use_precomputed_subkernels=flag;
00437         }
00438 
00439     protected:
00442         struct TModel
00443         {
00445             DREAL b;
00447             DREAL* alpha;
00449             INT* svs;
00451             INT num_svs;
00452         };
00453 
00455         TModel svm_model;
00457         bool svm_loaded;
00459         DREAL weight_epsilon;
00461         DREAL epsilon;
00463         DREAL tube_epsilon;
00465         DREAL nu;
00467         DREAL C1;
00469         DREAL C2;
00471         DREAL C_mkl;
00473         DREAL objective;
00475         int qpsize;
00477         bool use_bias;
00479         bool use_shrinking;
00481         bool use_mkl;
00483         bool use_precomputed_subkernels;
00484 };
00485 #endif

SHOGUN Machine Learning Toolbox - Documentation