LinearClassifier.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _LINEARCLASSIFIER_H__
00012 #define _LINEARCLASSIFIER_H__
00013
00014 #include "lib/common.h"
00015 #include "features/Labels.h"
00016 #include "features/RealFeatures.h"
00017 #include "classifier/Classifier.h"
00018
00019 #include <stdio.h>
00020
00022 class CLinearClassifier : public CClassifier
00023 {
00024 public:
00026 CLinearClassifier();
00027 virtual ~CLinearClassifier();
00028
00030 virtual inline float64_t classify_example(int32_t vec_idx)
00031 {
00032 int32_t vlen;
00033 bool vfree;
00034 float64_t* vec=features->get_feature_vector(vec_idx, vlen, vfree);
00035 float64_t result=CMath::dot(w,vec,vlen);
00036 features->free_feature_vector(vec, vec_idx, vfree);
00037
00038 return result+bias;
00039 }
00040
00046 inline void get_w(float64_t** dst_w, int32_t* dst_dims)
00047 {
00048 ASSERT(dst_w && dst_dims);
00049 ASSERT(w && features);
00050 *dst_dims=features->get_num_features();
00051 *dst_w=(float64_t*) malloc(sizeof(float64_t)*(*dst_dims));
00052 ASSERT(*dst_w);
00053 memcpy(*dst_w, w, sizeof(float64_t) * (*dst_dims));
00054 }
00055
00061 inline void set_w(float64_t* src_w, int32_t src_w_dim)
00062 {
00063 w=src_w;
00064 w_dim=src_w_dim;
00065 }
00066
00071 inline void set_bias(float64_t b)
00072 {
00073 bias=b;
00074 }
00075
00080 inline float64_t get_bias()
00081 {
00082 return bias;
00083 }
00084
00090 virtual bool load(FILE* srcfile);
00091
00097 virtual bool save(FILE* dstfile);
00098
00104 virtual CLabels* classify(CLabels* output=NULL);
00105
00110 virtual inline void set_features(CRealFeatures* feat)
00111 {
00112 SG_UNREF(features);
00113 SG_REF(feat);
00114 features=feat;
00115 }
00116
00121 virtual CRealFeatures* get_features() { SG_REF(features); return features; }
00122
00123 protected:
00125 int32_t w_dim;
00127 float64_t* w;
00129 float64_t bias;
00131 CRealFeatures* features;
00132 };
00133 #endif