LibSVMMultiClass.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "classifier/svm/LibSVMMultiClass.h"
00012 #include "lib/io.h"
00013
00014 CLibSVMMultiClass::CLibSVMMultiClass()
00015 : CMultiClassSVM(ONE_VS_ONE), model(NULL)
00016 {
00017 }
00018
00019 CLibSVMMultiClass::CLibSVMMultiClass(float64_t C, CKernel* k, CLabels* lab)
00020 : CMultiClassSVM(ONE_VS_ONE, C, k, lab), model(NULL)
00021 {
00022 }
00023
00024 CLibSVMMultiClass::~CLibSVMMultiClass()
00025 {
00026
00027 }
00028
00029 bool CLibSVMMultiClass::train()
00030 {
00031 struct svm_node* x_space;
00032
00033 ASSERT(labels && labels->get_num_labels());
00034 int32_t num_classes = labels->get_num_classes();
00035 problem.l=labels->get_num_labels();
00036 SG_INFO( "%d trainlabels, %d classes\n", problem.l, num_classes);
00037
00038 problem.y=new float64_t[problem.l];
00039 problem.x=new struct svm_node*[problem.l];
00040 x_space=new struct svm_node[2*problem.l];
00041
00042 for (int32_t i=0; i<problem.l; i++)
00043 {
00044 problem.y[i]=labels->get_label(i);
00045 problem.x[i]=&x_space[2*i];
00046 x_space[2*i].index=i;
00047 x_space[2*i+1].index=-1;
00048 }
00049
00050 ASSERT(kernel);
00051
00052 param.svm_type=C_SVC;
00053 param.kernel_type = LINEAR;
00054 param.degree = 3;
00055 param.gamma = 0;
00056 param.coef0 = 0;
00057 param.nu = 0.5;
00058 param.kernel=kernel;
00059 param.cache_size = kernel->get_cache_size();
00060 param.C = get_C1();
00061 param.eps = epsilon;
00062 param.p = 0.1;
00063 param.shrinking = 1;
00064 param.nr_weight = 0;
00065 param.weight_label = NULL;
00066 param.weight = NULL;
00067
00068 const char* error_msg = svm_check_parameter(&problem,¶m);
00069
00070 if(error_msg)
00071 {
00072 fprintf(stderr,"Error: %s\n",error_msg);
00073 exit(1);
00074 }
00075
00076 model = svm_train(&problem, ¶m);
00077
00078 if (model)
00079 {
00080 ASSERT(model->nr_class==num_classes);
00081 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef));
00082 create_multiclass_svm(num_classes);
00083
00084 int32_t* offsets=new int32_t[num_classes];
00085 offsets[0]=0;
00086
00087 for (int32_t i=1; i<num_classes; i++)
00088 offsets[i] = offsets[i-1]+model->nSV[i-1];
00089
00090 int32_t s=0;
00091 for (int32_t i=0; i<num_classes; i++)
00092 {
00093 for (int32_t j=i+1; j<num_classes; j++)
00094 {
00095 int32_t k, l;
00096
00097 float64_t sgn=1;
00098 if (model->label[i]>model->label[j])
00099 sgn=-1;
00100
00101 int32_t num_sv=model->nSV[i]+model->nSV[j];
00102 float64_t bias=-model->rho[s];
00103
00104 ASSERT(num_sv>0);
00105 ASSERT(model->sv_coef[i] && model->sv_coef[j-1]);
00106
00107 CSVM* svm=new CSVM(num_sv);
00108
00109 svm->set_bias(sgn*bias);
00110
00111 int32_t sv_idx=0;
00112 for (k=0; k<model->nSV[i]; k++)
00113 {
00114 svm->set_support_vector(sv_idx, model->SV[offsets[i]+k]->index);
00115 svm->set_alpha(sv_idx, sgn*model->sv_coef[j-1][offsets[i]+k]);
00116 sv_idx++;
00117 }
00118
00119 for (k=0; k<model->nSV[j]; k++)
00120 {
00121 svm->set_support_vector(sv_idx, model->SV[offsets[j]+k]->index);
00122 svm->set_alpha(sv_idx, sgn*model->sv_coef[i][offsets[j]+k]);
00123 sv_idx++;
00124 }
00125
00126 int32_t idx=0;
00127
00128 if (sgn>0)
00129 {
00130 for (k=0; k<model->label[i]; k++)
00131 idx+=num_classes-k-1;
00132
00133 for (l=model->label[i]+1; l<model->label[j]; l++)
00134 idx++;
00135 }
00136 else
00137 {
00138 for (k=0; k<model->label[j]; k++)
00139 idx+=num_classes-k-1;
00140
00141 for (l=model->label[j]+1; l<model->label[i]; l++)
00142 idx++;
00143 }
00144
00145
00146
00147
00148
00149
00150
00151 SG_DEBUG("svm[%d] has %d sv (total: %d), b=%f label:(%d,%d) -> svm[%d]\n", s, num_sv, model->l, bias, model->label[i], model->label[j], idx);
00152
00153 set_svm(idx, svm);
00154 s++;
00155 }
00156 }
00157
00158 CSVM::set_objective(model->objective);
00159
00160 delete[] offsets;
00161 delete[] problem.x;
00162 delete[] problem.y;
00163 delete[] x_space;
00164
00165 svm_destroy_model(model);
00166 model=NULL;
00167
00168 return true;
00169 }
00170 else
00171 return false;
00172 }
00173