GMNPSVM.cpp

Go to the documentation of this file.
00001 /* 
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Vojtech Franc, xfrancv@cmp.felk.cvut.cz
00008  * Copyright (C) 1999-2008 Center for Machine Perception, CTU FEL Prague 
00009  */
00010 
00011 #include "lib/io.h"
00012 #include "classifier/svm/GMNPSVM.h"
00013 #include "classifier/svm/gmnplib.h"
00014 
00015 #define INDEX(ROW,COL,DIM) (((COL)*(DIM))+(ROW)) 
00016 #define MINUS_INF INT_MIN
00017 #define PLUS_INF  INT_MAX
00018 #define KDELTA(A,B) (A==B)
00019 #define KDELTA4(A1,A2,A3,A4) ((A1==A2)||(A1==A3)||(A1==A4)||(A2==A3)||(A2==A4)||(A3==A4))
00020 
00021 CGMNPSVM::CGMNPSVM()
00022 : CMultiClassSVM(ONE_VS_REST)
00023 {
00024 }
00025 
00026 CGMNPSVM::CGMNPSVM(DREAL C, CKernel* k, CLabels* lab)
00027 : CMultiClassSVM(ONE_VS_REST, C, k, lab)
00028 {
00029 }
00030 
00031 CGMNPSVM::~CGMNPSVM()
00032 {
00033 }
00034 
00035 bool CGMNPSVM::train()
00036 {
00037     ASSERT(kernel);
00038     ASSERT(labels && labels->get_num_labels());
00039 
00040     INT num_data = labels->get_num_labels();
00041     INT num_classes = labels->get_num_classes();
00042     INT num_virtual_data= num_data*(num_classes-1);
00043 
00044     SG_INFO( "%d trainlabels, %d classes\n", num_data, num_classes);
00045 
00046     DREAL* vector_y = new double[num_data];
00047     for (int i=0; i<num_data; i++)
00048         vector_y[i]= labels->get_label(i)+1;
00049 
00050     DREAL C = get_C1();
00051     INT tmax = 1000000000;
00052     DREAL tolabs = 0;
00053     DREAL tolrel = epsilon;
00054 
00055     DREAL reg_const=0;
00056     if( C!=0 )
00057         reg_const = 1/(2*C);
00058 
00059 
00060     DREAL* alpha = new DREAL[num_virtual_data];
00061     DREAL* vector_c = new DREAL[num_virtual_data];
00062     memset(vector_c, 0, num_virtual_data*sizeof(DREAL));
00063 
00064     DREAL thlb = 10000000000.0;
00065     INT t = 0;
00066     DREAL* History = NULL;
00067     INT verb = 0;
00068 
00069     CGMNPLib mnp(vector_y,kernel,num_data, num_virtual_data, num_classes, reg_const);
00070 
00071     mnp.gmnp_imdm(vector_c, num_virtual_data, tmax,
00072             tolabs, tolrel, thlb, alpha, &t, &History, verb );
00073 
00074     /* matrix alpha [num_classes x num_data] */
00075     DREAL* all_alphas= new DREAL[num_classes*num_data];
00076     memset(all_alphas,0,num_classes*num_data*sizeof(DREAL));
00077 
00078     /* bias vector b [num_classes x 1] */
00079     DREAL* all_bs=new DREAL[num_classes];
00080     memset(all_bs,0,num_classes*sizeof(DREAL));
00081 
00082     /* compute alpha/b from virt_data */
00083     for(INT i=0; i < num_classes; i++ )
00084     {
00085         for(INT j=0; j < num_virtual_data; j++ )
00086         {
00087             INT inx1=0;
00088             INT inx2=0;
00089 
00090             mnp.get_indices2( &inx1, &inx2, j );
00091 
00092             all_alphas[(inx1*num_classes)+i] += 
00093                 alpha[j]*(KDELTA(vector_y[inx1],i+1)+KDELTA(i+1,inx2));
00094             all_bs[i] += alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2));
00095         }
00096     }
00097 
00098     create_multiclass_svm(num_classes);
00099 
00100     for (INT i=0; i<num_classes; i++)
00101     {
00102         INT num_sv=0;
00103         for (INT j=0; j<num_data; j++)
00104         {
00105             if (all_alphas[j*num_classes+i] != 0)
00106                 num_sv++;
00107         }
00108         ASSERT(num_sv>0);
00109         SG_DEBUG("svm[%d] has %d sv, b=%f\n", i, num_sv, all_bs[i]);
00110 
00111         CSVM* svm=new CSVM(num_sv);
00112 
00113         INT k=0;
00114         for (INT j=0; j<num_data; j++)
00115         {
00116             if (all_alphas[j*num_classes+i] != 0)
00117             {
00118                 if (i==vector_y[j]-1)
00119                     svm->set_alpha(k, all_alphas[j*num_classes+i]);
00120                 else
00121                     svm->set_alpha(k, -all_alphas[j*num_classes+i]);
00122 
00123                 svm->set_support_vector(k, j);
00124                 k++;
00125             }
00126         }
00127 
00128         svm->set_bias(all_bs[i]);
00129         set_svm(i, svm);
00130     }
00131 
00132     delete[] vector_c;
00133     delete[] alpha;
00134     delete[] all_alphas;
00135     delete[] all_bs;
00136     delete[] vector_y;
00137     delete[] History;
00138 
00139     return true;
00140 }

SHOGUN Machine Learning Toolbox - Documentation