PluginEstimate.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _PLUGINESTIMATE_H___
00012 #define _PLUGINESTIMATE_H___
00013
00014 #include "classifier/Classifier.h"
00015 #include "features/StringFeatures.h"
00016 #include "features/Labels.h"
00017 #include "distributions/hmm/LinearHMM.h"
00018
00020 class CPluginEstimate: public CClassifier
00021 {
00022 public:
00027 CPluginEstimate(DREAL pos_pseudo=1e-10, DREAL neg_pseudo=1e-10);
00028 virtual ~CPluginEstimate();
00029
00034 bool train();
00035
00037 CLabels* classify(CLabels* output=NULL);
00038
00043 virtual inline void set_features(CStringFeatures<WORD>* feat)
00044 {
00045 SG_UNREF(features);
00046 SG_REF(feat);
00047 features=feat;
00048 }
00049
00054 virtual CStringFeatures<WORD>* get_features() { SG_REF(features); return features; }
00055
00057 DREAL classify_example(INT vec_idx);
00058
00065 inline DREAL posterior_log_odds_obsolete(WORD* vector, INT len)
00066 {
00067 return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len);
00068 }
00069
00076 inline DREAL get_parameterwise_log_odds(WORD obs, INT position)
00077 {
00078 return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position);
00079 }
00080
00087 inline DREAL log_derivative_pos_obsolete(WORD obs, INT pos)
00088 {
00089 return pos_model->get_log_derivative_obsolete(obs, pos);
00090 }
00091
00098 inline DREAL log_derivative_neg_obsolete(WORD obs, INT pos)
00099 {
00100 return neg_model->get_log_derivative_obsolete(obs, pos);
00101 }
00102
00111 inline bool get_model_params(DREAL*& pos_params, DREAL*& neg_params, INT &seq_length, INT &num_symbols)
00112 {
00113 INT num;
00114
00115 if ((!pos_model) || (!neg_model))
00116 {
00117 SG_ERROR( "no model available\n");
00118 return false;
00119 }
00120
00121 pos_model->get_log_transition_probs(&pos_params, &num);
00122 neg_model->get_log_transition_probs(&neg_params, &num);
00123
00124 seq_length = pos_model->get_sequence_length();
00125 num_symbols = pos_model->get_num_symbols();
00126 ASSERT(pos_model->get_num_model_parameters()==neg_model->get_num_model_parameters());
00127 ASSERT(pos_model->get_num_symbols()==neg_model->get_num_symbols());
00128 return true;
00129 }
00130
00137 inline void set_model_params(const DREAL* pos_params, const DREAL* neg_params, INT seq_length, INT num_symbols)
00138 {
00139 INT num_params;
00140
00141 delete pos_model;
00142 pos_model=new CLinearHMM(seq_length, num_symbols);
00143 delete neg_model;
00144 neg_model=new CLinearHMM(seq_length, num_symbols);
00145
00146 num_params=pos_model->get_num_model_parameters();
00147 ASSERT(seq_length*num_symbols==num_params);
00148 ASSERT(num_params==neg_model->get_num_model_parameters());
00149
00150 pos_model->set_log_transition_probs(pos_params, num_params);
00151 neg_model->set_log_transition_probs(neg_params, num_params);
00152 }
00153
00158 inline INT get_num_params()
00159 {
00160 return pos_model->get_num_model_parameters()+neg_model->get_num_model_parameters();
00161 }
00162
00167 inline bool check_models()
00168 {
00169 return ( (pos_model!=NULL) && (neg_model!=NULL) );
00170 }
00171
00172 protected:
00174 DREAL m_pos_pseudo;
00176 DREAL m_neg_pseudo;
00177
00179 CLinearHMM* pos_model;
00181 CLinearHMM* neg_model;
00182
00184 CStringFeatures<WORD>* features;
00185 };
00186 #endif