WeightedDegreePositionStringKernel.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Soeren Sonnenburg
00008  * Written (W) 1999-2008 Gunnar Raetsch
00009  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #ifndef _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___
00013 #define _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___
00014 
00015 #include "lib/common.h"
00016 #include "kernel/StringKernel.h"
00017 #include "kernel/WeightedDegreeStringKernel.h"
00018 #include "lib/Trie.h"
00019 
00020 class CSVM ;
00021 
00043 class CWeightedDegreePositionStringKernel: public CStringKernel<char>
00044 {
00045     public:
00053         CWeightedDegreePositionStringKernel(
00054             int32_t size, int32_t degree,
00055             int32_t max_mismatch=0, int32_t mkl_stepsize=1);
00056 
00067         CWeightedDegreePositionStringKernel(
00068             int32_t size, float64_t* weights, int32_t degree,
00069             int32_t max_mismatch, int32_t* shift, int32_t shift_len,
00070             int32_t mkl_stepsize=1);
00071 
00078         CWeightedDegreePositionStringKernel(
00079             CStringFeatures<char>* l, CStringFeatures<char>* r, int32_t degree);
00080 
00081         virtual ~CWeightedDegreePositionStringKernel();
00082 
00089         virtual bool init(CFeatures* l, CFeatures* r);
00090 
00092         virtual void cleanup();
00093 
00099         bool load_init(FILE* src);
00100 
00106         bool save_init(FILE* dest);
00107 
00112         virtual EKernelType get_kernel_type() { return K_WEIGHTEDDEGREEPOS; }
00113 
00118         virtual const char* get_name() { return "WeightedDegreePos"; }
00119 
00127         inline virtual bool init_optimization(
00128             int32_t p_count, int32_t *IDX, float64_t * alphas)
00129         { 
00130             return init_optimization(p_count, IDX, alphas, -1);
00131         }
00132 
00144         virtual bool init_optimization(
00145             int32_t count, int32_t *IDX, float64_t * alphas, int32_t tree_num,
00146             int32_t upto_tree=-1);
00147 
00152         virtual bool delete_optimization();
00153 
00159         inline virtual float64_t compute_optimized(int32_t idx)
00160         { 
00161             ASSERT(get_is_initialized());
00162             ASSERT(alphabet);
00163             ASSERT(alphabet->get_alphabet()==DNA || alphabet->get_alphabet()==RNA);
00164             return compute_by_tree(idx);
00165         }
00166 
00171         static void* compute_batch_helper(void* p);
00172 
00183         virtual void compute_batch(
00184             int32_t num_vec, int32_t* vec_idx, float64_t* target,
00185             int32_t num_suppvec, int32_t* IDX, float64_t* alphas,
00186             float64_t factor=1.0);
00187 
00191         inline virtual void clear_normal()
00192         {
00193             if ((opt_type==FASTBUTMEMHUNGRY) && (tries.get_use_compact_terminal_nodes()))
00194             {
00195                 tries.set_use_compact_terminal_nodes(false) ;
00196                 SG_DEBUG( "disabling compact trie nodes with FASTBUTMEMHUNGRY\n") ;
00197             }
00198 
00199             if (get_is_initialized())
00200             {
00201                 if (opt_type==SLOWBUTMEMEFFICIENT)
00202                     tries.delete_trees(true); 
00203                 else if (opt_type==FASTBUTMEMHUNGRY)
00204                     tries.delete_trees(false);  // still buggy
00205                 else
00206                     SG_ERROR( "unknown optimization type\n");
00207 
00208                 set_is_initialized(false);
00209             }
00210         }
00211 
00217         inline virtual void add_to_normal(int32_t idx, float64_t weight)
00218         {
00219             add_example_to_tree(idx, weight);
00220             set_is_initialized(true);
00221         }
00222 
00227         inline virtual int32_t get_num_subkernels()
00228         {
00229             if (position_weights!=NULL)
00230                 return (int32_t) ceil(1.0*seq_length/mkl_stepsize) ;
00231             if (length==0)
00232                 return (int32_t) ceil(1.0*get_degree()/mkl_stepsize);
00233             return (int32_t) ceil(1.0*get_degree()*length/mkl_stepsize) ;
00234         }
00235 
00241         inline void compute_by_subkernel(
00242             int32_t idx, float64_t * subkernel_contrib)
00243         { 
00244             if (get_is_initialized())
00245             {
00246                 compute_by_tree(idx, subkernel_contrib);
00247                 return ;
00248             }
00249 
00250             SG_ERROR( "CWeightedDegreePositionStringKernel optimization not initialized\n") ;
00251         }
00252 
00258         inline const float64_t* get_subkernel_weights(int32_t& num_weights)
00259         {
00260             num_weights = get_num_subkernels() ;
00261 
00262             delete[] weights_buffer ;
00263             weights_buffer = new float64_t[num_weights] ;
00264 
00265             if (position_weights!=NULL)
00266                 for (int32_t i=0; i<num_weights; i++)
00267                     weights_buffer[i] = position_weights[i*mkl_stepsize] ;
00268             else
00269                 for (int32_t i=0; i<num_weights; i++)
00270                     weights_buffer[i] = weights[i*mkl_stepsize] ;
00271 
00272             return weights_buffer ;
00273         }
00274 
00280         inline void set_subkernel_weights(
00281             float64_t* weights2, int32_t num_weights2)
00282         {
00283             int32_t num_weights = get_num_subkernels() ;
00284             if (num_weights!=num_weights2)
00285                 SG_ERROR( "number of weights do not match\n") ;
00286 
00287             if (position_weights!=NULL)
00288                 for (int32_t i=0; i<num_weights; i++)
00289                     for (int32_t j=0; j<mkl_stepsize; j++)
00290                     {
00291                         if (i*mkl_stepsize+j<seq_length)
00292                             position_weights[i*mkl_stepsize+j] = weights2[i] ;
00293                     }
00294             else if (length==0)
00295             {
00296                 for (int32_t i=0; i<num_weights; i++)
00297                     for (int32_t j=0; j<mkl_stepsize; j++)
00298                         if (i*mkl_stepsize+j<get_degree())
00299                             weights[i*mkl_stepsize+j] = weights2[i] ;
00300             }
00301             else
00302             {
00303                 for (int32_t i=0; i<num_weights; i++)
00304                     for (int32_t j=0; j<mkl_stepsize; j++)
00305                         if (i*mkl_stepsize+j<get_degree()*length)
00306                             weights[i*mkl_stepsize+j] = weights2[i] ;
00307             }
00308         }
00309 
00310         // other kernel tree operations
00316         float64_t* compute_abs_weights(int32_t & len);
00317 
00322         bool is_tree_initialized() { return tree_initialized; }
00323 
00328         inline int32_t get_max_mismatch() { return max_mismatch; }
00329 
00334         inline int32_t get_degree() { return degree; }
00335 
00341         inline float64_t *get_degree_weights(int32_t& d, int32_t& len)
00342         {
00343             d=degree;
00344             len=length;
00345             return weights;
00346         }
00347 
00353         inline float64_t *get_weights(int32_t& num_weights)
00354         {
00355             if (position_weights!=NULL)
00356             {
00357                 num_weights = seq_length ;
00358                 return position_weights ;
00359             }
00360             if (length==0)
00361                 num_weights = degree ;
00362             else
00363                 num_weights = degree*length ;
00364             return weights;
00365         }
00366 
00372         inline float64_t *get_position_weights(int32_t& len)
00373         {
00374             len=seq_length;
00375             return position_weights;
00376         }
00377 
00383         bool set_shifts(int32_t* shifts, int32_t len);
00384 
00391         virtual bool set_weights(float64_t* weights, int32_t d, int32_t len=0);
00392 
00397         virtual bool set_wd_weights();
00398 
00405         virtual bool set_position_weights(
00406             float64_t* position_weights, int32_t len=0);
00407 
00415         bool set_position_weights_lhs(float64_t* pws, int32_t len, int32_t num);
00416 
00424         bool set_position_weights_rhs(float64_t* pws, int32_t len, int32_t num);
00425 
00430         bool init_block_weights();
00431 
00436         bool init_block_weights_from_wd();
00437 
00442         bool init_block_weights_from_wd_external();
00443 
00448         bool init_block_weights_const();
00449 
00454         bool init_block_weights_linear();
00455 
00460         bool init_block_weights_sqpoly();
00461 
00466         bool init_block_weights_cubicpoly();
00467 
00472         bool init_block_weights_exp();
00473 
00478         bool init_block_weights_log();
00479 
00484         bool init_block_weights_external();
00485 
00490         bool delete_position_weights()
00491         {
00492             delete[] position_weights;
00493             position_weights=NULL;
00494             return true;
00495         }
00496 
00501         bool delete_position_weights_lhs()
00502         {
00503             delete[] position_weights_lhs;
00504             position_weights_lhs=NULL;
00505             return true;
00506         }
00507 
00512         bool delete_position_weights_rhs()
00513         {
00514             delete[] position_weights_rhs;
00515             position_weights_rhs=NULL;
00516             return true;
00517         }
00518 
00524         virtual float64_t compute_by_tree(int32_t idx);
00525 
00531         virtual void compute_by_tree(int32_t idx, float64_t* LevelContrib);
00532 
00545         float64_t* compute_scoring(
00546             int32_t max_degree, int32_t& num_feat, int32_t& num_sym,
00547             float64_t* target, int32_t num_suppvec, int32_t* IDX,
00548             float64_t* weights);
00549 
00558         char* compute_consensus(
00559             int32_t &num_feat, int32_t num_suppvec, int32_t* IDX,
00560             float64_t* alphas);
00561 
00573         float64_t* extract_w(
00574             int32_t max_degree, int32_t& num_feat, int32_t& num_sym,
00575             float64_t* w_result, int32_t num_suppvec, int32_t* IDX,
00576             float64_t* alphas);
00577 
00590         float64_t* compute_POIM(
00591             int32_t max_degree, int32_t& num_feat, int32_t& num_sym,
00592             float64_t* poim_result, int32_t num_suppvec, int32_t* IDX,
00593             float64_t* alphas, float64_t* distrib);
00594 
00601         void prepare_POIM2(
00602             float64_t* distrib, int32_t num_sym, int32_t num_feat);
00603 
00610         void compute_POIM2(int32_t max_degree, CSVM* svm);
00611 
00617         void get_POIM2(float64_t** poim, int32_t* result_len);
00618 
00620         void cleanup_POIM2();
00621         
00622     protected:
00624         void create_empty_tries();
00625 
00631         virtual void add_example_to_tree(
00632             int32_t idx, float64_t weight);
00633 
00640         void add_example_to_single_tree(
00641             int32_t idx, float64_t weight, int32_t tree_num);
00642 
00651         virtual float64_t compute(int32_t idx_a, int32_t idx_b);
00652 
00661         float64_t compute_with_mismatch(
00662             char* avec, int32_t alen, char* bvec, int32_t blen);
00663 
00672         float64_t compute_without_mismatch(
00673             char* avec, int32_t alen, char* bvec, int32_t blen);
00674 
00683         float64_t compute_without_mismatch_matrix(
00684             char* avec, int32_t alen, char* bvec, int32_t blen);
00685 
00696         float64_t compute_without_mismatch_position_weights(
00697             char* avec, float64_t *posweights_lhs, int32_t alen,
00698             char* bvec, float64_t *posweights_rhs, int32_t blen);
00699 
00701         virtual void remove_lhs();
00702 
00703     protected:
00705         float64_t* weights;
00707         float64_t* position_weights;
00709         float64_t* position_weights_lhs;
00711         float64_t* position_weights_rhs;
00713         bool* position_mask;
00714 
00716         float64_t* weights_buffer;
00718         int32_t mkl_stepsize;
00719 
00721         int32_t degree;
00723         int32_t length;
00724 
00726         int32_t max_mismatch;
00728         int32_t seq_length;
00729 
00731         int32_t *shift;
00733         int32_t shift_len;
00735         int32_t max_shift;
00736 
00738         bool block_computation;
00739 
00741         int32_t num_block_weights_external;
00743         float64_t* block_weights_external;
00744 
00746         float64_t* block_weights;
00748         EWDKernType type;
00750         int32_t which_degree;
00751 
00753         CTrie<DNATrie> tries;
00755         CTrie<POIMTrie> poim_tries;
00756 
00758         bool tree_initialized;
00760         bool use_poim_tries;
00761 
00763         float64_t* m_poim_distrib;
00765         float64_t* m_poim;
00766 
00768         int32_t m_poim_num_sym;
00770         int32_t m_poim_num_feat;
00772         int32_t m_poim_result_len;
00773 
00775         CAlphabet* alphabet;
00776 };
00777 #endif /* _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H__ */

SHOGUN Machine Learning Toolbox - Documentation