LinearHMM.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Soeren Sonnenburg
00008  * Written (W) 1999-2008 Gunnar Raetsch
00009  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #ifndef _LINEARHMM_H__
00013 #define _LINEARHMM_H__
00014 
00015 #include "features/StringFeatures.h"
00016 #include "features/Labels.h"
00017 #include "distributions/Distribution.h"
00018 
00036 class CLinearHMM : public CDistribution
00037 {
00038     public:
00043         CLinearHMM(CStringFeatures<uint16_t>* f);
00044 
00050         CLinearHMM(int32_t p_num_features, int32_t p_num_symbols);
00051         ~CLinearHMM();
00052 
00057         bool train();
00058 
00066         bool train(
00067             const int32_t* indizes, int32_t num_indizes,
00068             float64_t pseudo_count);
00069 
00076         float64_t get_log_likelihood_example(uint16_t* vector, int32_t len);
00077 
00084         float64_t get_likelihood_example(uint16_t* vector, int32_t len);
00085 
00091         virtual float64_t get_log_likelihood_example(int32_t num_example);
00092 
00099         virtual float64_t get_log_derivative(
00100             int32_t num_param, int32_t num_example);
00101 
00108         virtual inline float64_t get_log_derivative_obsolete(
00109             uint16_t obs, int32_t pos)
00110         {
00111             return 1.0/transition_probs[pos*num_symbols+obs];
00112         }
00113 
00120         virtual inline float64_t get_derivative_obsolete(
00121             uint16_t* vector, int32_t len, int32_t pos)
00122         {
00123             ASSERT(pos<len);
00124             return get_likelihood_example(vector, len)/transition_probs[pos*num_symbols+vector[pos]];
00125         }
00126 
00131         virtual inline int32_t get_sequence_length() { return sequence_length; }
00132 
00137         virtual inline int32_t get_num_symbols() { return num_symbols; }
00138 
00143         virtual inline int32_t get_num_model_parameters() { return num_params; }
00144 
00151         virtual inline float64_t get_positional_log_parameter(
00152             uint16_t obs, int32_t position)
00153         {
00154             return log_transition_probs[position*num_symbols+obs];
00155         }
00156 
00162         virtual inline float64_t get_log_model_parameter(int32_t num_param)
00163         {
00164             ASSERT(log_transition_probs);
00165             ASSERT(num_param<num_params);
00166 
00167             return log_transition_probs[num_param];
00168         }
00169 
00177         virtual void get_log_transition_probs(float64_t** dst, int32_t* num);
00178 
00185         virtual bool set_log_transition_probs(
00186             const float64_t* src, int32_t num);
00187 
00193         virtual void get_transition_probs(float64_t** dst, int32_t* num);
00194 
00201         virtual bool set_transition_probs(const float64_t* src, int32_t num);
00202 
00203     protected:
00205         int32_t sequence_length;
00207         int32_t num_symbols;
00209         int32_t num_params;
00211         float64_t* transition_probs;
00213         float64_t* log_transition_probs;
00214 };
00215 #endif

SHOGUN Machine Learning Toolbox - Documentation