LinearHMM.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _LINEARHMM_H__
00013 #define _LINEARHMM_H__
00014
00015 #include "features/StringFeatures.h"
00016 #include "features/Labels.h"
00017 #include "distributions/Distribution.h"
00018
00036 class CLinearHMM : public CDistribution
00037 {
00038 public:
00043 CLinearHMM(CStringFeatures<uint16_t>* f);
00044
00050 CLinearHMM(int32_t p_num_features, int32_t p_num_symbols);
00051 ~CLinearHMM();
00052
00057 bool train();
00058
00066 bool train(
00067 const int32_t* indizes, int32_t num_indizes,
00068 float64_t pseudo_count);
00069
00076 float64_t get_log_likelihood_example(uint16_t* vector, int32_t len);
00077
00084 float64_t get_likelihood_example(uint16_t* vector, int32_t len);
00085
00091 virtual float64_t get_log_likelihood_example(int32_t num_example);
00092
00099 virtual float64_t get_log_derivative(
00100 int32_t num_param, int32_t num_example);
00101
00108 virtual inline float64_t get_log_derivative_obsolete(
00109 uint16_t obs, int32_t pos)
00110 {
00111 return 1.0/transition_probs[pos*num_symbols+obs];
00112 }
00113
00120 virtual inline float64_t get_derivative_obsolete(
00121 uint16_t* vector, int32_t len, int32_t pos)
00122 {
00123 ASSERT(pos<len);
00124 return get_likelihood_example(vector, len)/transition_probs[pos*num_symbols+vector[pos]];
00125 }
00126
00131 virtual inline int32_t get_sequence_length() { return sequence_length; }
00132
00137 virtual inline int32_t get_num_symbols() { return num_symbols; }
00138
00143 virtual inline int32_t get_num_model_parameters() { return num_params; }
00144
00151 virtual inline float64_t get_positional_log_parameter(
00152 uint16_t obs, int32_t position)
00153 {
00154 return log_transition_probs[position*num_symbols+obs];
00155 }
00156
00162 virtual inline float64_t get_log_model_parameter(int32_t num_param)
00163 {
00164 ASSERT(log_transition_probs);
00165 ASSERT(num_param<num_params);
00166
00167 return log_transition_probs[num_param];
00168 }
00169
00177 virtual void get_log_transition_probs(float64_t** dst, int32_t* num);
00178
00185 virtual bool set_log_transition_probs(
00186 const float64_t* src, int32_t num);
00187
00193 virtual void get_transition_probs(float64_t** dst, int32_t* num);
00194
00201 virtual bool set_transition_probs(const float64_t* src, int32_t num);
00202
00203 protected:
00205 int32_t sequence_length;
00207 int32_t num_symbols;
00209 int32_t num_params;
00211 float64_t* transition_probs;
00213 float64_t* log_transition_probs;
00214 };
00215 #endif