KNN.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #include "classifier/KNN.h"
00014 #include "features/Labels.h"
00015 #include "lib/Mathematics.h"
00016
00017 CKNN::CKNN()
00018 : CDistanceMachine(), k(3), num_classes(0), num_train_labels(0), train_labels(NULL)
00019 {
00020 }
00021
00022 CKNN::CKNN(INT k_, CDistance* d, CLabels* trainlab)
00023 : CDistanceMachine(), k(k_), num_classes(0), train_labels(NULL)
00024 {
00025 set_distance(d);
00026 set_labels(trainlab);
00027 num_train_labels=trainlab->get_num_labels();
00028 }
00029
00030
00031 CKNN::~CKNN()
00032 {
00033 delete[] train_labels;
00034 }
00035
00036 bool CKNN::train()
00037 {
00038 ASSERT(labels);
00039
00040 train_labels=labels->get_int_labels(num_train_labels);
00041 ASSERT(train_labels);
00042 ASSERT(num_train_labels>0);
00043
00044 int max_class=train_labels[0];
00045 int min_class=train_labels[0];
00046
00047 int i;
00048 for (i=1; i<num_train_labels; i++)
00049 {
00050 max_class=CMath::max(max_class, train_labels[i]);
00051 min_class=CMath::min(min_class, train_labels[i]);
00052 }
00053
00054 for (i=0; i<num_train_labels; i++)
00055 train_labels[i]-=min_class;
00056
00057 min_label=min_class;
00058 num_classes=max_class-min_class+1;
00059
00060 SG_INFO( "num_classes: %d (%+d to %+d) num_train: %d\n", num_classes, min_class, max_class, num_train_labels);
00061 return true;
00062 }
00063
00064 CLabels* CKNN::classify(CLabels* output)
00065 {
00066 ASSERT(num_classes>0);
00067 ASSERT(distance);
00068 ASSERT(labels);
00069 ASSERT(labels->get_num_labels());
00070
00071 INT num_lab=labels->get_num_labels();
00072 ASSERT(k<=num_lab);
00073
00074
00075 DREAL* dists=new DREAL[num_train_labels];
00076 INT* train_lab=new INT[num_train_labels];
00077
00079 INT* classes=new INT[num_classes];
00080 if (!output)
00081 output=new CLabels(num_lab);
00082
00083 ASSERT(dists);
00084 ASSERT(train_lab);
00085 ASSERT(output);
00086 ASSERT(classes);
00087
00088 SG_INFO( "%d test examples\n", num_lab);
00089 for (INT i=0; i<num_lab; i++)
00090 {
00091 if ((i%(num_lab/10+1))== 0)
00092 SG_PROGRESS(i, 0, num_lab);
00093
00094 INT j;
00095 for (j=0; j<num_train_labels; j++)
00096 {
00097
00098 train_lab[j]=train_labels[j];
00099
00100 dists[j]=distance->distance(j,i);
00101 }
00102
00103
00104
00105 CMath::qsort_index(dists, train_lab, num_train_labels);
00106
00107
00108 for (j=0; j<num_classes; j++)
00109 classes[j]=0;
00110
00111 for (j=0; j<k; j++)
00112 classes[train_lab[j]]++;
00113
00114
00115 INT out_idx=0;
00116 INT out_max=0;
00117
00118 for (j=0; j<num_classes; j++)
00119 {
00120 if (out_max< classes[j])
00121 {
00122 out_idx= j;
00123 out_max= classes[j];
00124 }
00125 }
00126
00127 output->set_label(i, out_idx+min_label);
00128 }
00129
00130 delete[] dists;
00131 delete[] train_lab;
00132 delete[] classes;
00133
00134 return output;
00135 }
00136
00137 bool CKNN::load(FILE* srcfile)
00138 {
00139 return false;
00140 }
00141
00142 bool CKNN::save(FILE* dstfile)
00143 {
00144 return false;
00145 }