MultiClassSVM.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "classifier/svm/MultiClassSVM.h"
00014
00015 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type)
00016 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00017 {
00018 }
00019
00020 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type, DREAL C, CKernel* k, CLabels* lab)
00021 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00022 {
00023 }
00024
00025 CMultiClassSVM::~CMultiClassSVM()
00026 {
00027 cleanup();
00028 }
00029
00030 void CMultiClassSVM::cleanup()
00031 {
00032 for (INT i=0; i<m_num_svms; i++)
00033 delete m_svms[i];
00034 delete[] m_svms;
00035
00036 m_num_svms=0;
00037 m_svms=NULL;
00038 }
00039
00040 bool CMultiClassSVM::create_multiclass_svm(INT num_classes)
00041 {
00042 if (num_classes>0)
00043 {
00044 m_num_classes=num_classes;
00045
00046 if (multiclass_type==ONE_VS_REST)
00047 m_num_svms=num_classes;
00048 else if (multiclass_type==ONE_VS_ONE)
00049 m_num_svms=num_classes*(num_classes-1)/2;
00050 else
00051 SG_ERROR("unknown multiclass type\n");
00052
00053 m_svms=new CSVM*[m_num_svms];
00054 if (m_svms)
00055 {
00056 memset(m_svms,0, m_num_svms*sizeof(CSVM*));
00057 return true;
00058 }
00059 }
00060 return false;
00061 }
00062
00063 bool CMultiClassSVM::set_svm(INT num, CSVM* svm)
00064 {
00065 if (m_num_svms>0 && m_num_svms>num && num>=0 && svm)
00066 {
00067 m_svms[num]=svm;
00068 return true;
00069 }
00070 return false;
00071 }
00072
00073 CLabels* CMultiClassSVM::classify(CLabels* result)
00074 {
00075 if (multiclass_type==ONE_VS_REST)
00076 return classify_one_vs_rest(result);
00077 else if (multiclass_type==ONE_VS_ONE)
00078 return classify_one_vs_one(result);
00079 else
00080 SG_ERROR("unknown multiclass type\n");
00081
00082 return NULL;
00083 }
00084
00085 CLabels* CMultiClassSVM::classify_one_vs_one(CLabels* result)
00086 {
00087 ASSERT(m_num_svms>0);
00088 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00089
00090 if (!kernel)
00091 {
00092 SG_ERROR( "SVM can not proceed without kernel!\n");
00093 return false ;
00094 }
00095
00096 if ( kernel && kernel->has_features() && kernel->get_num_vec_rhs())
00097 {
00098 INT num_vectors=kernel->get_num_vec_rhs();
00099
00100 if (!result)
00101 result=new CLabels(num_vectors);
00102
00103 ASSERT(num_vectors==result->get_num_labels());
00104 CLabels** outputs=new CLabels*[m_num_svms];
00105
00106 for (INT i=0; i<m_num_svms; i++)
00107 {
00108 SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]);
00109 ASSERT(m_svms[i]);
00110 m_svms[i]->set_kernel(kernel);
00111 m_svms[i]->set_labels(labels);
00112 outputs[i]=m_svms[i]->classify();
00113 }
00114
00115 INT* votes=new INT[m_num_classes];
00116 for (INT v=0; v<num_vectors; v++)
00117 {
00118 INT s=0;
00119 memset(votes, 0, sizeof(INT)*m_num_classes);
00120
00121 for (INT i=0; i<m_num_classes; i++)
00122 {
00123 for (INT j=i+1; j<m_num_classes; j++)
00124 {
00125 if (outputs[s++]->get_label(v)>0)
00126 votes[i]++;
00127 else
00128 votes[j]++;
00129 }
00130 }
00131
00132 INT winner=0;
00133 INT max_votes=votes[0];
00134
00135 for (INT i=1; i<m_num_classes; i++)
00136 {
00137 if (votes[i]>max_votes)
00138 {
00139 max_votes=votes[i];
00140 winner=i;
00141 }
00142 }
00143
00144 result->set_label(v, winner);
00145 }
00146
00147 delete[] votes;
00148
00149 for (INT i=0; i<m_num_svms; i++)
00150 delete outputs[i];
00151 delete[] outputs;
00152 }
00153
00154 return result;
00155 }
00156
00157 CLabels* CMultiClassSVM::classify_one_vs_rest(CLabels* result)
00158 {
00159 ASSERT(m_num_svms>0);
00160
00161 if (!kernel)
00162 {
00163 SG_ERROR( "SVM can not proceed without kernel!\n");
00164 return false ;
00165 }
00166
00167 if ( kernel && kernel->has_features() && kernel->get_num_vec_rhs())
00168 {
00169 INT num_vectors=kernel->get_num_vec_rhs();
00170
00171 if (!result)
00172 result=new CLabels(num_vectors);
00173
00174 ASSERT(num_vectors==result->get_num_labels());
00175 CLabels** outputs=new CLabels*[m_num_svms];
00176
00177 for (INT i=0; i<m_num_svms; i++)
00178 {
00179 ASSERT(m_svms[i]);
00180 m_svms[i]->set_kernel(kernel);
00181 m_svms[i]->set_labels(get_labels());
00182 outputs[i]=m_svms[i]->classify();
00183 }
00184
00185 for (INT i=0; i<num_vectors; i++)
00186 {
00187 INT winner=0;
00188 DREAL max_out=outputs[0]->get_label(i);
00189
00190 for (INT j=1; j<m_num_svms; j++)
00191 {
00192 DREAL out=outputs[j]->get_label(i);
00193
00194 if (out>max_out)
00195 {
00196 winner=j;
00197 max_out=out;
00198 }
00199 }
00200
00201 result->set_label(i, winner);
00202 }
00203
00204 for (INT i=0; i<m_num_svms; i++)
00205 delete outputs[i];
00206 delete[] outputs;
00207 }
00208
00209 return result;
00210 }
00211
00212 DREAL CMultiClassSVM::classify_example(INT num)
00213 {
00214 if (multiclass_type==ONE_VS_REST)
00215 return classify_example_one_vs_rest(num);
00216 else if (multiclass_type==ONE_VS_ONE)
00217 return classify_example_one_vs_one(num);
00218 else
00219 SG_ERROR("unknown multiclass type\n");
00220
00221 return 0;
00222 }
00223
00224 DREAL CMultiClassSVM::classify_example_one_vs_rest(INT num)
00225 {
00226 ASSERT(m_num_svms>0);
00227 DREAL* outputs=new DREAL[m_num_svms];
00228 INT winner=0;
00229 DREAL max_out=m_svms[0]->classify_example(num);
00230
00231 for (INT i=1; i<m_num_svms; i++)
00232 {
00233 outputs[i]=m_svms[i]->classify_example(num);
00234 if (outputs[i]>max_out)
00235 {
00236 winner=i;
00237 max_out=outputs[i];
00238 }
00239 }
00240 delete[] outputs;
00241
00242 return winner;
00243 }
00244
00245 DREAL CMultiClassSVM::classify_example_one_vs_one(INT num)
00246 {
00247 ASSERT(m_num_svms>0);
00248 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00249
00250 INT* votes=new INT[m_num_classes];
00251 INT s=0;
00252
00253 for (INT i=0; i<m_num_classes; i++)
00254 {
00255 for (INT j=i+1; j<m_num_classes; j++)
00256 {
00257 if (m_svms[s++]->classify_example(num)>0)
00258 votes[i]++;
00259 else
00260 votes[j]++;
00261 }
00262 }
00263
00264 INT winner=0;
00265 INT max_votes=votes[0];
00266
00267 for (INT i=1; i<m_num_classes; i++)
00268 {
00269 if (votes[i]>max_votes)
00270 {
00271 max_votes=votes[i];
00272 winner=i;
00273 }
00274 }
00275
00276 delete[] votes;
00277
00278 return winner;
00279 }