00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2006 Christian Gehl 00008 * Written (W) 1999-2008 Soeren Sonnenburg 00009 * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _KNN_H__ 00013 #define _KNN_H__ 00014 00015 #include <stdio.h> 00016 #include "lib/common.h" 00017 #include "lib/io.h" 00018 #include "features/Features.h" 00019 #include "distance/Distance.h" 00020 #include "distance/DistanceMachine.h" 00021 00022 class CDistanceMachine; 00023 00025 class CKNN : public CDistanceMachine 00026 { 00027 public: 00029 CKNN(); 00030 00037 CKNN(int32_t k, CDistance* d, CLabels* trainlab); 00038 virtual ~CKNN(); 00039 00044 virtual inline EClassifierType get_classifier_type() { return CT_KNN; } 00045 //inline EDistanceType get_distance_type() { return DT_KNN;} 00046 00051 virtual bool train(); 00052 00058 virtual CLabels* classify(CLabels* output=NULL); 00059 00061 virtual float64_t classify_example(int32_t vec_idx) 00062 { 00063 SG_ERROR( "for performance reasons use classify() instead of classify_example\n"); 00064 return 0; 00065 } 00066 00072 virtual bool load(FILE* srcfile); 00073 00079 virtual bool save(FILE* dstfile); 00080 00085 inline void set_k(float64_t p_k) 00086 { 00087 ASSERT(p_k>0); 00088 this->k=p_k; 00089 } 00090 00095 inline float64_t get_k() 00096 { 00097 return k; 00098 } 00099 00100 protected: 00102 float64_t k; 00103 00105 int32_t num_classes; 00106 00108 int32_t min_label; 00109 00111 int32_t num_train_labels; 00112 00114 int32_t* train_labels; 00115 }; 00116 #endif 00117