1
2
3
4
5
6
7
8
9 """k-Nearest-Neighbour classifier."""
10
11 __docformat__ = 'restructuredtext'
12
13
14 import numpy as N
15
16 from mvpa.misc import warning
17 from mvpa.misc.support import indentDoc
18 from mvpa.clfs.base import Classifier
19
20
21 -class kNN(Classifier):
22 """k-nearest-neighbour classifier.
23
24 If enabled it stores the votes per class in the 'values' state after
25 calling predict().
26 """
27
28 __warned = False
29
30 _clf_internals = [ 'knn', 'non-linear', 'multiclass' ]
31
33 """
34 :Parameters:
35 k
36 number of nearest neighbours to be used for voting
37 """
38
39 Classifier.__init__(self, train2predict=False, **kwargs)
40
41 self.__k = k
42
43 self.__votingfx = self.getWeightedVote
44 self.__data = None
45
46
48 """Representation of the object
49 """
50 return "kNN(k=%d, enable_states=%s)" % \
51 (self.__k, str(self.states.enabled))
52
53
57
58
60 """Train the classifier.
61
62 For kNN it is degenerate -- just stores the data.
63 """
64 self.__data = data
65 if __debug__:
66 if not kNN.__warned and \
67 str(data.samples.dtype).startswith('uint') \
68 or str(data.samples.dtype).startswith('int'):
69 kNN.__warned = True
70 warning("kNN: input data is in integers. " + \
71 "Overflow on arithmetic operations might result in"+\
72 " errors. Please convert dataset's samples into" +\
73 " floating datatype if any error is reported.")
74 self.__weights = None
75
76
77 uniquelabels = data.uniquelabels
78 self.__votes_init = dict(zip(uniquelabels,
79 [0] * len(uniquelabels)))
80
81
83 """Predict the class labels for the provided data.
84
85 Returns a list of class labels (one for each data sample).
86 """
87
88 data = N.asarray(data)
89
90 if not data.ndim == 2:
91 raise ValueError, "Data array must be two-dimensional."
92
93 if not data.shape[1] == self.__data.nfeatures:
94 raise ValueError, "Length of data samples (features) does " \
95 "not match the classifier."
96
97
98 predicted = []
99 votes = []
100
101
102 for p in data:
103
104
105 dists = N.sqrt(
106 N.sum(
107 (self.__data.samples - p )**2, axis=1
108 )
109 )
110
111 knn = dists.argsort()[:self.__k]
112
113
114 prediction, vote = self.__votingfx(knn)
115 predicted.append(prediction)
116 votes.append(vote)
117
118
119
120 self.predictions = predicted
121 self.values = votes
122
123 return predicted
124
125
127 """Simple voting by choosing the majority of class neighbours.
128 """
129
130 uniquelabels = self.__data.uniquelabels
131
132
133 knn_labels = N.array([ self.__data.labels[nn] for nn in knn_ids ])
134
135
136 votes = self.__votes_init.copy()
137 for nn in knn_ids:
138 votes[self.__labels[nn]] += 1
139
140
141
142 return uniquelabels[N.asarray(votes).argmax()], \
143 votes
144
145
147 """Vote with classes weighted by the number of samples per class.
148 """
149 uniquelabels = self.__data.uniquelabels
150
151
152 if self.__weights is None:
153
154
155
156
157 self.__labels = self.__data.labels
158 Nlabels = len(self.__labels)
159 Nuniquelabels = len(uniquelabels)
160
161
162
163
164
165
166
167 self.__weights = \
168 [ 1.0 - ((self.__labels == label).sum() / Nlabels) \
169 for label in uniquelabels ]
170 self.__weights = dict(zip(uniquelabels, self.__weights))
171
172
173
174 votes = self.__votes_init.copy()
175 for nn in knn_ids:
176 votes[self.__labels[nn]] += 1
177
178
179 votes = [ self.__weights[ul] * votes[ul] for ul in uniquelabels]
180
181
182
183 return uniquelabels[N.asarray(votes).argmax()], \
184 votes
185
186
188 """Reset trained state"""
189 self.__data = None
190 super(kNN, self).untrain()
191