GPBTSVM.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "classifier/svm/GPBTSVM.h"
00012 #include "classifier/svm/gpdt.h"
00013 #include "classifier/svm/gpdtsolve.h"
00014 #include "lib/io.h"
00015
00016 CGPBTSVM::CGPBTSVM()
00017 : CSVM(), model(NULL)
00018 {
00019 }
00020
00021 CGPBTSVM::CGPBTSVM(DREAL C, CKernel* k, CLabels* lab)
00022 : CSVM(C, k, lab), model(NULL)
00023 {
00024 }
00025
00026 CGPBTSVM::~CGPBTSVM()
00027 {
00028 free(model);
00029 }
00030
00031 bool CGPBTSVM::train()
00032 {
00033 double *solution;
00034 QPproblem prob;
00035
00036 ASSERT(kernel);
00037 ASSERT(labels && labels->get_num_labels());
00038 ASSERT(labels->is_two_class_labeling());
00039
00040 int num_lab = 0;
00041 int* lab=get_labels()->get_int_labels(num_lab);
00042 prob.KER=new sKernel(kernel, num_lab);
00043 prob.y=lab;
00044 prob.ell=get_labels()->get_num_labels();
00045 SG_INFO( "%d trainlabels\n", prob.ell);
00046
00047
00048 prob.delta = epsilon;
00049 prob.maxmw = kernel->get_cache_size();
00050 prob.verbosity = 0;
00051 prob.preprocess_size = -1;
00052 prob.projection_projector = -1;
00053 prob.c_const = get_C1();
00054 prob.chunk_size = get_qpsize();
00055 prob.linadd = get_linadd_enabled();
00056
00057 if (prob.chunk_size < 2) prob.chunk_size = 2;
00058 if (prob.q <= 0) prob.q = prob.chunk_size / 3;
00059 if (prob.q < 2) prob.q = 2;
00060 if (prob.q > prob.chunk_size) prob.q = prob.chunk_size;
00061 prob.q = prob.q & (~1);
00062 if (prob.maxmw < 5)
00063 prob.maxmw = 5;
00064
00065
00066 SG_INFO( "\nTRAINING PARAMETERS:\n");
00067 SG_INFO( "\tNumber of training documents: %d\n", prob.ell);
00068 SG_INFO( "\tq: %d\n", prob.chunk_size);
00069 SG_INFO( "\tn: %d\n", prob.q);
00070 SG_INFO( "\tC: %lf\n", prob.c_const);
00071 SG_INFO( "\tkernel type: %d\n", prob.ker_type);
00072 SG_INFO( "\tcache size: %dMb\n", prob.maxmw);
00073 SG_INFO( "\tStopping tolerance: %lf\n", prob.delta);
00074
00075
00076 if (prob.preprocess_size == -1)
00077 prob.preprocess_size = (int) ( (double)prob.chunk_size * 1.5 );
00078
00079 if (prob.projection_projector == -1)
00080 {
00081 if (prob.chunk_size <= 20) prob.projection_projector = 0;
00082 else prob.projection_projector = 1;
00083 }
00084
00085
00086 solution = new double[prob.ell];
00087 prob.gpdtsolve(solution);
00088
00089
00090 CSVM::set_objective(prob.objective_value);
00091
00092 int num_sv=0;
00093 int bsv=0;
00094 int i=0;
00095 int k=0;
00096
00097 for (i = 0; i < prob.ell; i++)
00098 {
00099 if (solution[i] > prob.DELTAsv)
00100 {
00101 num_sv++;
00102 if (solution[i] > (prob.c_const - prob.DELTAsv)) bsv++;
00103 }
00104 }
00105
00106 create_new_model(num_sv);
00107 set_bias(prob.bee);
00108
00109 SG_INFO("SV: %d BSV = %d\n", num_sv, bsv);
00110
00111 for (i = 0; i < prob.ell; i++)
00112 {
00113 if (solution[i] > prob.DELTAsv)
00114 {
00115 set_support_vector(k, i);
00116 set_alpha(k++, solution[i]*get_labels()->get_label(i));
00117 }
00118 }
00119
00120 delete prob.KER;
00121 delete[] prob.y;
00122 delete[] solution;
00123
00124 return true;
00125 }