Package mvpa :: Package clfs :: Package sg :: Module sens
[hide private]
[frames] | no frames]

Source Code for Module mvpa.clfs.sg.sens

 1  #emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 
 2  #ex: set sts=4 ts=4 sw=4 et: 
 3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
 4  # 
 5  #   See COPYING file distributed along with the PyMVPA package for the 
 6  #   copyright and license terms. 
 7  # 
 8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
 9  """Provide sensitivity measures for sg's SVM.""" 
10   
11  __docformat__ = 'restructuredtext' 
12   
13  import numpy as N 
14   
15  import shogun.Classifier 
16   
17  from mvpa.misc.state import StateVariable 
18  from mvpa.measures.base import Sensitivity 
19   
20  if __debug__: 
21      from mvpa.base import debug 
22   
23   
24 -class LinearSVMWeights(Sensitivity):
25 """`Sensitivity` that reports the weights of a linear SVM trained 26 on a given `Dataset`. 27 """ 28 29 biases = StateVariable(enabled=True, 30 doc="Offsets of separating hyperplanes") 31
32 - def __init__(self, clf, **kwargs):
33 """Initialize the analyzer with the classifier it shall use. 34 35 :Parameters: 36 clf: LinearSVM 37 classifier to use. Only classifiers sub-classed from 38 `LinearSVM` may be used. 39 """ 40 # init base classes first 41 Sensitivity.__init__(self, clf, **kwargs)
42 43
44 - def __sg_helper(self, svm):
45 """Helper function to compute sensitivity for a single given SVM""" 46 self.offsets = svm.get_bias() 47 svcoef = N.matrix(svm.get_alphas()) 48 svnums = svm.get_support_vectors() 49 svs = self.clf.traindataset.samples[svnums,:] 50 res = (svcoef * svs).mean(axis=0).A1 51 return res
52 53
54 - def _call(self, dataset):
55 # XXX Hm... it might make sense to unify access functions 56 # naming across our swig libsvm wrapper and sg access 57 # functions for svm 58 svm = self.clf.svm 59 if isinstance(svm, shogun.Classifier.MultiClassSVM): 60 sens = [] 61 for i in xrange(svm.get_num_svms()): 62 sens.append(self.__sg_helper(svm.get_svm(i))) 63 else: 64 sens = self.__sg_helper(svm) 65 return N.asarray(sens)
66