LinearClassifier.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Soeren Sonnenburg
00008  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _LINEARCLASSIFIER_H__
00012 #define _LINEARCLASSIFIER_H__
00013 
00014 #include "lib/common.h"
00015 #include "features/Labels.h"
00016 #include "features/RealFeatures.h"
00017 #include "classifier/Classifier.h"
00018 
00019 #include <stdio.h>
00020 
00022 class CLinearClassifier : public CClassifier
00023 {
00024     public:
00026         CLinearClassifier();
00027         virtual ~CLinearClassifier();
00028 
00030         virtual inline float64_t classify_example(int32_t vec_idx)
00031         {
00032             int32_t vlen;
00033             bool vfree;
00034             float64_t* vec=features->get_feature_vector(vec_idx, vlen, vfree);
00035             float64_t result=CMath::dot(w,vec,vlen);
00036             features->free_feature_vector(vec, vec_idx, vfree);
00037 
00038             return result+bias;
00039         }
00040 
00046         inline void get_w(float64_t** dst_w, int32_t* dst_dims)
00047         {
00048             ASSERT(dst_w && dst_dims);
00049             ASSERT(w && features);
00050             *dst_dims=features->get_num_features();
00051             *dst_w=(float64_t*) malloc(sizeof(float64_t)*(*dst_dims));
00052             ASSERT(*dst_w);
00053             memcpy(*dst_w, w, sizeof(float64_t) * (*dst_dims));
00054         }
00055 
00061         inline void set_w(float64_t* src_w, int32_t src_w_dim)
00062         {
00063             w=src_w;
00064             w_dim=src_w_dim;
00065         }
00066 
00071         inline void set_bias(float64_t b)
00072         {
00073             bias=b;
00074         }
00075 
00080         inline float64_t get_bias()
00081         {
00082             return bias;
00083         }
00084 
00090         virtual bool load(FILE* srcfile);
00091 
00097         virtual bool save(FILE* dstfile);
00098 
00104         virtual CLabels* classify(CLabels* output=NULL);
00105 
00110         virtual inline void set_features(CRealFeatures* feat)
00111         {
00112             SG_UNREF(features);
00113             SG_REF(feat);
00114             features=feat;
00115         }
00116 
00121         virtual CRealFeatures* get_features() { SG_REF(features); return features; }
00122 
00123     protected:
00125         int32_t w_dim;
00127         float64_t* w;
00129         float64_t bias;
00131         CRealFeatures* features;
00132 };
00133 #endif

SHOGUN Machine Learning Toolbox - Documentation