KNN.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  *
00008  * Written (W) 2006 Christian Gehl
00009  * Written (W) 2006-2008 Soeren Sonnenburg
00010  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00011  */
00012 
00013 #include "classifier/KNN.h"
00014 #include "features/Labels.h"
00015 #include "lib/Mathematics.h"
00016 
00017 CKNN::CKNN()
00018 : CDistanceMachine(), k(3), num_classes(0), num_train_labels(0), train_labels(NULL)
00019 {
00020 }
00021 
00022 CKNN::CKNN(int32_t k_, CDistance* d, CLabels* trainlab)
00023 : CDistanceMachine(), k(k_), num_classes(0), train_labels(NULL)
00024 {
00025     set_distance(d);
00026     set_labels(trainlab);
00027     num_train_labels=trainlab->get_num_labels();
00028 }
00029 
00030 
00031 CKNN::~CKNN()
00032 {
00033     delete[] train_labels;
00034 }
00035 
00036 bool CKNN::train()
00037 {
00038     ASSERT(labels);
00039 
00040     train_labels=labels->get_int_labels(num_train_labels);
00041     ASSERT(train_labels);
00042     ASSERT(num_train_labels>0);
00043 
00044     int32_t max_class=train_labels[0];
00045     int32_t min_class=train_labels[0];
00046 
00047     int32_t i;
00048     for (i=1; i<num_train_labels; i++)
00049     {
00050         max_class=CMath::max(max_class, train_labels[i]);
00051         min_class=CMath::min(min_class, train_labels[i]);
00052     }
00053 
00054     for (i=0; i<num_train_labels; i++)
00055         train_labels[i]-=min_class;
00056 
00057     min_label=min_class;
00058     num_classes=max_class-min_class+1;
00059 
00060     SG_INFO( "num_classes: %d (%+d to %+d) num_train: %d\n", num_classes, min_class, max_class, num_train_labels);
00061     return true;
00062 }
00063 
00064 CLabels* CKNN::classify(CLabels* output)
00065 {
00066     ASSERT(num_classes>0);
00067     ASSERT(distance);
00068     ASSERT(labels);
00069     ASSERT(labels->get_num_labels());
00070 
00071     int32_t num_lab=labels->get_num_labels();
00072     ASSERT(k<=num_lab);
00073 
00074     //distances to train data and working buffer of train_labels
00075     float64_t* dists=new float64_t[num_train_labels];
00076     int32_t* train_lab=new int32_t[num_train_labels];
00077 
00079     int32_t* classes=new int32_t[num_classes];
00080     if (!output)
00081         output=new CLabels(num_lab);
00082 
00083     ASSERT(dists);
00084     ASSERT(train_lab);
00085     ASSERT(output);
00086     ASSERT(classes);
00087 
00088     SG_INFO( "%d test examples\n", num_lab);
00089     for (int32_t i=0; i<num_lab; i++)
00090     {
00091         if ((i%(num_lab/10+1))== 0)
00092             SG_PROGRESS(i, 0, num_lab);
00093 
00094         int32_t j;
00095         for (j=0; j<num_train_labels; j++)
00096         {
00097             //copy back train labels and compute distance
00098             train_lab[j]=train_labels[j];
00099             
00100             dists[j]=distance->distance(j,i);
00101         }
00102 
00103         //sort the distance vector for test example j to all train examples
00104         //classes[1..k] then holds the classes for minimum distance
00105         CMath::qsort_index(dists, train_lab, num_train_labels);
00106 
00107         //compute histogram of class outputs of the first k nearest neighbours
00108         for (j=0; j<num_classes; j++)
00109             classes[j]=0;
00110 
00111         for (j=0; j<k; j++)
00112             classes[train_lab[j]]++;
00113 
00114         //choose the class that got 'outputted' most often
00115         int32_t out_idx=0;
00116         int32_t out_max=0;
00117 
00118         for (j=0; j<num_classes; j++)
00119         {
00120             if (out_max< classes[j])
00121             {
00122                 out_idx= j;
00123                 out_max= classes[j];
00124             }
00125         }
00126 
00127         output->set_label(i, out_idx+min_label);
00128     }
00129 
00130     delete[] dists;
00131     delete[] train_lab;
00132     delete[] classes;
00133 
00134     return output;
00135 }
00136 
00137 bool CKNN::load(FILE* srcfile)
00138 {
00139     return false;
00140 }
00141 
00142 bool CKNN::save(FILE* dstfile)
00143 {
00144     return false;
00145 }

SHOGUN Machine Learning Toolbox - Documentation