LPM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2007-2008 Soeren Sonnenburg
00008  * Copyright (C) 2007-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/config.h"
00012 
00013 #ifdef USE_CPLEX
00014 
00015 #include "classifier/LPM.h"
00016 #include "features/Labels.h"
00017 #include "lib/Mathematics.h"
00018 #include "lib/Cplex.h"
00019 
00020 CLPM::CLPM()
00021 : CSparseLinearClassifier(), C1(1), C2(1), use_bias(true), epsilon(1e-3)
00022 {
00023 }
00024 
00025 
00026 CLPM::~CLPM()
00027 {
00028 }
00029 
00030 bool CLPM::train()
00031 {
00032     ASSERT(labels);
00033     ASSERT(features);
00034     int32_t num_train_labels=labels->get_num_labels();
00035     int32_t num_feat=features->get_num_features();
00036     int32_t num_vec=features->get_num_vectors();
00037 
00038     ASSERT(num_vec==num_train_labels);
00039     delete[] w;
00040     w=new float64_t[num_feat];
00041     w_dim=num_feat;
00042 
00043     int32_t num_params=1+2*num_feat+num_vec; //b,w+,w-,xi
00044     float64_t* params=new float64_t[num_params];
00045     memset(params,0,sizeof(float64_t)*num_params);
00046 
00047     CCplex solver;
00048     solver.init(E_LINEAR);
00049     SG_INFO("C=%f\n", C1);
00050     solver.setup_lpm(C1, features, labels, get_bias_enabled());
00051     if (get_max_train_time()>0)
00052         solver.set_time_limit(get_max_train_time());
00053     bool result=solver.optimize(params);
00054     solver.cleanup();
00055 
00056     set_bias(params[0]);
00057     for (int32_t i=0; i<num_feat; i++)
00058         w[i]=params[1+i]-params[1+num_feat+i];
00059 
00060 //#define LPM_DEBUG
00061 #ifdef LPM_DEBUG
00062     CMath::display_vector(params,num_params, "params");
00063     SG_PRINT("bias=%f\n", bias);
00064     CMath::display_vector(w,w_dim, "w");
00065     CMath::display_vector(&params[1],w_dim, "w+");
00066     CMath::display_vector(&params[1+w_dim],w_dim, "w-");
00067 #endif
00068     delete[] params;
00069 
00070     return result;
00071 }
00072 #endif

SHOGUN Machine Learning Toolbox - Documentation