KMeans.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _KMEANS_H__
00013 #define _KMEANS_H__
00014
00015 #include <stdio.h>
00016 #include "lib/common.h"
00017 #include "lib/io.h"
00018 #include "features/RealFeatures.h"
00019 #include "distance/Distance.h"
00020 #include "distance/DistanceMachine.h"
00021
00022 class CDistanceMachine;
00023
00036 class CKMeans : public CDistanceMachine
00037 {
00038 public:
00040 CKMeans();
00041
00047 CKMeans(int32_t k, CDistance* d);
00048 virtual ~CKMeans();
00049
00054 virtual inline EClassifierType get_classifier_type() { return CT_KMEANS; }
00055
00060 virtual bool train();
00061
00067 virtual bool load(FILE* srcfile);
00068
00074 virtual bool save(FILE* dstfile);
00075
00080 inline void set_k(int32_t p_k)
00081 {
00082 ASSERT(p_k>0);
00083 this->k=p_k;
00084 }
00085
00090 inline int32_t get_k()
00091 {
00092 return k;
00093 }
00094
00099 inline void set_max_iter(int32_t iter)
00100 {
00101 ASSERT(iter>0);
00102 max_iter=iter;
00103 }
00104
00109 inline float64_t get_max_iter()
00110 {
00111 return max_iter;
00112 }
00113
00119 inline void get_radi(float64_t*& radi, int32_t& num)
00120 {
00121 radi=R;
00122 num=k;
00123 }
00124
00131 inline void get_centers(float64_t*& centers, int32_t& dim, int32_t& num)
00132 {
00133 centers=mus;
00134 dim=dimensions;
00135 num=k;
00136 }
00137
00143 inline void get_radiuses(float64_t** radii, int32_t* num)
00144 {
00145 size_t sz=sizeof(*R)*k;
00146 *radii=(float64_t*) malloc(sz);
00147 ASSERT(*radii);
00148
00149 memcpy(*radii, R, sz);
00150 *num=k;
00151 }
00152
00159 inline void get_cluster_centers(
00160 float64_t** centers, int32_t* dim, int32_t* num)
00161 {
00162 size_t sz=sizeof(*mus)*dimensions*k;
00163 *centers=(float64_t*) malloc(sz);
00164 ASSERT(*centers);
00165
00166 memcpy(*centers, mus, sz);
00167 *dim=dimensions;
00168 *num=k;
00169 }
00170
00175 inline int32_t get_dimensions()
00176 {
00177 return dimensions;
00178 }
00179
00180
00181 protected:
00192 void sqdist(
00193 float64_t* x, CRealFeatures* y, float64_t *z, int32_t n1,
00194 int32_t offs, int32_t n2, int32_t m);
00195
00201 void clustknb(bool use_old_mus, float64_t *mus_start);
00202
00203 protected:
00205 int32_t max_iter;
00206
00208 int32_t k;
00209
00211 int32_t dimensions;
00212
00214 float64_t* R;
00215
00217 float64_t* mus;
00218
00219 private:
00221 float64_t* Weights;
00222 };
00223 #endif
00224