1
2
3
4
5
6
7
8
9 """Base class for all classifiers.
10
11 At the moment, regressions are treated just as a special case of
12 classifier (or vise verse), so the same base class `Classifier` is
13 utilized for both kinds.
14 """
15
16 __docformat__ = 'restructuredtext'
17
18 import numpy as N
19
20 from mvpa.support.copy import deepcopy
21
22 import time
23
24 from mvpa.misc.support import idhash
25 from mvpa.misc.state import StateVariable, ClassWithCollections
26 from mvpa.misc.param import Parameter
27
28 from mvpa.clfs.transerror import ConfusionMatrix, RegressionStatistics
29
30 from mvpa.base import warning
31
32 if __debug__:
33 from mvpa.base import debug
39
41 """Exception to be thrown whenever classifier fails to learn for
42 some reason"""
43 pass
44
46 """Abstract classifier class to be inherited by all classifiers
47 """
48
49
50
51 _DEV__doc__ = """
52 Required behavior:
53
54 For every classifier is has to be possible to be instantiated without
55 having to specify the training pattern.
56
57 Repeated calls to the train() method with different training data have to
58 result in a valid classifier, trained for the particular dataset.
59
60 It must be possible to specify all classifier parameters as keyword
61 arguments to the constructor.
62
63 Recommended behavior:
64
65 Derived classifiers should provide access to *values* -- i.e. that
66 information that is finally used to determine the predicted class label.
67
68 Michael: Maybe it works well if each classifier provides a 'values'
69 state member. This variable is a list as long as and in same order
70 as Dataset.uniquelabels (training data). Each item in the list
71 corresponds to the likelyhood of a sample to belong to the
72 respective class. However the semantics might differ between
73 classifiers, e.g. kNN would probably store distances to class-
74 neighbors, where PLR would store the raw function value of the
75 logistic function. So in the case of kNN low is predictive and for
76 PLR high is predictive. Don't know if there is the need to unify
77 that.
78
79 As the storage and/or computation of this information might be
80 demanding its collection should be switchable and off be default.
81
82 Nomenclature
83 * predictions : corresponds to the quantized labels if classifier spits
84 out labels by .predict()
85 * values : might be different from predictions if a classifier's predict()
86 makes a decision based on some internal value such as
87 probability or a distance.
88 """
89
90
91
92
93
94
95
96
97
98 trained_labels = StateVariable(enabled=True,
99 doc="Set of unique labels it has been trained on")
100
101 trained_nsamples = StateVariable(enabled=True,
102 doc="Number of samples it has been trained on")
103
104 trained_dataset = StateVariable(enabled=False,
105 doc="The dataset it has been trained on")
106
107 training_confusion = StateVariable(enabled=False,
108 doc="Confusion matrix of learning performance")
109
110 predictions = StateVariable(enabled=True,
111 doc="Most recent set of predictions")
112
113 values = StateVariable(enabled=True,
114 doc="Internal classifier values the most recent " +
115 "predictions are based on")
116
117 training_time = StateVariable(enabled=True,
118 doc="Time (in seconds) which took classifier to train")
119
120 predicting_time = StateVariable(enabled=True,
121 doc="Time (in seconds) which took classifier to predict")
122
123 feature_ids = StateVariable(enabled=False,
124 doc="Feature IDS which were used for the actual training.")
125
126 _clf_internals = []
127 """Describes some specifics about the classifier -- is that it is
128 doing regression for instance...."""
129
130 regression = Parameter(False, allowedtype='bool',
131 doc="""Either to use 'regression' as regression. By default any
132 Classifier-derived class serves as a classifier, so regression
133 does binary classification.""", index=1001)
134
135
136 retrainable = Parameter(False, allowedtype='bool',
137 doc="""Either to enable retraining for 'retrainable' classifier.""",
138 index=1002)
139
140
142 """Cheap initialization.
143 """
144 ClassWithCollections.__init__(self, **kwargs)
145
146
147 self.__trainednfeatures = None
148 """Stores number of features for which classifier was trained.
149 If None -- it wasn't trained at all"""
150
151 self._setRetrainable(self.params.retrainable, force=True)
152
153 if self.params.regression:
154 for statevar in [ "trained_labels"]:
155 if self.states.isEnabled(statevar):
156 if __debug__:
157 debug("CLF",
158 "Disabling state %s since doing regression, " %
159 statevar + "not classification")
160 self.states.disable(statevar)
161 self._summaryClass = RegressionStatistics
162 else:
163 self._summaryClass = ConfusionMatrix
164 clf_internals = self._clf_internals
165 if 'regression' in clf_internals and not ('binary' in clf_internals):
166
167
168
169 self._clf_internals = clf_internals + ['binary']
170
171
172
173
174
175
176
178 if __debug__ and 'CLF_' in debug.active:
179 return "%s / %s" % (repr(self), super(Classifier, self).__str__())
180 else:
181 return repr(self)
182
185
186
188 """Functionality prior to training
189 """
190
191
192 params = self.params
193 if not params.retrainable:
194 self.untrain()
195 else:
196
197 self.states.reset()
198 if not self.__changedData_isset:
199 self.__resetChangedData()
200 _changedData = self._changedData
201 __idhashes = self.__idhashes
202 __invalidatedChangedData = self.__invalidatedChangedData
203
204
205
206 if __debug__:
207 debug('CLF_', "IDHashes are %s" % (__idhashes))
208
209
210 for key, data_ in (('traindata', dataset.samples),
211 ('labels', dataset.labels)):
212 _changedData[key] = self.__wasDataChanged(key, data_)
213
214
215 if __invalidatedChangedData.get(key, False):
216 if __debug__ and not _changedData[key]:
217 debug('CLF_', 'Found that idhash for %s was '
218 'invalidated by retraining' % key)
219 _changedData[key] = True
220
221
222 for col in self._paramscols:
223 changedParams = self._collections[col].whichSet()
224 if len(changedParams):
225 _changedData[col] = changedParams
226
227 self.__invalidatedChangedData = {}
228
229 if __debug__:
230 debug('CLF_', "Obtained _changedData is %s"
231 % (self._changedData))
232
233 if not params.regression and 'regression' in self._clf_internals \
234 and not self.states.isEnabled('trained_labels'):
235
236
237 if __debug__:
238 debug("CLF", "Enabling trained_labels state since it is needed")
239 self.states.enable('trained_labels')
240
241
242 - def _posttrain(self, dataset):
243 """Functionality post training
244
245 For instance -- computing confusion matrix
246 :Parameters:
247 dataset : Dataset
248 Data which was used for training
249 """
250 if self.states.isEnabled('trained_labels'):
251 self.trained_labels = dataset.uniquelabels
252
253 self.trained_dataset = dataset
254 self.trained_nsamples = dataset.nsamples
255
256
257 self.__trainednfeatures = dataset.nfeatures
258
259 if __debug__ and 'CHECK_TRAINED' in debug.active:
260 self.__trainedidhash = dataset.idhash
261
262 if self.states.isEnabled('training_confusion') and \
263 not self.states.isSet('training_confusion'):
264
265
266 self.states._changeTemporarily(
267 disable_states=["predictions"])
268 if self.params.retrainable:
269
270
271
272
273
274 self.__changedData_isset = False
275 predictions = self.predict(dataset.samples)
276 self.states._resetEnabledTemporarily()
277 self.training_confusion = self._summaryClass(
278 targets=dataset.labels,
279 predictions=predictions)
280
281 try:
282 self.training_confusion.labels_map = dataset.labels_map
283 except:
284 pass
285
286 if self.states.isEnabled('feature_ids'):
287 self.feature_ids = self._getFeatureIds()
288
289
291 """Virtual method to return feature_ids used while training
292
293 Is not intended to be called anywhere but from _posttrain,
294 thus classifier is assumed to be trained at this point
295 """
296
297 return range(self.__trainednfeatures)
298
299
301 """Providing summary over the classifier"""
302
303 s = "Classifier %s" % self
304 states = self.states
305 states_enabled = states.enabled
306
307 if self.trained:
308 s += "\n trained"
309 if states.isSet('training_time'):
310 s += ' in %.3g sec' % states.training_time
311 s += ' on data with'
312 if states.isSet('trained_labels'):
313 s += ' labels:%s' % list(states.trained_labels)
314
315 nsamples, nchunks = None, None
316 if states.isSet('trained_nsamples'):
317 nsamples = states.trained_nsamples
318 if states.isSet('trained_dataset'):
319 td = states.trained_dataset
320 nsamples, nchunks = td.nsamples, len(td.uniquechunks)
321 if nsamples is not None:
322 s += ' #samples:%d' % nsamples
323 if nchunks is not None:
324 s += ' #chunks:%d' % nchunks
325
326 s += " #features:%d" % self.__trainednfeatures
327 if states.isSet('feature_ids'):
328 s += ", used #features:%d" % len(states.feature_ids)
329 if states.isSet('training_confusion'):
330 s += ", training error:%.3g" % states.training_confusion.error
331 else:
332 s += "\n not yet trained"
333
334 if len(states_enabled):
335 s += "\n enabled states:%s" % ', '.join([str(states[x])
336 for x in states_enabled])
337 return s
338
339
341 """Create full copy of the classifier.
342
343 It might require classifier to be untrained first due to
344 present SWIG bindings.
345
346 TODO: think about proper re-implementation, without enrollment of deepcopy
347 """
348 if __debug__:
349 debug("CLF", "Cloning %s#%s" % (self, id(self)))
350 try:
351 return deepcopy(self)
352 except:
353 self.untrain()
354 return deepcopy(self)
355
356
358 """Function to be actually overridden in derived classes
359 """
360 raise NotImplementedError
361
362
363 - def train(self, dataset):
364 """Train classifier on a dataset
365
366 Shouldn't be overridden in subclasses unless explicitly needed
367 to do so
368 """
369 if dataset.nfeatures == 0 or dataset.nsamples == 0:
370 raise DegenerateInputError, \
371 "Cannot train classifier on degenerate data %s" % dataset
372 if __debug__:
373 debug("CLF", "Training classifier %(clf)s on dataset %(dataset)s",
374 msgargs={'clf':self, 'dataset':dataset})
375
376 self._pretrain(dataset)
377
378
379 t0 = time.time()
380
381 if dataset.nfeatures > 0:
382 result = self._train(dataset)
383 else:
384 warning("Trying to train on dataset with no features present")
385 if __debug__:
386 debug("CLF",
387 "No features present for training, no actual training " \
388 "is called")
389 result = None
390
391 self.training_time = time.time() - t0
392 self._posttrain(dataset)
393 return result
394
395
397 """Functionality prior prediction
398 """
399 if not ('notrain2predict' in self._clf_internals):
400
401 if not self.trained:
402 raise ValueError, \
403 "Classifier %s wasn't yet trained, therefore can't " \
404 "predict" % self
405 nfeatures = data.shape[1]
406
407
408 if nfeatures != self.__trainednfeatures:
409 raise ValueError, \
410 "Classifier %s was trained on data with %d features, " % \
411 (self, self.__trainednfeatures) + \
412 "thus can't predict for %d features" % nfeatures
413
414
415 if self.params.retrainable:
416 if not self.__changedData_isset:
417 self.__resetChangedData()
418 _changedData = self._changedData
419 _changedData['testdata'] = \
420 self.__wasDataChanged('testdata', data)
421 if __debug__:
422 debug('CLF_', "prepredict: Obtained _changedData is %s"
423 % (_changedData))
424
425
426 - def _postpredict(self, data, result):
427 """Functionality after prediction is computed
428 """
429 self.predictions = result
430 if self.params.retrainable:
431 self.__changedData_isset = False
432
434 """Actual prediction
435 """
436 raise NotImplementedError
437
438
440 """Predict classifier on data
441
442 Shouldn't be overridden in subclasses unless explicitly needed
443 to do so. Also subclasses trying to call super class's predict
444 should call _predict if within _predict instead of predict()
445 since otherwise it would loop
446 """
447 data = N.asarray(data)
448 if __debug__:
449 debug("CLF", "Predicting classifier %(clf)s on data %(data)s",
450 msgargs={'clf':self, 'data':data.shape})
451
452
453 t0 = time.time()
454
455 states = self.states
456
457
458 states.reset(['values', 'predictions'])
459
460 self._prepredict(data)
461
462 if self.__trainednfeatures > 0 \
463 or 'notrain2predict' in self._clf_internals:
464 result = self._predict(data)
465 else:
466 warning("Trying to predict using classifier trained on no features")
467 if __debug__:
468 debug("CLF",
469 "No features were present for training, prediction is " \
470 "bogus")
471 result = [None]*data.shape[0]
472
473 states.predicting_time = time.time() - t0
474
475 if 'regression' in self._clf_internals and not self.params.regression:
476
477
478
479
480
481
482
483
484 result_ = N.array(result)
485 if states.isEnabled('values'):
486
487
488 if not states.isSet('values'):
489 states.values = result_.copy()
490 else:
491
492
493
494 states.values = N.array(states.values, copy=True)
495
496 trained_labels = self.trained_labels
497 for i, value in enumerate(result):
498 dists = N.abs(value - trained_labels)
499 result[i] = trained_labels[N.argmin(dists)]
500
501 if __debug__:
502 debug("CLF_", "Converted regression result %(result_)s "
503 "into labels %(result)s for %(self_)s",
504 msgargs={'result_':result_, 'result':result,
505 'self_': self})
506
507 self._postpredict(data, result)
508 return result
509
510
512 """Either classifier was already trained.
513
514 MUST BE USED WITH CARE IF EVER"""
515 if dataset is None:
516
517 return not self.__trainednfeatures is None
518 else:
519 res = (self.__trainednfeatures == dataset.nfeatures)
520 if __debug__ and 'CHECK_TRAINED' in debug.active:
521 res2 = (self.__trainedidhash == dataset.idhash)
522 if res2 != res:
523 raise RuntimeError, \
524 "isTrained is weak and shouldn't be relied upon. " \
525 "Got result %b although comparing of idhash says %b" \
526 % (res, res2)
527 return res
528
529
531 """Some classifiers like BinaryClassifier can't be used for
532 regression"""
533
534 if self.params.regression:
535 raise ValueError, "Regression mode is meaningless for %s" % \
536 self.__class__.__name__ + " thus don't enable it"
537
538
539 @property
541 """Either classifier was already trained"""
542 return self.isTrained()
543
545 """Reset trained state"""
546 self.__trainednfeatures = None
547
548
549
550
551
552
553 super(Classifier, self).reset()
554
555
557 """Factory method to return an appropriate sensitivity analyzer for
558 the respective classifier."""
559 raise NotImplementedError
560
561
562
563
564
566 """Assign value of retrainable parameter
567
568 If retrainable flag is to be changed, classifier has to be
569 untrained. Also internal attributes such as _changedData,
570 __changedData_isset, and __idhashes should be initialized if
571 it becomes retrainable
572 """
573 pretrainable = self.params['retrainable']
574 if (force or value != pretrainable.value) \
575 and 'retrainable' in self._clf_internals:
576 if __debug__:
577 debug("CLF_", "Setting retrainable to %s" % value)
578 if 'meta' in self._clf_internals:
579 warning("Retrainability is not yet crafted/tested for "
580 "meta classifiers. Unpredictable behavior might occur")
581
582 if self.trained:
583 self.untrain()
584 states = self.states
585 if not value and states.isKnown('retrained'):
586 states.remove('retrained')
587 states.remove('repredicted')
588 if value:
589 if not 'retrainable' in self._clf_internals:
590 warning("Setting of flag retrainable for %s has no effect"
591 " since classifier has no such capability. It would"
592 " just lead to resources consumption and slowdown"
593 % self)
594 states.add(StateVariable(enabled=True,
595 name='retrained',
596 doc="Either retrainable classifier was retrained"))
597 states.add(StateVariable(enabled=True,
598 name='repredicted',
599 doc="Either retrainable classifier was repredicted"))
600
601 pretrainable.value = value
602
603
604 if value:
605 self.__idhashes = {'traindata': None, 'labels': None,
606 'testdata': None}
607 if __debug__ and 'CHECK_RETRAIN' in debug.active:
608
609
610
611
612 self.__trained = self.__idhashes.copy()
613 self.__resetChangedData()
614 self.__invalidatedChangedData = {}
615 elif 'retrainable' in self._clf_internals:
616
617 self.__changedData_isset = False
618 self._changedData = None
619 self.__idhashes = None
620 if __debug__ and 'CHECK_RETRAIN' in debug.active:
621 self.__trained = None
622
624 """For retrainable classifier we keep track of what was changed
625 This function resets that dictionary
626 """
627 if __debug__:
628 debug('CLF_',
629 'Retrainable: resetting flags on either data was changed')
630 keys = self.__idhashes.keys() + self._paramscols
631
632
633
634
635
636 self._changedData = dict(zip(keys, [False]*len(keys)))
637 self.__changedData_isset = False
638
639
641 """Check if given entry was changed from what known prior.
642
643 If so -- store only the ones needed for retrainable beastie
644 """
645 idhash_ = idhash(entry)
646 __idhashes = self.__idhashes
647
648 changed = __idhashes[key] != idhash_
649 if __debug__ and 'CHECK_RETRAIN' in debug.active:
650 __trained = self.__trained
651 changed2 = entry != __trained[key]
652 if isinstance(changed2, N.ndarray):
653 changed2 = changed2.any()
654 if changed != changed2 and not changed:
655 raise RuntimeError, \
656 'idhash found to be weak for %s. Though hashid %s!=%s %s, '\
657 'values %s!=%s %s' % \
658 (key, idhash_, __idhashes[key], changed,
659 entry, __trained[key], changed2)
660 if update:
661 __trained[key] = entry
662
663 if __debug__ and changed:
664 debug('CLF_', "Changed %s from %s to %s.%s"
665 % (key, __idhashes[key], idhash_,
666 ('','updated')[int(update)]))
667 if update:
668 __idhashes[key] = idhash_
669
670 return changed
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702 - def retrain(self, dataset, **kwargs):
703 """Helper to avoid check if data was changed actually changed
704
705 Useful if just some aspects of classifier were changed since
706 its previous training. For instance if dataset wasn't changed
707 but only classifier parameters, then kernel matrix does not
708 have to be computed.
709
710 Words of caution: classifier must be previously trained,
711 results always should first be compared to the results on not
712 'retrainable' classifier (without calling retrain). Some
713 additional checks are enabled if debug id 'CHECK_RETRAIN' is
714 enabled, to guard against obvious mistakes.
715
716 :Parameters:
717 kwargs
718 that is what _changedData gets updated with. So, smth like
719 ``(params=['C'], labels=True)`` if parameter C and labels
720 got changed
721 """
722
723
724 if __debug__:
725 if not self.params.retrainable:
726 raise RuntimeError, \
727 "Do not use re(train,predict) on non-retrainable %s" % \
728 self
729
730 if kwargs.has_key('params') or kwargs.has_key('kernel_params'):
731 raise ValueError, \
732 "Retraining for changed params not working yet"
733
734 self.__resetChangedData()
735
736
737 chd = self._changedData
738 ichd = self.__invalidatedChangedData
739
740 chd.update(kwargs)
741
742
743 for key, value in kwargs.iteritems():
744 if value:
745 ichd[key] = True
746 self.__changedData_isset = True
747
748
749 if __debug__ and 'CHECK_RETRAIN' in debug.active:
750 for key, data_ in (('traindata', dataset.samples),
751 ('labels', dataset.labels)):
752
753 if not chd[key] and not ichd.get(key, False):
754 if self.__wasDataChanged(key, data_, update=False):
755 raise RuntimeError, \
756 "Data %s found changed although wasn't " \
757 "labeled as such" % key
758
759
760
761
762
763
764 if __debug__ and 'CHECK_RETRAIN' in debug.active and self.trained \
765 and not self._changedData['traindata'] \
766 and self.__trained['traindata'].shape != dataset.samples.shape:
767 raise ValueError, "In retrain got dataset with %s size, " \
768 "whenever previousely was trained on %s size" \
769 % (dataset.samples.shape, self.__trained['traindata'].shape)
770 self.train(dataset)
771
772
774 """Helper to avoid check if data was changed actually changed
775
776 Useful if classifier was (re)trained but with the same data
777 (so just parameters were changed), so that it could be
778 repredicted easily (on the same data as before) without
779 recomputing for instance train/test kernel matrix. Should be
780 used with caution and always compared to the results on not
781 'retrainable' classifier. Some additional checks are enabled
782 if debug id 'CHECK_RETRAIN' is enabled, to guard against
783 obvious mistakes.
784
785 :Parameters:
786 data
787 data which is conventionally given to predict
788 kwargs
789 that is what _changedData gets updated with. So, smth like
790 ``(params=['C'], labels=True)`` if parameter C and labels
791 got changed
792 """
793 if len(kwargs)>0:
794 raise RuntimeError, \
795 "repredict for now should be used without params since " \
796 "it makes little sense to repredict if anything got changed"
797 if __debug__ and not self.params.retrainable:
798 raise RuntimeError, \
799 "Do not use retrain/repredict on non-retrainable classifiers"
800
801 self.__resetChangedData()
802 chd = self._changedData
803 chd.update(**kwargs)
804 self.__changedData_isset = True
805
806
807
808 if __debug__ and 'CHECK_RETRAIN' in debug.active:
809 for key, data_ in (('testdata', data),):
810
811
812 if self.__wasDataChanged(key, data_, update=False):
813 raise RuntimeError, \
814 "Data %s found changed although wasn't " \
815 "labeled as such" % key
816
817
818
819 if __debug__ and 'CHECK_RETRAIN' in debug.active \
820 and not self._changedData['testdata'] \
821 and self.__trained['testdata'].shape != data.shape:
822 raise ValueError, "In repredict got dataset with %s size, " \
823 "whenever previously was trained on %s size" \
824 % (data.shape, self.__trained['testdata'].shape)
825
826 return self.predict(data)
827
828
829
830
831
832