GPBTSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Soeren Sonnenburg
00008  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "classifier/svm/GPBTSVM.h"
00012 #include "classifier/svm/gpdt.h"
00013 #include "classifier/svm/gpdtsolve.h"
00014 #include "lib/io.h"
00015 
00016 CGPBTSVM::CGPBTSVM()
00017 : CSVM(), model(NULL)
00018 {
00019 }
00020 
00021 CGPBTSVM::CGPBTSVM(float64_t C, CKernel* k, CLabels* lab)
00022 : CSVM(C, k, lab), model(NULL)
00023 {
00024 }
00025 
00026 CGPBTSVM::~CGPBTSVM()
00027 {
00028     free(model);
00029 }
00030 
00031 bool CGPBTSVM::train()
00032 {
00033     float64_t     *solution;                     /* store the solution found       */
00034     QPproblem  prob;                          /* object containing the solvers  */
00035 
00036     ASSERT(kernel);
00037     ASSERT(labels && labels->get_num_labels());
00038     ASSERT(labels->is_two_class_labeling());
00039 
00040     int32_t num_lab = 0;
00041     int32_t* lab=get_labels()->get_int_labels(num_lab);
00042     prob.KER=new sKernel(kernel, num_lab);
00043     prob.y=lab;
00044     prob.ell=get_labels()->get_num_labels();
00045     SG_INFO( "%d trainlabels\n", prob.ell);
00046 
00047     //  /*** set options defaults ***/
00048     prob.delta = epsilon;
00049     prob.maxmw = kernel->get_cache_size();
00050     prob.verbosity       = 0;
00051     prob.preprocess_size = -1;
00052     prob.projection_projector = -1;
00053     prob.c_const = get_C1();
00054     prob.chunk_size = get_qpsize();
00055     prob.linadd = get_linadd_enabled();
00056 
00057     if (prob.chunk_size < 2)      prob.chunk_size = 2;
00058     if (prob.q <= 0)              prob.q = prob.chunk_size / 3;
00059     if (prob.q < 2)               prob.q = 2;
00060     if (prob.q > prob.chunk_size) prob.q = prob.chunk_size;
00061     prob.q = prob.q & (~1);
00062     if (prob.maxmw < 5)
00063         prob.maxmw = 5;
00064 
00065     /*** set the problem description for final report ***/
00066     SG_INFO( "\nTRAINING PARAMETERS:\n");
00067     SG_INFO( "\tNumber of training documents: %d\n", prob.ell);
00068     SG_INFO( "\tq: %d\n", prob.chunk_size);
00069     SG_INFO( "\tn: %d\n", prob.q);
00070     SG_INFO( "\tC: %lf\n", prob.c_const);
00071     SG_INFO( "\tkernel type: %d\n", prob.ker_type);
00072     SG_INFO( "\tcache size: %dMb\n", prob.maxmw);
00073     SG_INFO( "\tStopping tolerance: %lf\n", prob.delta);
00074 
00075     //  /*** compute the number of cache rows up to maxmw Mb. ***/
00076     if (prob.preprocess_size == -1)
00077         prob.preprocess_size = (int32_t) ( (float64_t)prob.chunk_size * 1.5 );
00078 
00079     if (prob.projection_projector == -1)
00080     {
00081         if (prob.chunk_size <= 20) prob.projection_projector = 0;
00082         else prob.projection_projector = 1;
00083     }
00084 
00085     /*** compute the problem solution *******************************************/
00086     solution = new float64_t[prob.ell];
00087     prob.gpdtsolve(solution);
00088     /****************************************************************************/
00089 
00090     CSVM::set_objective(prob.objective_value);
00091 
00092     int32_t num_sv=0;
00093     int32_t bsv=0;
00094     int32_t i=0;
00095     int32_t k=0;
00096 
00097     for (i = 0; i < prob.ell; i++)
00098     {
00099         if (solution[i] > prob.DELTAsv)
00100         {
00101             num_sv++;
00102             if (solution[i] > (prob.c_const - prob.DELTAsv)) bsv++;
00103         }
00104     }
00105 
00106     create_new_model(num_sv);
00107     set_bias(prob.bee);
00108 
00109     SG_INFO("SV: %d BSV = %d\n", num_sv, bsv);
00110 
00111     for (i = 0; i < prob.ell; i++)
00112     {
00113         if (solution[i] > prob.DELTAsv)
00114         {
00115             set_support_vector(k, i);
00116             set_alpha(k++, solution[i]*get_labels()->get_label(i));
00117         }
00118     }
00119 
00120     delete prob.KER;
00121     delete[] prob.y;
00122     delete[] solution;
00123 
00124     return true;
00125 }

SHOGUN Machine Learning Toolbox - Documentation