Package mvpa :: Package clfs :: Module knn
[hide private]
[frames] | no frames]

Source Code for Module mvpa.clfs.knn

  1  #emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  #ex: set sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """k-Nearest-Neighbour classifier.""" 
 10   
 11  __docformat__ = 'restructuredtext' 
 12   
 13   
 14  import numpy as N 
 15   
 16  from mvpa.misc import warning 
 17  from mvpa.misc.support import indentDoc 
 18  from mvpa.clfs.base import Classifier 
 19   
 20   
21 -class kNN(Classifier):
22 """k-nearest-neighbour classifier. 23 24 If enabled it stores the votes per class in the 'values' state after 25 calling predict(). 26 """ 27 28 __warned = False 29 30 _clf_internals = [ 'knn', 'non-linear', 'multiclass' ] 31
32 - def __init__(self, k=2, **kwargs):
33 """ 34 :Parameters: 35 k 36 number of nearest neighbours to be used for voting 37 """ 38 # init base class first 39 Classifier.__init__(self, train2predict=False, **kwargs) 40 41 self.__k = k 42 # XXX So is the voting function fixed forever? 43 self.__votingfx = self.getWeightedVote 44 self.__data = None
45 46
47 - def __repr__(self):
48 """Representation of the object 49 """ 50 return "kNN(k=%d, enable_states=%s)" % \ 51 (self.__k, str(self.states.enabled))
52 53
54 - def __str__(self):
55 return "%s\n data: %s" % \ 56 (Classifier.__str__(self), indentDoc(self.__data))
57 58
59 - def _train(self, data):
60 """Train the classifier. 61 62 For kNN it is degenerate -- just stores the data. 63 """ 64 self.__data = data 65 if __debug__: 66 if not kNN.__warned and \ 67 str(data.samples.dtype).startswith('uint') \ 68 or str(data.samples.dtype).startswith('int'): 69 kNN.__warned = True 70 warning("kNN: input data is in integers. " + \ 71 "Overflow on arithmetic operations might result in"+\ 72 " errors. Please convert dataset's samples into" +\ 73 " floating datatype if any error is reported.") 74 self.__weights = None 75 76 # create dictionary with an item for each condition 77 uniquelabels = data.uniquelabels 78 self.__votes_init = dict(zip(uniquelabels, 79 [0] * len(uniquelabels)))
80 81
82 - def _predict(self, data):
83 """Predict the class labels for the provided data. 84 85 Returns a list of class labels (one for each data sample). 86 """ 87 # make sure we're talking about arrays 88 data = N.asarray(data) 89 90 if not data.ndim == 2: 91 raise ValueError, "Data array must be two-dimensional." 92 93 if not data.shape[1] == self.__data.nfeatures: 94 raise ValueError, "Length of data samples (features) does " \ 95 "not match the classifier." 96 97 # predicted class labels will go here 98 predicted = [] 99 votes = [] 100 101 # for all test pattern 102 for p in data: 103 # calc the euclidean distance of the pattern vector to all 104 # patterns in the training data 105 dists = N.sqrt( 106 N.sum( 107 (self.__data.samples - p )**2, axis=1 108 ) 109 ) 110 # get the k nearest neighbours from the sorted list of distances 111 knn = dists.argsort()[:self.__k] 112 113 # finally get the class label 114 prediction, vote = self.__votingfx(knn) 115 predicted.append(prediction) 116 votes.append(vote) 117 118 # store the predictions in the state. Relies on State._setitem to do 119 # nothing if the relevant state member is not enabled 120 self.predictions = predicted 121 self.values = votes 122 123 return predicted
124 125
126 - def getMajorityVote(self, knn_ids):
127 """Simple voting by choosing the majority of class neighbours. 128 """ 129 130 uniquelabels = self.__data.uniquelabels 131 132 # translate knn ids into class labels 133 knn_labels = N.array([ self.__data.labels[nn] for nn in knn_ids ]) 134 135 # number of occerences for each unique class in kNNs 136 votes = self.__votes_init.copy() 137 for nn in knn_ids: 138 votes[self.__labels[nn]] += 1 139 140 # find the class with most votes 141 # return votes as well to store them in the state 142 return uniquelabels[N.asarray(votes).argmax()], \ 143 votes
144 145
146 - def getWeightedVote(self, knn_ids):
147 """Vote with classes weighted by the number of samples per class. 148 """ 149 uniquelabels = self.__data.uniquelabels 150 151 # Lazy evaluation 152 if self.__weights is None: 153 # 154 # It seemed to Yarik that this has to be evaluated just once per 155 # training dataset. 156 # 157 self.__labels = self.__data.labels 158 Nlabels = len(self.__labels) 159 Nuniquelabels = len(uniquelabels) 160 161 # TODO: To get proper speed up for the next line only, 162 # histogram should be computed 163 # via sorting + counting "same" elements while reducing. 164 # Guaranteed complexity is NlogN whenever now it is N^2 165 # compute the relative proportion of samples belonging to each 166 # class (do it in one loop to improve speed and reduce readability 167 self.__weights = \ 168 [ 1.0 - ((self.__labels == label).sum() / Nlabels) \ 169 for label in uniquelabels ] 170 self.__weights = dict(zip(uniquelabels, self.__weights)) 171 172 173 # number of occerences for each unique class in kNNs 174 votes = self.__votes_init.copy() 175 for nn in knn_ids: 176 votes[self.__labels[nn]] += 1 177 178 # weight votes 179 votes = [ self.__weights[ul] * votes[ul] for ul in uniquelabels] 180 181 # find the class with most votes 182 # return votes as well to store them in the state 183 return uniquelabels[N.asarray(votes).argmax()], \ 184 votes
185 186
187 - def untrain(self):
188 """Reset trained state""" 189 self.__data = None 190 super(kNN, self).untrain()
191