Perceptron.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "classifier/Perceptron.h"
00012 #include "features/Labels.h"
00013 #include "lib/Mathematics.h"
00014
00015 CPerceptron::CPerceptron()
00016 : CLinearClassifier(), learn_rate(0.1), max_iter(1000)
00017 {
00018 }
00019
00020 CPerceptron::CPerceptron(CRealFeatures* traindat, CLabels* trainlab)
00021 : CLinearClassifier(), learn_rate(.1), max_iter(1000)
00022 {
00023 set_features(traindat);
00024 set_labels(trainlab);
00025 }
00026
00027 CPerceptron::~CPerceptron()
00028 {
00029 }
00030
00031 bool CPerceptron::train()
00032 {
00033 ASSERT(labels);
00034 ASSERT(features);
00035 bool converged=false;
00036 int32_t iter=0;
00037 int32_t num_train_labels=0;
00038 int32_t* train_labels=labels->get_int_labels(num_train_labels);
00039 int32_t num_feat=features->get_num_features();
00040 int32_t num_vec=features->get_num_vectors();
00041
00042 ASSERT(num_vec==num_train_labels);
00043 delete[] w;
00044 w_dim=num_feat;
00045 w=new float64_t[num_feat];
00046 float64_t* output=new float64_t[num_vec];
00047
00048
00049 bias=0;
00050 for (int32_t i=0; i<num_feat; i++)
00051 w[i]=1.0/num_feat;
00052
00053
00054
00055 while (!converged && iter<max_iter)
00056 {
00057 converged=true;
00058 for (int32_t i=0; i<num_vec; i++)
00059 output[i]=classify_example(i);
00060
00061 for (int32_t i=0; i<num_vec; i++)
00062 {
00063 if (CMath::sign<float64_t>(output[i]) != train_labels[i])
00064 {
00065 converged=false;
00066 int32_t vlen;
00067 bool vfree;
00068 float64_t* vec=features->get_feature_vector(i, vlen, vfree);
00069
00070 bias+=learn_rate*train_labels[i];
00071 for (int32_t j=0; j<num_feat; j++)
00072 w[j]+= learn_rate*train_labels[i]*vec[j];
00073
00074 features->free_feature_vector(vec, i, vfree);
00075 }
00076 }
00077
00078 iter++;
00079 }
00080
00081 if (converged)
00082 SG_INFO("Perceptron algorithm converged after %d iterations.\n", iter);
00083 else
00084 SG_WARNING("Perceptron algorithm did not converge after %d iterations.\n", max_iter);
00085
00086 delete[] output;
00087 delete[] train_labels;
00088
00089 return converged;
00090 }