1
2
3
4
5
6
7
8
9 """k-Nearest-Neighbour classifier."""
10
11 __docformat__ = 'restructuredtext'
12
13
14 import numpy as N
15
16 from mvpa.base import warning
17 from mvpa.misc.support import indentDoc
18 from mvpa.clfs.base import Classifier
19 from mvpa.base.dochelpers import enhancedDocString
20 from mvpa.clfs.distance import squared_euclidean_distance
21
22 if __debug__:
23 from mvpa.base import debug
24
25
26 -class kNN(Classifier):
27 """k-nearest-neighbour classifier.
28
29 If enabled, it stores the votes per class in the 'values' state after
30 calling predict().
31 """
32
33 _clf_internals = [ 'knn', 'non-linear', 'binary', 'multiclass', 'notrain2predict' ]
34
37 """
38 :Parameters:
39 k: unsigned integer
40 Number of nearest neighbours to be used for voting.
41 dfx: functor
42 Function to compute the distances between training and test samples.
43 Default: squared euclidean distance
44 voting: str
45 Voting method used to derive predictions from the nearest neighbors.
46 Possible values are 'majority' (simple majority of classes
47 determines vote) and 'weighted' (votes are weighted according to the
48 relative frequencies of each class in the training data).
49 **kwargs:
50 Additonal arguments are passed to the base class.
51 """
52
53
54 Classifier.__init__(self, **kwargs)
55
56 self.__k = k
57 self.__dfx = dfx
58 self.__voting = voting
59 self.__data = None
60
61
62 __doc__ = enhancedDocString('kNN', locals(), Classifier)
63
64
66 """Representation of the object
67 """
68 return "kNN(k=%d, enable_states=%s)" % \
69 (self.__k, str(self.states.enabled))
70
71
75
76
78 """Train the classifier.
79
80 For kNN it is degenerate -- just stores the data.
81 """
82 self.__data = data
83 if __debug__:
84 if str(data.samples.dtype).startswith('uint') \
85 or str(data.samples.dtype).startswith('int'):
86 warning("kNN: input data is in integers. " + \
87 "Overflow on arithmetic operations might result in"+\
88 " errors. Please convert dataset's samples into" +\
89 " floating datatype if any error is reported.")
90 self.__weights = None
91
92
93 uniquelabels = data.uniquelabels
94 self.__votes_init = dict(zip(uniquelabels,
95 [0] * len(uniquelabels)))
96
97
99 """Predict the class labels for the provided data.
100
101 Returns a list of class labels (one for each data sample).
102 """
103
104 data = N.asarray(data)
105
106
107 if __debug__:
108 if not data.ndim == 2:
109 raise ValueError, "Data array must be two-dimensional."
110
111 if not data.shape[1] == self.__data.nfeatures:
112 raise ValueError, "Length of data samples (features) does " \
113 "not match the classifier."
114
115
116
117
118 dists = self.__dfx(self.__data.samples, data).T
119
120
121 knns = dists.argsort(axis=1)[:, :self.__k]
122
123
124 predicted = []
125 votes = []
126
127 if self.__voting == 'majority':
128 vfx = self.getMajorityVote
129 elif self.__voting == 'weighted':
130 vfx = self.getWeightedVote
131 else:
132 raise ValueError, "kNN told to perform unknown voting '%s'." % self.__voting
133
134
135 results = [vfx(knn) for knn in knns]
136
137
138 predicted = [r[0] for r in results]
139
140
141
142 self.predictions = predicted
143 self.values = [r[1] for r in results]
144
145 return predicted
146
147
149 """Simple voting by choosing the majority of class neighbours.
150 """
151
152 uniquelabels = self.__data.uniquelabels
153
154
155 knn_labels = N.array([ self.__data.labels[nn] for nn in knn_ids ])
156
157
158 votes = self.__votes_init.copy()
159 for nn in knn_ids:
160 votes[self.__labels[nn]] += 1
161
162
163
164 return uniquelabels[N.asarray(votes).argmax()], \
165 votes
166
167
169 """Vote with classes weighted by the number of samples per class.
170 """
171 uniquelabels = self.__data.uniquelabels
172
173
174 if self.__weights is None:
175
176
177
178
179 self.__labels = self.__data.labels
180 Nlabels = len(self.__labels)
181 Nuniquelabels = len(uniquelabels)
182
183
184
185
186
187
188
189 self.__weights = \
190 [ 1.0 - ((self.__labels == label).sum() / Nlabels) \
191 for label in uniquelabels ]
192 self.__weights = dict(zip(uniquelabels, self.__weights))
193
194
195
196 votes = self.__votes_init.copy()
197 for nn in knn_ids:
198 votes[self.__labels[nn]] += 1
199
200
201 votes = [ self.__weights[ul] * votes[ul] for ul in uniquelabels]
202
203
204
205 return uniquelabels[N.asarray(votes).argmax()], \
206 votes
207
208
210 """Reset trained state"""
211 self.__data = None
212 super(kNN, self).untrain()
213