00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/io.h"
00012 #include "classifier/svm/GMNPSVM.h"
00013 #include "classifier/svm/gmnplib.h"
00014
00015 #define INDEX(ROW,COL,DIM) (((COL)*(DIM))+(ROW))
00016 #define MINUS_INF INT_MIN
00017 #define PLUS_INF INT_MAX
00018 #define KDELTA(A,B) (A==B)
00019 #define KDELTA4(A1,A2,A3,A4) ((A1==A2)||(A1==A3)||(A1==A4)||(A2==A3)||(A2==A4)||(A3==A4))
00020
00021 CGMNPSVM::CGMNPSVM()
00022 : CMultiClassSVM(ONE_VS_REST)
00023 {
00024 }
00025
00026 CGMNPSVM::CGMNPSVM(float64_t C, CKernel* k, CLabels* lab)
00027 : CMultiClassSVM(ONE_VS_REST, C, k, lab)
00028 {
00029 }
00030
00031 CGMNPSVM::~CGMNPSVM()
00032 {
00033 }
00034
00035 bool CGMNPSVM::train()
00036 {
00037 ASSERT(kernel);
00038 ASSERT(labels && labels->get_num_labels());
00039
00040 int32_t num_data = labels->get_num_labels();
00041 int32_t num_classes = labels->get_num_classes();
00042 int32_t num_virtual_data= num_data*(num_classes-1);
00043
00044 SG_INFO( "%d trainlabels, %d classes\n", num_data, num_classes);
00045
00046 float64_t* vector_y = new float64_t[num_data];
00047 for (int32_t i=0; i<num_data; i++)
00048 vector_y[i]= labels->get_label(i)+1;
00049
00050 float64_t C = get_C1();
00051 int32_t tmax = 1000000000;
00052 float64_t tolabs = 0;
00053 float64_t tolrel = epsilon;
00054
00055 float64_t reg_const=0;
00056 if( C!=0 )
00057 reg_const = 1/(2*C);
00058
00059
00060 float64_t* alpha = new float64_t[num_virtual_data];
00061 float64_t* vector_c = new float64_t[num_virtual_data];
00062 memset(vector_c, 0, num_virtual_data*sizeof(float64_t));
00063
00064 float64_t thlb = 10000000000.0;
00065 int32_t t = 0;
00066 float64_t* History = NULL;
00067 int32_t verb = 0;
00068
00069 CGMNPLib mnp(vector_y,kernel,num_data, num_virtual_data, num_classes, reg_const);
00070
00071 mnp.gmnp_imdm(vector_c, num_virtual_data, tmax,
00072 tolabs, tolrel, thlb, alpha, &t, &History, verb );
00073
00074
00075 float64_t* all_alphas= new float64_t[num_classes*num_data];
00076 memset(all_alphas,0,num_classes*num_data*sizeof(float64_t));
00077
00078
00079 float64_t* all_bs=new float64_t[num_classes];
00080 memset(all_bs,0,num_classes*sizeof(float64_t));
00081
00082
00083 for(int32_t i=0; i < num_classes; i++ )
00084 {
00085 for(int32_t j=0; j < num_virtual_data; j++ )
00086 {
00087 int32_t inx1=0;
00088 int32_t inx2=0;
00089
00090 mnp.get_indices2( &inx1, &inx2, j );
00091
00092 all_alphas[(inx1*num_classes)+i] +=
00093 alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2));
00094 all_bs[i] += alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2));
00095 }
00096 }
00097
00098 create_multiclass_svm(num_classes);
00099
00100 for (int32_t i=0; i<num_classes; i++)
00101 {
00102 int32_t num_sv=0;
00103 for (int32_t j=0; j<num_data; j++)
00104 {
00105 if (all_alphas[j*num_classes+i] != 0)
00106 num_sv++;
00107 }
00108 ASSERT(num_sv>0);
00109 SG_DEBUG("svm[%d] has %d sv, b=%f\n", i, num_sv, all_bs[i]);
00110
00111 CSVM* svm=new CSVM(num_sv);
00112
00113 int32_t k=0;
00114 for (int32_t j=0; j<num_data; j++)
00115 {
00116 if (all_alphas[j*num_classes+i] != 0)
00117 {
00118 svm->set_alpha(k, all_alphas[j*num_classes+i]);
00119 svm->set_support_vector(k, j);
00120 k++;
00121 }
00122 }
00123
00124 svm->set_bias(all_bs[i]);
00125 set_svm(i, svm);
00126 }
00127
00128 delete[] vector_c;
00129 delete[] alpha;
00130 delete[] all_alphas;
00131 delete[] all_bs;
00132 delete[] vector_y;
00133 delete[] History;
00134
00135 return true;
00136 }