19 using namespace shogun;
21 CPrimalMosekSOSVM::CPrimalMosekSOSVM()
28 CPrimalMosekSOSVM::CPrimalMosekSOSVM(
37 void CPrimalMosekSOSVM::init()
44 m_regularization = 1.0;
48 CPrimalMosekSOSVM::~CPrimalMosekSOSVM()
52 bool CPrimalMosekSOSVM::train_machine(
CFeatures* data)
54 SG_DEBUG(
"Entering CPrimalMosekSOSVM::train_machine.\n");
58 CFeatures* model_features = get_features();
60 m_model->init_training();
62 m_model->check_training_setup();
63 SG_DEBUG(
"The training setup is correct.\n");
66 int32_t M = m_model->get_dim();
68 int32_t num_aux = m_model->get_num_aux();
70 int32_t num_aux_con = m_model->get_num_aux_con();
74 SG_DEBUG(
"M=%d, N =%d, num_aux=%d, num_aux_con=%d.\n", M, N, num_aux, num_aux_con);
77 CMosek* mosek =
new CMosek(0, M+num_aux+N);
79 REQUIRE(mosek->get_rescode() == MSK_RES_OK,
"Mosek object could not be properly created in PrimalMosekSOSVM training.\n");
84 m_model->init_primal_opt(m_regularization, A, a, B, b, lb, ub, C);
86 SG_DEBUG(
"Regularization used in PrimalMosekSOSVM equal to %.2f.\n", m_regularization);
89 REQUIRE(mosek->init_sosvm(M, N, num_aux, num_aux_con, C, lb, ub, A, b) == MSK_RES_OK,
90 "Mosek error in PrimalMosekSOSVM initializing SO-SVM.\n")
104 for ( int32_t i = 0 ; i < N ; ++i )
111 int32_t num_con = num_aux_con;
112 int32_t old_num_con = num_con;
113 bool exception =
false;
123 SG_DEBUG(
"Iteration #%d: Cutting plane training with num_con=%d and old_num_con=%d.\n",
124 iteration, num_con, old_num_con);
126 old_num_con = num_con;
128 for ( int32_t i = 0 ; i < N ; ++i )
145 while ( cur_res != NULL )
153 if ( slack > max_slack + m_epsilon )
157 if ( ! insert_result(cur_list, result) )
163 add_constraint(mosek, result, num_con, i);
170 if ( ! insert_result(cur_list, result) )
176 add_constraint(mosek, result, num_con, i);
185 SG_DEBUG(
"Entering Mosek QP solver.\n");
187 mosek->optimize(sol);
188 for ( int32_t i = 0 ; i < M+num_aux+N ; ++i )
192 else if ( i < M+num_aux )
195 m_slacks[i-M-num_aux] = sol[i];
198 SG_DEBUG(
"QP solved. The primal objective value is %.4f.\n", mosek->get_primal_objective_value());
202 }
while ( old_num_con != num_con && ! exception );
204 po_value = mosek->get_primal_objective_value();
216 int32_t M = m_w.vlen;
223 bool CPrimalMosekSOSVM::insert_result(
CList* result_list,
CResultSet* result)
const
229 SG_PRINT(
"ResultSet could not be inserted in the list..."
230 "aborting training of PrimalMosekSOSVM\n");
236 bool CPrimalMosekSOSVM::add_constraint(
242 int32_t M = m_model->get_dim();
245 for (
int i = 0 ; i < M ; ++i )
248 return ( mosek->add_constraint_sosvm(dPsi, con_idx, train_idx,
249 m_model->get_num_aux(), -result->
delta) == MSK_RES_OK );
253 float64_t CPrimalMosekSOSVM::compute_primal_objective()
const
263 void CPrimalMosekSOSVM::set_regularization(
float64_t C)
265 m_regularization = C;
SGVector< float64_t > psi_truth
float64_t loss(float64_t prediction, float64_t label)
static float64_t dot(const bool *v1, const bool *v2, int32_t n)
compute dot product between v1 and v2 (blas optimized)
Base class of the labels used in Structured Output (SO) problems.
virtual bool init(CFeatures *features)=0
CSGObject * get_next_element()
static const float64_t INFTY
infinity
virtual int32_t get_num_vectors() const =0
static const float64_t epsilon
CSGObject * get_first_element()
int32_t get_num_elements()
static T max(T a, T b)
return the maximum of two integers
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
Class CStructuredModel that represents the application specific model and contains most of the applic...
The class Features is the base class of all feature objects.
SGVector< float64_t > psi_pred
CSGObject * get_element(int32_t index) const
CHingeLoss implements the hinge loss function.
void push_back(CSGObject *e)
void set_epsilon(float *begin, float max)
Class List implements a doubly connected list for low-level-objects.
bool insert_element(CSGObject *data)