SVM.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Soeren Sonnenburg
00008  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _SVM_H___
00012 #define _SVM_H___
00013 
00014 #include "lib/common.h"
00015 #include "features/Features.h"
00016 #include "kernel/Kernel.h"
00017 #include "kernel/KernelMachine.h"
00018 
00019 class CKernelMachine;
00020 
00043 class CSVM : public CKernelMachine
00044 {
00045     public:
00049         CSVM(int32_t num_sv=0);
00050 
00058         CSVM(float64_t C, CKernel* k, CLabels* lab);
00059         virtual ~CSVM();
00060 
00063         void set_defaults(int32_t num_sv=0);
00064 
00068         bool load(FILE* svm_file);
00069 
00073         bool save(FILE* svm_file);
00074 
00079         inline void set_nu(float64_t nue) { nu=nue; }
00080 
00089         inline void set_C(float64_t c1, float64_t c2) { C1=c1; C2=c2; }
00090 
00095         inline void set_weight_epsilon(float64_t eps) { weight_epsilon=eps; }
00096 
00101         inline void set_epsilon(float64_t eps) { epsilon=eps; }
00102 
00107         inline void set_tube_epsilon(float64_t eps) { tube_epsilon=eps; }
00108 
00113         inline void set_C_mkl(float64_t C) { C_mkl = C; }
00114 
00119         inline void set_mkl_norm(int32_t norm)
00120         {
00121             if (norm!=1 && norm!=2)
00122                 SG_ERROR("Only 1-and 2-norm supported\n");
00123             mkl_norm = norm;
00124         }
00125 
00130         inline void set_qpsize(int32_t qps) { qpsize=qps; }
00131 
00136         inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; }
00137 
00142         inline bool get_bias_enabled() { return use_bias; }
00143 
00148         inline float64_t get_weight_epsilon() { return weight_epsilon; }
00149 
00154         inline float64_t get_epsilon() { return epsilon; }
00155 
00160         inline float64_t get_nu() { return nu; }
00161 
00166         inline float64_t get_C1() { return C1; }
00167 
00172         inline float64_t get_C2() { return C2; }
00173 
00178         inline int32_t get_qpsize() { return qpsize; }
00179 
00185         inline int32_t get_support_vector(int32_t idx)
00186         {
00187             ASSERT(svm_model.svs && idx<svm_model.num_svs);
00188             return svm_model.svs[idx];
00189         }
00190 
00196         inline float64_t get_alpha(int32_t idx)
00197         {
00198             ASSERT(svm_model.alpha && idx<svm_model.num_svs);
00199             return svm_model.alpha[idx];
00200         }
00201 
00208         inline bool set_support_vector(int32_t idx, int32_t val)
00209         {
00210             if (svm_model.svs && idx<svm_model.num_svs)
00211                 svm_model.svs[idx]=val;
00212             else
00213                 return false;
00214 
00215             return true;
00216         }
00217 
00224         inline bool set_alpha(int32_t idx, float64_t val)
00225         {
00226             if (svm_model.alpha && idx<svm_model.num_svs)
00227                 svm_model.alpha[idx]=val;
00228             else
00229                 return false;
00230 
00231             return true;
00232         }
00233 
00238         inline float64_t get_bias()
00239         {
00240             return svm_model.b;
00241         }
00242 
00247         inline void set_bias(float64_t bias)
00248         {
00249             svm_model.b=bias;
00250         }
00251 
00256         inline int32_t get_num_support_vectors()
00257         {
00258             return svm_model.num_svs;
00259         }
00260 
00266         void set_alphas(float64_t* alphas, int32_t d)
00267         {
00268             ASSERT(alphas);
00269             ASSERT(d==svm_model.num_svs);
00270 
00271             for(int32_t i=0; i<d; i++)
00272                 svm_model.alpha[i]=alphas[i];
00273         }
00274 
00280         void set_support_vectors(int32_t* svs, int32_t d)
00281         {
00282             ASSERT(svs);
00283             ASSERT(d==svm_model.num_svs);
00284 
00285             for(int32_t i=0; i<d; i++)
00286                 svm_model.svs[i]=svs[i];
00287         }
00288 
00294         void get_support_vectors(int32_t** svs, int32_t* num)
00295         {
00296             int32_t nsv = get_num_support_vectors();
00297 
00298             ASSERT(svs && num);
00299             *svs=NULL;
00300             *num=nsv;
00301 
00302             if (nsv>0)
00303             {
00304                 *svs = (int32_t*) malloc(sizeof(int32_t)*nsv);
00305                 for(int32_t i=0; i<nsv; i++)
00306                     (*svs)[i] = get_support_vector(i);
00307             }
00308         }
00309 
00315         void get_alphas(float64_t** alphas, int32_t* d1)
00316         {
00317             int32_t nsv = get_num_support_vectors();
00318 
00319             ASSERT(alphas && d1);
00320             *alphas=NULL;
00321             *d1=nsv;
00322 
00323             if (nsv>0)
00324             {
00325                 *alphas = (float64_t*) malloc(nsv*sizeof(float64_t));
00326                 for(int32_t i=0; i<nsv; i++)
00327                     (*alphas)[i] = get_alpha(i);
00328             }
00329         }
00330 
00335         inline bool create_new_model(int32_t num)
00336         {
00337             delete[] svm_model.alpha;
00338             delete[] svm_model.svs;
00339 
00340             svm_model.b=0;
00341             svm_model.num_svs=num;
00342 
00343             if (num>0)
00344             {
00345                 svm_model.alpha= new float64_t[num];
00346                 svm_model.svs= new int32_t[num];
00347                 return (svm_model.alpha!=NULL && svm_model.svs!=NULL);
00348             }
00349             else
00350             {
00351                 svm_model.alpha= NULL;
00352                 svm_model.svs=NULL;
00353                 return true;
00354             }
00355         }
00356 
00361         inline void set_shrinking_enabled(bool enable)
00362         {
00363             use_shrinking=enable;
00364         }
00365 
00370         inline bool get_shrinking_enabled()
00371         {
00372             return use_shrinking;
00373         }
00374 
00379         inline void set_mkl_enabled(bool enable)
00380         {
00381             use_mkl=enable;
00382         }
00383 
00388         inline bool get_mkl_enabled()
00389         {
00390             return use_mkl;
00391         }
00392 
00397         float64_t compute_objective();
00398 
00403         inline void set_objective(float64_t v)
00404         {
00405             objective=v;
00406         }
00407 
00412         inline float64_t get_objective()
00413         {
00414             return objective;
00415         }
00416 
00421         bool init_kernel_optimization();
00422 
00428         virtual CLabels* classify(CLabels* lab=NULL);
00429 
00435         virtual float64_t classify_example(int32_t num);
00436 
00442         static void* classify_example_helper(void* p);
00443 
00444     protected:
00447         struct TModel
00448         {
00450             float64_t b;
00452             float64_t* alpha;
00454             int32_t* svs;
00456             int32_t num_svs;
00457         };
00458 
00460         TModel svm_model;
00462         bool svm_loaded;
00464         float64_t weight_epsilon;
00466         float64_t epsilon;
00468         float64_t tube_epsilon;
00470         float64_t nu;
00472         float64_t C1;
00474         float64_t C2;
00476         int32_t  mkl_norm;
00478         float64_t C_mkl;
00480         float64_t objective;
00482         int32_t qpsize;
00484         bool use_bias;
00486         bool use_shrinking;
00488         bool use_mkl;
00489 };
00490 #endif

SHOGUN Machine Learning Toolbox - Documentation