Package mvpa :: Package clfs :: Module knn
[hide private]
[frames] | no frames]

Source Code for Module mvpa.clfs.knn

  1  #emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  #ex: set sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """k-Nearest-Neighbour classifier.""" 
 10   
 11  __docformat__ = 'restructuredtext' 
 12   
 13   
 14  import numpy as N 
 15   
 16  from mvpa.base import warning 
 17  from mvpa.misc.support import indentDoc 
 18  from mvpa.clfs.base import Classifier 
 19  from mvpa.base.dochelpers import enhancedDocString 
 20  from mvpa.clfs.distance import squared_euclidean_distance 
 21   
 22  if __debug__: 
 23      from mvpa.base import debug 
 24   
 25   
26 -class kNN(Classifier):
27 """k-nearest-neighbour classifier. 28 29 If enabled, it stores the votes per class in the 'values' state after 30 calling predict(). 31 """ 32 33 _clf_internals = [ 'knn', 'non-linear', 'binary', 'multiclass', 'notrain2predict' ] 34
35 - def __init__(self, k=2, dfx=squared_euclidean_distance, 36 voting='weighted', **kwargs):
37 """ 38 :Parameters: 39 k: unsigned integer 40 Number of nearest neighbours to be used for voting. 41 dfx: functor 42 Function to compute the distances between training and test samples. 43 Default: squared euclidean distance 44 voting: str 45 Voting method used to derive predictions from the nearest neighbors. 46 Possible values are 'majority' (simple majority of classes 47 determines vote) and 'weighted' (votes are weighted according to the 48 relative frequencies of each class in the training data). 49 **kwargs: 50 Additonal arguments are passed to the base class. 51 """ 52 53 # init base class first 54 Classifier.__init__(self, **kwargs) 55 56 self.__k = k 57 self.__dfx = dfx 58 self.__voting = voting 59 self.__data = None
60 61 62 __doc__ = enhancedDocString('kNN', locals(), Classifier) 63 64
65 - def __repr__(self):
66 """Representation of the object 67 """ 68 return "kNN(k=%d, enable_states=%s)" % \ 69 (self.__k, str(self.states.enabled))
70 71
72 - def __str__(self):
73 return "%s\n data: %s" % \ 74 (Classifier.__str__(self), indentDoc(self.__data))
75 76
77 - def _train(self, data):
78 """Train the classifier. 79 80 For kNN it is degenerate -- just stores the data. 81 """ 82 self.__data = data 83 if __debug__: 84 if str(data.samples.dtype).startswith('uint') \ 85 or str(data.samples.dtype).startswith('int'): 86 warning("kNN: input data is in integers. " + \ 87 "Overflow on arithmetic operations might result in"+\ 88 " errors. Please convert dataset's samples into" +\ 89 " floating datatype if any error is reported.") 90 self.__weights = None 91 92 # create dictionary with an item for each condition 93 uniquelabels = data.uniquelabels 94 self.__votes_init = dict(zip(uniquelabels, 95 [0] * len(uniquelabels)))
96 97
98 - def _predict(self, data):
99 """Predict the class labels for the provided data. 100 101 Returns a list of class labels (one for each data sample). 102 """ 103 # make sure we're talking about arrays 104 data = N.asarray(data) 105 106 # checks only in debug mode 107 if __debug__: 108 if not data.ndim == 2: 109 raise ValueError, "Data array must be two-dimensional." 110 111 if not data.shape[1] == self.__data.nfeatures: 112 raise ValueError, "Length of data samples (features) does " \ 113 "not match the classifier." 114 115 # compute the distance matrix between training and test data with 116 # distances stored row-wise, ie. distances between test sample [0] 117 # and all training samples will end up in row 0 118 dists = self.__dfx(self.__data.samples, data).T 119 120 # determine the k nearest neighbors per test sample 121 knns = dists.argsort(axis=1)[:, :self.__k] 122 123 # predicted class labels will go here 124 predicted = [] 125 votes = [] 126 127 if self.__voting == 'majority': 128 vfx = self.getMajorityVote 129 elif self.__voting == 'weighted': 130 vfx = self.getWeightedVote 131 else: 132 raise ValueError, "kNN told to perform unknown voting '%s'." % self.__voting 133 134 # perform voting 135 results = [vfx(knn) for knn in knns] 136 137 # extract predictions 138 predicted = [r[0] for r in results] 139 140 # store the predictions in the state. Relies on State._setitem to do 141 # nothing if the relevant state member is not enabled 142 self.predictions = predicted 143 self.values = [r[1] for r in results] 144 145 return predicted
146 147
148 - def getMajorityVote(self, knn_ids):
149 """Simple voting by choosing the majority of class neighbours. 150 """ 151 152 uniquelabels = self.__data.uniquelabels 153 154 # translate knn ids into class labels 155 knn_labels = N.array([ self.__data.labels[nn] for nn in knn_ids ]) 156 157 # number of occerences for each unique class in kNNs 158 votes = self.__votes_init.copy() 159 for nn in knn_ids: 160 votes[self.__labels[nn]] += 1 161 162 # find the class with most votes 163 # return votes as well to store them in the state 164 return uniquelabels[N.asarray(votes).argmax()], \ 165 votes
166 167
168 - def getWeightedVote(self, knn_ids):
169 """Vote with classes weighted by the number of samples per class. 170 """ 171 uniquelabels = self.__data.uniquelabels 172 173 # Lazy evaluation 174 if self.__weights is None: 175 # 176 # It seemed to Yarik that this has to be evaluated just once per 177 # training dataset. 178 # 179 self.__labels = self.__data.labels 180 Nlabels = len(self.__labels) 181 Nuniquelabels = len(uniquelabels) 182 183 # TODO: To get proper speed up for the next line only, 184 # histogram should be computed 185 # via sorting + counting "same" elements while reducing. 186 # Guaranteed complexity is NlogN whenever now it is N^2 187 # compute the relative proportion of samples belonging to each 188 # class (do it in one loop to improve speed and reduce readability 189 self.__weights = \ 190 [ 1.0 - ((self.__labels == label).sum() / Nlabels) \ 191 for label in uniquelabels ] 192 self.__weights = dict(zip(uniquelabels, self.__weights)) 193 194 195 # number of occerences for each unique class in kNNs 196 votes = self.__votes_init.copy() 197 for nn in knn_ids: 198 votes[self.__labels[nn]] += 1 199 200 # weight votes 201 votes = [ self.__weights[ul] * votes[ul] for ul in uniquelabels] 202 203 # find the class with most votes 204 # return votes as well to store them in the state 205 return uniquelabels[N.asarray(votes).argmax()], \ 206 votes
207 208
209 - def untrain(self):
210 """Reset trained state""" 211 self.__data = None 212 super(kNN, self).untrain()
213