00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "classifier/svm/MultiClassSVM.h"
00014
00015 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type)
00016 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00017 {
00018 }
00019
00020 CMultiClassSVM::CMultiClassSVM(
00021 EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab)
00022 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00023 {
00024 }
00025
00026 CMultiClassSVM::~CMultiClassSVM()
00027 {
00028 cleanup();
00029 }
00030
00031 void CMultiClassSVM::cleanup()
00032 {
00033 for (int32_t i=0; i<m_num_svms; i++)
00034 delete m_svms[i];
00035 delete[] m_svms;
00036
00037 m_num_svms=0;
00038 m_svms=NULL;
00039 }
00040
00041 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes)
00042 {
00043 if (num_classes>0)
00044 {
00045 m_num_classes=num_classes;
00046
00047 if (multiclass_type==ONE_VS_REST)
00048 m_num_svms=num_classes;
00049 else if (multiclass_type==ONE_VS_ONE)
00050 m_num_svms=num_classes*(num_classes-1)/2;
00051 else
00052 SG_ERROR("unknown multiclass type\n");
00053
00054 m_svms=new CSVM*[m_num_svms];
00055 if (m_svms)
00056 {
00057 memset(m_svms,0, m_num_svms*sizeof(CSVM*));
00058 return true;
00059 }
00060 }
00061 return false;
00062 }
00063
00064 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm)
00065 {
00066 if (m_num_svms>0 && m_num_svms>num && num>=0 && svm)
00067 {
00068 m_svms[num]=svm;
00069 return true;
00070 }
00071 return false;
00072 }
00073
00074 CLabels* CMultiClassSVM::classify(CLabels* result)
00075 {
00076 if (multiclass_type==ONE_VS_REST)
00077 return classify_one_vs_rest(result);
00078 else if (multiclass_type==ONE_VS_ONE)
00079 return classify_one_vs_one(result);
00080 else
00081 SG_ERROR("unknown multiclass type\n");
00082
00083 return NULL;
00084 }
00085
00086 CLabels* CMultiClassSVM::classify_one_vs_one(CLabels* result)
00087 {
00088 ASSERT(m_num_svms>0);
00089 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00090
00091 if (!kernel)
00092 {
00093 SG_ERROR( "SVM can not proceed without kernel!\n");
00094 return false ;
00095 }
00096
00097 if ( kernel && kernel->has_features() && kernel->get_num_vec_rhs())
00098 {
00099 int32_t num_vectors=kernel->get_num_vec_rhs();
00100
00101 if (!result)
00102 result=new CLabels(num_vectors);
00103
00104 ASSERT(num_vectors==result->get_num_labels());
00105 CLabels** outputs=new CLabels*[m_num_svms];
00106
00107 for (int32_t i=0; i<m_num_svms; i++)
00108 {
00109 SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]);
00110 ASSERT(m_svms[i]);
00111 m_svms[i]->set_kernel(kernel);
00112 m_svms[i]->set_labels(labels);
00113 outputs[i]=m_svms[i]->classify();
00114 }
00115
00116 int32_t* votes=new int32_t[m_num_classes];
00117 for (int32_t v=0; v<num_vectors; v++)
00118 {
00119 int32_t s=0;
00120 memset(votes, 0, sizeof(int32_t)*m_num_classes);
00121
00122 for (int32_t i=0; i<m_num_classes; i++)
00123 {
00124 for (int32_t j=i+1; j<m_num_classes; j++)
00125 {
00126 if (outputs[s++]->get_label(v)>0)
00127 votes[i]++;
00128 else
00129 votes[j]++;
00130 }
00131 }
00132
00133 int32_t winner=0;
00134 int32_t max_votes=votes[0];
00135
00136 for (int32_t i=1; i<m_num_classes; i++)
00137 {
00138 if (votes[i]>max_votes)
00139 {
00140 max_votes=votes[i];
00141 winner=i;
00142 }
00143 }
00144
00145 result->set_label(v, winner);
00146 }
00147
00148 delete[] votes;
00149
00150 for (int32_t i=0; i<m_num_svms; i++)
00151 delete outputs[i];
00152 delete[] outputs;
00153 }
00154
00155 return result;
00156 }
00157
00158 CLabels* CMultiClassSVM::classify_one_vs_rest(CLabels* result)
00159 {
00160 ASSERT(m_num_svms>0);
00161
00162 if (!kernel)
00163 {
00164 SG_ERROR( "SVM can not proceed without kernel!\n");
00165 return false ;
00166 }
00167
00168 if ( kernel && kernel->has_features() && kernel->get_num_vec_rhs())
00169 {
00170 int32_t num_vectors=kernel->get_num_vec_rhs();
00171
00172 if (!result)
00173 result=new CLabels(num_vectors);
00174
00175 ASSERT(num_vectors==result->get_num_labels());
00176 CLabels** outputs=new CLabels*[m_num_svms];
00177
00178 for (int32_t i=0; i<m_num_svms; i++)
00179 {
00180 ASSERT(m_svms[i]);
00181 m_svms[i]->set_kernel(kernel);
00182 m_svms[i]->set_labels(get_labels());
00183 outputs[i]=m_svms[i]->classify();
00184 }
00185
00186 for (int32_t i=0; i<num_vectors; i++)
00187 {
00188 int32_t winner=0;
00189 float64_t max_out=outputs[0]->get_label(i);
00190
00191 for (int32_t j=1; j<m_num_svms; j++)
00192 {
00193 float64_t out=outputs[j]->get_label(i);
00194
00195 if (out>max_out)
00196 {
00197 winner=j;
00198 max_out=out;
00199 }
00200 }
00201
00202 result->set_label(i, winner);
00203 }
00204
00205 for (int32_t i=0; i<m_num_svms; i++)
00206 delete outputs[i];
00207 delete[] outputs;
00208 }
00209
00210 return result;
00211 }
00212
00213 float64_t CMultiClassSVM::classify_example(int32_t num)
00214 {
00215 if (multiclass_type==ONE_VS_REST)
00216 return classify_example_one_vs_rest(num);
00217 else if (multiclass_type==ONE_VS_ONE)
00218 return classify_example_one_vs_one(num);
00219 else
00220 SG_ERROR("unknown multiclass type\n");
00221
00222 return 0;
00223 }
00224
00225 float64_t CMultiClassSVM::classify_example_one_vs_rest(int32_t num)
00226 {
00227 ASSERT(m_num_svms>0);
00228 float64_t* outputs=new float64_t[m_num_svms];
00229 int32_t winner=0;
00230 float64_t max_out=m_svms[0]->classify_example(num);
00231
00232 for (int32_t i=1; i<m_num_svms; i++)
00233 {
00234 outputs[i]=m_svms[i]->classify_example(num);
00235 if (outputs[i]>max_out)
00236 {
00237 winner=i;
00238 max_out=outputs[i];
00239 }
00240 }
00241 delete[] outputs;
00242
00243 return winner;
00244 }
00245
00246 float64_t CMultiClassSVM::classify_example_one_vs_one(int32_t num)
00247 {
00248 ASSERT(m_num_svms>0);
00249 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00250
00251 int32_t* votes=new int32_t[m_num_classes];
00252 int32_t s=0;
00253
00254 for (int32_t i=0; i<m_num_classes; i++)
00255 {
00256 for (int32_t j=i+1; j<m_num_classes; j++)
00257 {
00258 if (m_svms[s++]->classify_example(num)>0)
00259 votes[i]++;
00260 else
00261 votes[j]++;
00262 }
00263 }
00264
00265 int32_t winner=0;
00266 int32_t max_votes=votes[0];
00267
00268 for (int32_t i=1; i<m_num_classes; i++)
00269 {
00270 if (votes[i]>max_votes)
00271 {
00272 max_votes=votes[i];
00273 winner=i;
00274 }
00275 }
00276
00277 delete[] votes;
00278
00279 return winner;
00280 }
00281
00282 bool CMultiClassSVM::load(FILE* modelfl)
00283 {
00284 bool result=true;
00285 char char_buffer[1024];
00286 int32_t int_buffer;
00287 float64_t double_buffer;
00288 int32_t line_number=1;
00289 int32_t svm_idx=-1;
00290
00291 if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
00292 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00293 else
00294 {
00295 char_buffer[15]='\0';
00296 if (strcmp("%MultiClassSVM", char_buffer)!=0)
00297 SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
00298
00299 line_number++;
00300 }
00301
00302 int_buffer=0;
00303 if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1)
00304 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00305
00306 if (!feof(modelfl))
00307 line_number++;
00308
00309 if (int_buffer != multiclass_type)
00310 SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type);
00311
00312 int_buffer=0;
00313 if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
00314 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00315
00316 if (!feof(modelfl))
00317 line_number++;
00318
00319 if (int_buffer < 2)
00320 SG_ERROR("less than 2 classes - how is this multiclass?\n");
00321
00322 create_multiclass_svm(int_buffer);
00323
00324 int_buffer=0;
00325 if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
00326 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00327
00328 if (!feof(modelfl))
00329 line_number++;
00330
00331 if (m_num_svms != int_buffer)
00332 SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer);
00333
00334 if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00335 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00336
00337 if (!feof(modelfl))
00338 line_number++;
00339
00340 for (int32_t n=0; n<m_num_svms; n++)
00341 {
00342 svm_idx=-1;
00343 if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
00344 {
00345 result=false;
00346 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00347 }
00348 else
00349 {
00350 char_buffer[4]='\0';
00351 if (strncmp("%SVM", char_buffer, 4)!=0)
00352 {
00353 result=false;
00354 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00355 }
00356
00357 if (svm_idx != n)
00358 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00359
00360 line_number++;
00361 }
00362
00363 int_buffer=0;
00364 if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
00365 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00366
00367 if (svm_idx != n)
00368 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00369
00370 if (!feof(modelfl))
00371 line_number++;
00372
00373 SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
00374 CSVM* svm=new CSVM(int_buffer);
00375
00376 double_buffer=0;
00377
00378 if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
00379 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00380
00381 if (svm_idx != n)
00382 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00383
00384 if (!feof(modelfl))
00385 line_number++;
00386
00387 svm->set_bias(double_buffer);
00388
00389 if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
00390 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00391
00392 if (svm_idx != n)
00393 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00394
00395 if (!feof(modelfl))
00396 line_number++;
00397
00398 for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
00399 {
00400 double_buffer=0;
00401 int_buffer=0;
00402
00403 if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00404 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00405
00406 if (!feof(modelfl))
00407 line_number++;
00408
00409 svm->set_support_vector(i, int_buffer);
00410 svm->set_alpha(i, double_buffer);
00411 }
00412
00413 if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00414 {
00415 result=false;
00416 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00417 }
00418 else
00419 {
00420 char_buffer[3]='\0';
00421 if (strcmp("];", char_buffer)!=0)
00422 {
00423 result=false;
00424 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00425 }
00426 line_number++;
00427 }
00428
00429 set_svm(n, svm);
00430 }
00431
00432 svm_loaded=result;
00433 return result;
00434 }
00435
00436 bool CMultiClassSVM::save(FILE* modelfl)
00437 {
00438 if (!kernel)
00439 SG_ERROR("Kernel not defined!\n");
00440
00441 if (!m_svms || m_num_svms<1 || m_num_classes <=2)
00442 SG_ERROR("Multiclass SVM not trained!\n");
00443
00444 SG_INFO( "Writing model file...");
00445 fprintf(modelfl,"%%MultiClassSVM\n");
00446 fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type);
00447 fprintf(modelfl,"num_classes=%d;\n", m_num_classes);
00448 fprintf(modelfl,"num_svms=%d;\n", m_num_svms);
00449 fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
00450
00451 for (int32_t i=0; i<m_num_svms; i++)
00452 {
00453 CSVM* svm=m_svms[i];
00454 ASSERT(svm);
00455 fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1);
00456 fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
00457 fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
00458
00459 fprintf(modelfl, "alphas%d=[\n", i);
00460
00461 for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
00462 {
00463 fprintf(modelfl,"\t[%+10.16e,%d];\n",
00464 svm->get_alpha(j), svm->get_support_vector(j));
00465 }
00466
00467 fprintf(modelfl, "];\n");
00468 }
00469
00470 SG_DONE();
00471 return true ;
00472 }