00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/io.h"
00012 #include "classifier/svm/GMNPSVM.h"
00013 #include "classifier/svm/gmnplib.h"
00014
00015 #define INDEX(ROW,COL,DIM) (((COL)*(DIM))+(ROW))
00016 #define MINUS_INF INT_MIN
00017 #define PLUS_INF INT_MAX
00018 #define KDELTA(A,B) (A==B)
00019 #define KDELTA4(A1,A2,A3,A4) ((A1==A2)||(A1==A3)||(A1==A4)||(A2==A3)||(A2==A4)||(A3==A4))
00020
00021 CGMNPSVM::CGMNPSVM()
00022 : CMultiClassSVM(ONE_VS_REST)
00023 {
00024 }
00025
00026 CGMNPSVM::CGMNPSVM(DREAL C, CKernel* k, CLabels* lab)
00027 : CMultiClassSVM(ONE_VS_REST, C, k, lab)
00028 {
00029 }
00030
00031 CGMNPSVM::~CGMNPSVM()
00032 {
00033 }
00034
00035 bool CGMNPSVM::train()
00036 {
00037 ASSERT(kernel);
00038 ASSERT(labels && labels->get_num_labels());
00039
00040 INT num_data = labels->get_num_labels();
00041 INT num_classes = labels->get_num_classes();
00042 INT num_virtual_data= num_data*(num_classes-1);
00043
00044 SG_INFO( "%d trainlabels, %d classes\n", num_data, num_classes);
00045
00046 DREAL* vector_y = new double[num_data];
00047 for (int i=0; i<num_data; i++)
00048 vector_y[i]= labels->get_label(i)+1;
00049
00050 DREAL C = get_C1();
00051 INT tmax = 1000000000;
00052 DREAL tolabs = 0;
00053 DREAL tolrel = epsilon;
00054
00055 DREAL reg_const=0;
00056 if( C!=0 )
00057 reg_const = 1/(2*C);
00058
00059
00060 DREAL* alpha = new DREAL[num_virtual_data];
00061 DREAL* vector_c = new DREAL[num_virtual_data];
00062 memset(vector_c, 0, num_virtual_data*sizeof(DREAL));
00063
00064 DREAL thlb = 10000000000.0;
00065 INT t = 0;
00066 DREAL* History = NULL;
00067 INT verb = 0;
00068
00069 CGMNPLib mnp(vector_y,kernel,num_data, num_virtual_data, num_classes, reg_const);
00070
00071 mnp.gmnp_imdm(vector_c, num_virtual_data, tmax,
00072 tolabs, tolrel, thlb, alpha, &t, &History, verb );
00073
00074
00075 DREAL* all_alphas= new DREAL[num_classes*num_data];
00076 memset(all_alphas,0,num_classes*num_data*sizeof(DREAL));
00077
00078
00079 DREAL* all_bs=new DREAL[num_classes];
00080 memset(all_bs,0,num_classes*sizeof(DREAL));
00081
00082
00083 for(INT i=0; i < num_classes; i++ )
00084 {
00085 for(INT j=0; j < num_virtual_data; j++ )
00086 {
00087 INT inx1=0;
00088 INT inx2=0;
00089
00090 mnp.get_indices2( &inx1, &inx2, j );
00091
00092 all_alphas[(inx1*num_classes)+i] +=
00093 alpha[j]*(KDELTA(vector_y[inx1],i+1)+KDELTA(i+1,inx2));
00094 all_bs[i] += alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2));
00095 }
00096 }
00097
00098 create_multiclass_svm(num_classes);
00099
00100 for (INT i=0; i<num_classes; i++)
00101 {
00102 INT num_sv=0;
00103 for (INT j=0; j<num_data; j++)
00104 {
00105 if (all_alphas[j*num_classes+i] != 0)
00106 num_sv++;
00107 }
00108 ASSERT(num_sv>0);
00109 SG_DEBUG("svm[%d] has %d sv, b=%f\n", i, num_sv, all_bs[i]);
00110
00111 CSVM* svm=new CSVM(num_sv);
00112
00113 INT k=0;
00114 for (INT j=0; j<num_data; j++)
00115 {
00116 if (all_alphas[j*num_classes+i] != 0)
00117 {
00118 if (i==vector_y[j]-1)
00119 svm->set_alpha(k, all_alphas[j*num_classes+i]);
00120 else
00121 svm->set_alpha(k, -all_alphas[j*num_classes+i]);
00122
00123 svm->set_support_vector(k, j);
00124 k++;
00125 }
00126 }
00127
00128 svm->set_bias(all_bs[i]);
00129 set_svm(i, svm);
00130 }
00131
00132 delete[] vector_c;
00133 delete[] alpha;
00134 delete[] all_alphas;
00135 delete[] all_bs;
00136 delete[] vector_y;
00137 delete[] History;
00138
00139 return true;
00140 }