#include "ConnectedMachine.h"
#include "Linear.h"
#include "FileDataSet.h"
#include "MseCriterion.h"
#include "Tanh.h"
#include "MseMeasurer.h"
#include "ClassMeasurer.h"
#include "TwoClassFormat.h"
#include "OneHotClassFormat.h"
#include "StochasticGradient.h"
#include "GMTrainer.h"
#include "CmdLine.h"
int main(int argc, char **argv)
{
char *model_file, *test_model_file;
char *valid_file;
char *file;
int n_inputs;
int n_targets;
int n_hu;
int max_load;
real accuracy;
real learning_rate;
real decay;
int max_iter;
bool regression;
int k_fold;
int the_seed;
//=================== The command-line ==========================
// Construct the command line
CmdLine cmd;
// Put the help line at the beginning
cmd.info(help);
// Ask for arguments
cmd.addText("\nArguments:");
cmd.addSCmdArg("file", &file, "the train or test file");
cmd.addICmdArg("n_inputs", &n_inputs, "input dimension
of the data");
cmd.addICmdArg("n_targets", &n_targets, "output dimension
of the data");
// Propose some options
cmd.addText("\nModel Options:");
cmd.addICmdOption("-nhu", &n_hu, 25, "number of hidden
units");
cmd.addBCmdOption("-rm", ®ression, false, "regression
mode");
cmd.addText("\nLearning Options:");
cmd.addICmdOption("-iter", &max_iter, 25, "max number
of iterations");
cmd.addRCmdOption("-lr", &learning_rate, 0.01, "learning
rate");
cmd.addRCmdOption("-e", &accuracy, 0.00001, "end accuracy");
cmd.addRCmdOption("-lrd", &decay, 0, "learning rate
decay");
cmd.addText("\nMisc Options:");
cmd.addICmdOption("-seed", &the_seed, -1, "the random
seed");
cmd.addICmdOption("-Kfold", &k_fold, -1, "number of
subsets for K-fold cross-validation");
cmd.addICmdOption("-load", &max_load, -1, "max number
of examples to load");
cmd.addSCmdOption("-valid", &valid_file, "", "validation
file, if you want it");
cmd.addSCmdOption("-sm", &model_file, "", "file to save
the model");
cmd.addSCmdOption("-test", &test_model_file, "", "model
file to test");
// Read the command line
cmd.read(argc, argv);
// If the user didn't give any random
seed,
// generate a random random seed...
if(the_seed == -1)
seed();
else
manual_seed((long)the_seed);
//=================== Create the MLP...
=========================
ConnectedMachine MLP;
// Create the layers of the MLP
Linear hidden_linear(n_inputs, n_hu);
Tanh hidden_nlinear(n_hu);
Linear output_linear(n_hu, n_targets);
Tanh output_nlinear(n_targets);
// Initialize the layers
hidden_linear.init();
hidden_nlinear.init();
output_linear.init();
output_nlinear.init();
// Add the layers (Full Connected Layers)
to the MLP
MLP.addFCL(&hidden_linear);
MLP.addFCL(&hidden_nlinear);
MLP.addFCL(&output_linear);
// If regression, don't add the tanh
output layer
if(!regression)
MLP.addFCL(&output_nlinear);
// Initialize the MLP
MLP.init();
//=================== DataSets & Measurers... ===================
// Create the training dataset (normalize
inputs)
FileDataSet data(file, n_inputs, n_targets, false, max_load);
data.setBOption("normalize inputs", true);
data.init();
// The list of measurers...
List *measurers = NULL;
// The class format
ClassFormat *class_format = NULL;
if(!regression)
{
if(n_targets == 1)
class_format = new TwoClassFormat(&data);
else
class_format = new OneHotClassFormat(&data);
}
// The validation set...
FileDataSet *valid_data = NULL;
MseMeasurer *valid_mse_meas = NULL;
ClassMeasurer *valid_class_meas = NULL;
// Create a validation set, if any
if(strcmp(valid_file, ""))
{
// Load the validation
set and normalize it with the
// values in the train
dataset
valid_data = new FileDataSet(valid_file, n_inputs,
n_targets);
valid_data->init();
valid_data->normalizeUsingDataSet(&data);
// Create a MSE measurer
and an error class measurer
// on the validation dataset
(if we are not in regression)
valid_mse_meas = new MseMeasurer(MLP.outputs,
valid_data, "the_valid_mse");
valid_mse_meas->init();
addToList(&measurers, 1, valid_mse_meas);
if(!regression)
{
valid_class_meas = new ClassMeasurer(MLP.outputs,
valid_data, class_format, "the_valid_class_err");
valid_class_meas->init();
addToList(&measurers, 1, valid_class_meas);
}
}
// Measurers on the training dataset
MseMeasurer *mse_meas = new MseMeasurer(MLP.outputs, &data,
"the_mse");
mse_meas->init();
addToList(&measurers, 1, mse_meas);
ClassMeasurer *class_meas = NULL;
if(!regression)
{
class_meas = new ClassMeasurer(MLP.outputs,
&data, class_format, "the_class_err");
class_meas->init();
addToList(&measurers, 1, class_meas);
}
//=================== The Trainer ===============================
// The criterion for the GMTrainer
(MSE criterion)
MseCriterion mse(n_targets);
mse.init();
// The optimizer for the GMTrainer
StochasticGradient opt;
opt.setIOption("max iter", max_iter);
opt.setROption("end accuracy", accuracy);
opt.setROption("learning rate", learning_rate);
opt.setROption("learning rate decay", decay);
// The Gradient Machine Trainer
GMTrainer trainer(&MLP, &data, &mse, &opt);
//=================== Let's go... ===============================
// Print the number of parameter of
the MLP (just for fun)
message("Number of parameters: %d", MLP.n_params);
// If the user provides a previously
trained model,
// test it...
if( strcmp(test_model_file, "") )
{
trainer.load(test_model_file);
trainer.test(measurers);
}
// ...else...
else
{
// If the user provides
a number for the K-fold validation,
// do a K-fold validation
if(k_fold > 0)
trainer.crossValidate(k_fold, NULL,
measurers);
// Else, train the model
else
trainer.train(measurers);
// Save the model if the
user provides a name for that
if( strcmp(model_file, "") )
trainer.save(model_file);
}
//=================== Quit... ===================================
if(strcmp(valid_file, ""))
{
delete valid_data;
delete valid_mse_meas;
if(!regression)
delete valid_class_meas;
}
delete mse_meas;
if(!regression)
{
delete class_meas;
delete class_format;
}
freeList(&measurers);
return(0);
}