MultiClassSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Soeren Sonnenburg
00008  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "classifier/svm/MultiClassSVM.h"
00014 
00015 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type)
00016 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00017 {
00018 }
00019 
00020 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type, DREAL C, CKernel* k, CLabels* lab)
00021 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00022 {
00023 }
00024 
00025 CMultiClassSVM::~CMultiClassSVM()
00026 {
00027     cleanup();
00028 }
00029 
00030 void CMultiClassSVM::cleanup()
00031 {
00032     for (INT i=0; i<m_num_svms; i++)
00033         delete m_svms[i];
00034     delete[] m_svms;
00035 
00036     m_num_svms=0;
00037     m_svms=NULL;
00038 }
00039 
00040 bool CMultiClassSVM::create_multiclass_svm(INT num_classes)
00041 {
00042     if (num_classes>0)
00043     {
00044         m_num_classes=num_classes;
00045 
00046         if (multiclass_type==ONE_VS_REST)
00047             m_num_svms=num_classes;
00048         else if (multiclass_type==ONE_VS_ONE)
00049             m_num_svms=num_classes*(num_classes-1)/2;
00050         else
00051             SG_ERROR("unknown multiclass type\n");
00052 
00053         m_svms=new CSVM*[m_num_svms];
00054         if (m_svms)
00055         {
00056             memset(m_svms,0, m_num_svms*sizeof(CSVM*));
00057             return true;
00058         }
00059     }
00060     return false;
00061 }
00062 
00063 bool CMultiClassSVM::set_svm(INT num, CSVM* svm)
00064 {
00065     if (m_num_svms>0 && m_num_svms>num && num>=0 && svm)
00066     {
00067         m_svms[num]=svm;
00068         return true;
00069     }
00070     return false;
00071 }
00072 
00073 CLabels* CMultiClassSVM::classify(CLabels* result)
00074 {
00075     if (multiclass_type==ONE_VS_REST)
00076         return classify_one_vs_rest(result);
00077     else if (multiclass_type==ONE_VS_ONE)
00078         return classify_one_vs_one(result);
00079     else
00080         SG_ERROR("unknown multiclass type\n");
00081 
00082     return NULL;
00083 }
00084 
00085 CLabels* CMultiClassSVM::classify_one_vs_one(CLabels* result)
00086 {
00087     ASSERT(m_num_svms>0);
00088     ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00089 
00090     if (!kernel)
00091     {
00092         SG_ERROR( "SVM can not proceed without kernel!\n");
00093         return false ;
00094     }
00095 
00096     if ( kernel && kernel->has_features() && kernel->get_num_vec_rhs())
00097     {
00098         INT num_vectors=kernel->get_num_vec_rhs();
00099 
00100         if (!result)
00101             result=new CLabels(num_vectors);
00102 
00103         ASSERT(num_vectors==result->get_num_labels());
00104         CLabels** outputs=new CLabels*[m_num_svms];
00105 
00106         for (INT i=0; i<m_num_svms; i++)
00107         {
00108             SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]);
00109             ASSERT(m_svms[i]);
00110             m_svms[i]->set_kernel(kernel);
00111             m_svms[i]->set_labels(labels);
00112             outputs[i]=m_svms[i]->classify();
00113         }
00114 
00115         INT* votes=new INT[m_num_classes];
00116         for (INT v=0; v<num_vectors; v++)
00117         {
00118             INT s=0;
00119             memset(votes, 0, sizeof(INT)*m_num_classes);
00120 
00121             for (INT i=0; i<m_num_classes; i++)
00122             {
00123                 for (INT j=i+1; j<m_num_classes; j++)
00124                 {
00125                     if (outputs[s++]->get_label(v)>0)
00126                         votes[i]++;
00127                     else
00128                         votes[j]++;
00129                 }
00130             }
00131 
00132             INT winner=0;
00133             INT max_votes=votes[0];
00134 
00135             for (INT i=1; i<m_num_classes; i++)
00136             {
00137                 if (votes[i]>max_votes)
00138                 {
00139                     max_votes=votes[i];
00140                     winner=i;
00141                 }
00142             }
00143 
00144             result->set_label(v, winner);
00145         }
00146 
00147         delete[] votes;
00148 
00149         for (INT i=0; i<m_num_svms; i++)
00150             delete outputs[i];
00151         delete[] outputs;
00152     }
00153 
00154     return result;
00155 }
00156 
00157 CLabels* CMultiClassSVM::classify_one_vs_rest(CLabels* result)
00158 {
00159     ASSERT(m_num_svms>0);
00160 
00161     if (!kernel)
00162     {
00163         SG_ERROR( "SVM can not proceed without kernel!\n");
00164         return false ;
00165     }
00166 
00167     if ( kernel && kernel->has_features() && kernel->get_num_vec_rhs())
00168     {
00169         INT num_vectors=kernel->get_num_vec_rhs();
00170 
00171         if (!result)
00172             result=new CLabels(num_vectors);
00173 
00174         ASSERT(num_vectors==result->get_num_labels());
00175         CLabels** outputs=new CLabels*[m_num_svms];
00176 
00177         for (INT i=0; i<m_num_svms; i++)
00178         {
00179             ASSERT(m_svms[i]);
00180             m_svms[i]->set_kernel(kernel);
00181             m_svms[i]->set_labels(get_labels());
00182             outputs[i]=m_svms[i]->classify();
00183         }
00184 
00185         for (INT i=0; i<num_vectors; i++)
00186         {
00187             INT winner=0;
00188             DREAL max_out=outputs[0]->get_label(i);
00189 
00190             for (INT j=1; j<m_num_svms; j++)
00191             {
00192                 DREAL out=outputs[j]->get_label(i);
00193 
00194                 if (out>max_out)
00195                 {
00196                     winner=j;
00197                     max_out=out;
00198                 }
00199             }
00200 
00201             result->set_label(i, winner);
00202         }
00203 
00204         for (INT i=0; i<m_num_svms; i++)
00205             delete outputs[i];
00206         delete[] outputs;
00207     }
00208 
00209     return result;
00210 }
00211 
00212 DREAL CMultiClassSVM::classify_example(INT num)
00213 {
00214     if (multiclass_type==ONE_VS_REST)
00215         return classify_example_one_vs_rest(num);
00216     else if (multiclass_type==ONE_VS_ONE)
00217         return classify_example_one_vs_one(num);
00218     else
00219         SG_ERROR("unknown multiclass type\n");
00220 
00221     return 0;
00222 }
00223 
00224 DREAL CMultiClassSVM::classify_example_one_vs_rest(INT num)
00225 {
00226     ASSERT(m_num_svms>0);
00227     DREAL* outputs=new DREAL[m_num_svms];
00228     INT winner=0;
00229     DREAL max_out=m_svms[0]->classify_example(num);
00230 
00231     for (INT i=1; i<m_num_svms; i++)
00232     {
00233         outputs[i]=m_svms[i]->classify_example(num);
00234         if (outputs[i]>max_out)
00235         {
00236             winner=i;
00237             max_out=outputs[i];
00238         }
00239     }
00240     delete[] outputs;
00241 
00242     return winner;
00243 }
00244 
00245 DREAL CMultiClassSVM::classify_example_one_vs_one(INT num)
00246 {
00247     ASSERT(m_num_svms>0);
00248     ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00249 
00250     INT* votes=new INT[m_num_classes];
00251     INT s=0;
00252 
00253     for (INT i=0; i<m_num_classes; i++)
00254     {
00255         for (INT j=i+1; j<m_num_classes; j++)
00256         {
00257             if (m_svms[s++]->classify_example(num)>0)
00258                 votes[i]++;
00259             else
00260                 votes[j]++;
00261         }
00262     }
00263 
00264     INT winner=0;
00265     INT max_votes=votes[0];
00266 
00267     for (INT i=1; i<m_num_classes; i++)
00268     {
00269         if (votes[i]>max_votes)
00270         {
00271             max_votes=votes[i];
00272             winner=i;
00273         }
00274     }
00275 
00276     delete[] votes;
00277 
00278     return winner;
00279 }

SHOGUN Machine Learning Toolbox - Documentation