MultiClassSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Soeren Sonnenburg
00008  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "classifier/svm/MultiClassSVM.h"
00014 
00015 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type)
00016 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00017 {
00018 }
00019 
00020 CMultiClassSVM::CMultiClassSVM(
00021     EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab)
00022 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00023 {
00024 }
00025 
00026 CMultiClassSVM::~CMultiClassSVM()
00027 {
00028     cleanup();
00029 }
00030 
00031 void CMultiClassSVM::cleanup()
00032 {
00033     for (int32_t i=0; i<m_num_svms; i++)
00034         delete m_svms[i];
00035     delete[] m_svms;
00036 
00037     m_num_svms=0;
00038     m_svms=NULL;
00039 }
00040 
00041 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes)
00042 {
00043     if (num_classes>0)
00044     {
00045         m_num_classes=num_classes;
00046 
00047         if (multiclass_type==ONE_VS_REST)
00048             m_num_svms=num_classes;
00049         else if (multiclass_type==ONE_VS_ONE)
00050             m_num_svms=num_classes*(num_classes-1)/2;
00051         else
00052             SG_ERROR("unknown multiclass type\n");
00053 
00054         m_svms=new CSVM*[m_num_svms];
00055         if (m_svms)
00056         {
00057             memset(m_svms,0, m_num_svms*sizeof(CSVM*));
00058             return true;
00059         }
00060     }
00061     return false;
00062 }
00063 
00064 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm)
00065 {
00066     if (m_num_svms>0 && m_num_svms>num && num>=0 && svm)
00067     {
00068         m_svms[num]=svm;
00069         return true;
00070     }
00071     return false;
00072 }
00073 
00074 CLabels* CMultiClassSVM::classify(CLabels* result)
00075 {
00076     if (multiclass_type==ONE_VS_REST)
00077         return classify_one_vs_rest(result);
00078     else if (multiclass_type==ONE_VS_ONE)
00079         return classify_one_vs_one(result);
00080     else
00081         SG_ERROR("unknown multiclass type\n");
00082 
00083     return NULL;
00084 }
00085 
00086 CLabels* CMultiClassSVM::classify_one_vs_one(CLabels* result)
00087 {
00088     ASSERT(m_num_svms>0);
00089     ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00090 
00091     if (!kernel)
00092     {
00093         SG_ERROR( "SVM can not proceed without kernel!\n");
00094         return false ;
00095     }
00096 
00097     if ( kernel && kernel->has_features() && kernel->get_num_vec_rhs())
00098     {
00099         int32_t num_vectors=kernel->get_num_vec_rhs();
00100 
00101         if (!result)
00102             result=new CLabels(num_vectors);
00103 
00104         ASSERT(num_vectors==result->get_num_labels());
00105         CLabels** outputs=new CLabels*[m_num_svms];
00106 
00107         for (int32_t i=0; i<m_num_svms; i++)
00108         {
00109             SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]);
00110             ASSERT(m_svms[i]);
00111             m_svms[i]->set_kernel(kernel);
00112             m_svms[i]->set_labels(labels);
00113             outputs[i]=m_svms[i]->classify();
00114         }
00115 
00116         int32_t* votes=new int32_t[m_num_classes];
00117         for (int32_t v=0; v<num_vectors; v++)
00118         {
00119             int32_t s=0;
00120             memset(votes, 0, sizeof(int32_t)*m_num_classes);
00121 
00122             for (int32_t i=0; i<m_num_classes; i++)
00123             {
00124                 for (int32_t j=i+1; j<m_num_classes; j++)
00125                 {
00126                     if (outputs[s++]->get_label(v)>0)
00127                         votes[i]++;
00128                     else
00129                         votes[j]++;
00130                 }
00131             }
00132 
00133             int32_t winner=0;
00134             int32_t max_votes=votes[0];
00135 
00136             for (int32_t i=1; i<m_num_classes; i++)
00137             {
00138                 if (votes[i]>max_votes)
00139                 {
00140                     max_votes=votes[i];
00141                     winner=i;
00142                 }
00143             }
00144 
00145             result->set_label(v, winner);
00146         }
00147 
00148         delete[] votes;
00149 
00150         for (int32_t i=0; i<m_num_svms; i++)
00151             delete outputs[i];
00152         delete[] outputs;
00153     }
00154 
00155     return result;
00156 }
00157 
00158 CLabels* CMultiClassSVM::classify_one_vs_rest(CLabels* result)
00159 {
00160     ASSERT(m_num_svms>0);
00161 
00162     if (!kernel)
00163     {
00164         SG_ERROR( "SVM can not proceed without kernel!\n");
00165         return false ;
00166     }
00167 
00168     if ( kernel && kernel->has_features() && kernel->get_num_vec_rhs())
00169     {
00170         int32_t num_vectors=kernel->get_num_vec_rhs();
00171 
00172         if (!result)
00173             result=new CLabels(num_vectors);
00174 
00175         ASSERT(num_vectors==result->get_num_labels());
00176         CLabels** outputs=new CLabels*[m_num_svms];
00177 
00178         for (int32_t i=0; i<m_num_svms; i++)
00179         {
00180             ASSERT(m_svms[i]);
00181             m_svms[i]->set_kernel(kernel);
00182             m_svms[i]->set_labels(get_labels());
00183             outputs[i]=m_svms[i]->classify();
00184         }
00185 
00186         for (int32_t i=0; i<num_vectors; i++)
00187         {
00188             int32_t winner=0;
00189             float64_t max_out=outputs[0]->get_label(i);
00190 
00191             for (int32_t j=1; j<m_num_svms; j++)
00192             {
00193                 float64_t out=outputs[j]->get_label(i);
00194 
00195                 if (out>max_out)
00196                 {
00197                     winner=j;
00198                     max_out=out;
00199                 }
00200             }
00201 
00202             result->set_label(i, winner);
00203         }
00204 
00205         for (int32_t i=0; i<m_num_svms; i++)
00206             delete outputs[i];
00207         delete[] outputs;
00208     }
00209 
00210     return result;
00211 }
00212 
00213 float64_t CMultiClassSVM::classify_example(int32_t num)
00214 {
00215     if (multiclass_type==ONE_VS_REST)
00216         return classify_example_one_vs_rest(num);
00217     else if (multiclass_type==ONE_VS_ONE)
00218         return classify_example_one_vs_one(num);
00219     else
00220         SG_ERROR("unknown multiclass type\n");
00221 
00222     return 0;
00223 }
00224 
00225 float64_t CMultiClassSVM::classify_example_one_vs_rest(int32_t num)
00226 {
00227     ASSERT(m_num_svms>0);
00228     float64_t* outputs=new float64_t[m_num_svms];
00229     int32_t winner=0;
00230     float64_t max_out=m_svms[0]->classify_example(num);
00231 
00232     for (int32_t i=1; i<m_num_svms; i++)
00233     {
00234         outputs[i]=m_svms[i]->classify_example(num);
00235         if (outputs[i]>max_out)
00236         {
00237             winner=i;
00238             max_out=outputs[i];
00239         }
00240     }
00241     delete[] outputs;
00242 
00243     return winner;
00244 }
00245 
00246 float64_t CMultiClassSVM::classify_example_one_vs_one(int32_t num)
00247 {
00248     ASSERT(m_num_svms>0);
00249     ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00250 
00251     int32_t* votes=new int32_t[m_num_classes];
00252     int32_t s=0;
00253 
00254     for (int32_t i=0; i<m_num_classes; i++)
00255     {
00256         for (int32_t j=i+1; j<m_num_classes; j++)
00257         {
00258             if (m_svms[s++]->classify_example(num)>0)
00259                 votes[i]++;
00260             else
00261                 votes[j]++;
00262         }
00263     }
00264 
00265     int32_t winner=0;
00266     int32_t max_votes=votes[0];
00267 
00268     for (int32_t i=1; i<m_num_classes; i++)
00269     {
00270         if (votes[i]>max_votes)
00271         {
00272             max_votes=votes[i];
00273             winner=i;
00274         }
00275     }
00276 
00277     delete[] votes;
00278 
00279     return winner;
00280 }
00281 
00282 bool CMultiClassSVM::load(FILE* modelfl)
00283 {
00284     bool result=true;
00285     char char_buffer[1024];
00286     int32_t int_buffer;
00287     float64_t double_buffer;
00288     int32_t line_number=1;
00289     int32_t svm_idx=-1;
00290 
00291     if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
00292         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00293     else
00294     {
00295         char_buffer[15]='\0';
00296         if (strcmp("%MultiClassSVM", char_buffer)!=0)
00297             SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
00298 
00299         line_number++;
00300     }
00301 
00302     int_buffer=0;
00303     if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1)
00304         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00305 
00306     if (!feof(modelfl))
00307         line_number++;
00308 
00309     if (int_buffer != multiclass_type)
00310         SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type);
00311 
00312     int_buffer=0;
00313     if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
00314         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00315 
00316     if (!feof(modelfl))
00317         line_number++;
00318 
00319     if (int_buffer < 2)
00320         SG_ERROR("less than 2 classes - how is this multiclass?\n");
00321 
00322     create_multiclass_svm(int_buffer);
00323 
00324     int_buffer=0;
00325     if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
00326         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00327 
00328     if (!feof(modelfl))
00329         line_number++;
00330 
00331     if (m_num_svms != int_buffer)
00332         SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer);
00333 
00334     if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00335         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00336 
00337     if (!feof(modelfl))
00338         line_number++;
00339 
00340     for (int32_t n=0; n<m_num_svms; n++)
00341     {
00342         svm_idx=-1;
00343         if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
00344         {
00345             result=false;
00346             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00347         }
00348         else
00349         {
00350             char_buffer[4]='\0';
00351             if (strncmp("%SVM", char_buffer, 4)!=0)
00352             {
00353                 result=false;
00354                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00355             }
00356 
00357             if (svm_idx != n)
00358                 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00359 
00360             line_number++;
00361         }
00362 
00363         int_buffer=0;
00364         if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
00365             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00366 
00367         if (svm_idx != n)
00368             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00369 
00370         if (!feof(modelfl))
00371             line_number++;
00372 
00373         SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
00374         CSVM* svm=new CSVM(int_buffer);
00375 
00376         double_buffer=0;
00377 
00378         if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
00379             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00380 
00381         if (svm_idx != n)
00382             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00383 
00384         if (!feof(modelfl))
00385             line_number++;
00386 
00387         svm->set_bias(double_buffer);
00388 
00389         if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
00390             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00391 
00392         if (svm_idx != n)
00393             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00394 
00395         if (!feof(modelfl))
00396             line_number++;
00397 
00398         for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
00399         {
00400             double_buffer=0;
00401             int_buffer=0;
00402 
00403             if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00404                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00405 
00406             if (!feof(modelfl))
00407                 line_number++;
00408 
00409             svm->set_support_vector(i, int_buffer);
00410             svm->set_alpha(i, double_buffer);
00411         }
00412 
00413         if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00414         {
00415             result=false;
00416             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00417         }
00418         else
00419         {
00420             char_buffer[3]='\0';
00421             if (strcmp("];", char_buffer)!=0)
00422             {
00423                 result=false;
00424                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00425             }
00426             line_number++;
00427         }
00428 
00429         set_svm(n, svm);
00430     }
00431 
00432     svm_loaded=result;
00433     return result;
00434 }
00435 
00436 bool CMultiClassSVM::save(FILE* modelfl)
00437 {
00438     if (!kernel)
00439         SG_ERROR("Kernel not defined!\n");
00440 
00441     if (!m_svms || m_num_svms<1 || m_num_classes <=2)
00442         SG_ERROR("Multiclass SVM not trained!\n");
00443 
00444     SG_INFO( "Writing model file...");
00445     fprintf(modelfl,"%%MultiClassSVM\n");
00446     fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type);
00447     fprintf(modelfl,"num_classes=%d;\n", m_num_classes);
00448     fprintf(modelfl,"num_svms=%d;\n", m_num_svms);
00449     fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
00450 
00451     for (int32_t i=0; i<m_num_svms; i++)
00452     {
00453         CSVM* svm=m_svms[i];
00454         ASSERT(svm);
00455         fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1);
00456         fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
00457         fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
00458 
00459         fprintf(modelfl, "alphas%d=[\n", i);
00460 
00461         for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
00462         {
00463             fprintf(modelfl,"\t[%+10.16e,%d];\n",
00464                     svm->get_alpha(j), svm->get_support_vector(j));
00465         }
00466 
00467         fprintf(modelfl, "];\n");
00468     }
00469 
00470     SG_DONE();
00471     return true ;
00472 } 

SHOGUN Machine Learning Toolbox - Documentation