Package mvpa :: Package clfs :: Module base
[hide private]
[frames] | no frames]

Source Code for Module mvpa.clfs.base

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Base class for all classifiers. 
 10   
 11  At the moment, regressions are treated just as a special case of 
 12  classifier (or vise verse), so the same base class `Classifier` is 
 13  utilized for both kinds. 
 14  """ 
 15   
 16  __docformat__ = 'restructuredtext' 
 17   
 18  import numpy as N 
 19   
 20  from mvpa.support.copy import deepcopy 
 21   
 22  import time 
 23   
 24  from mvpa.misc.support import idhash 
 25  from mvpa.misc.state import StateVariable, ClassWithCollections 
 26  from mvpa.misc.param import Parameter 
 27   
 28  from mvpa.clfs.transerror import ConfusionMatrix, RegressionStatistics 
 29   
 30  from mvpa.base import warning 
 31   
 32  if __debug__: 
 33      from mvpa.base import debug 
34 35 -class DegenerateInputError(Exception):
36 """Exception to be thrown by learners if input data is bogus, i.e. no 37 features or samples""" 38 pass
39
40 -class FailedToTrainError(Exception):
41 """Exception to be thrown whenever classifier fails to learn for 42 some reason""" 43 pass
44
45 -class Classifier(ClassWithCollections):
46 """Abstract classifier class to be inherited by all classifiers 47 """ 48 49 # Kept separate from doc to don't pollute help(clf), especially if 50 # we including help for the parent class 51 _DEV__doc__ = """ 52 Required behavior: 53 54 For every classifier is has to be possible to be instantiated without 55 having to specify the training pattern. 56 57 Repeated calls to the train() method with different training data have to 58 result in a valid classifier, trained for the particular dataset. 59 60 It must be possible to specify all classifier parameters as keyword 61 arguments to the constructor. 62 63 Recommended behavior: 64 65 Derived classifiers should provide access to *values* -- i.e. that 66 information that is finally used to determine the predicted class label. 67 68 Michael: Maybe it works well if each classifier provides a 'values' 69 state member. This variable is a list as long as and in same order 70 as Dataset.uniquelabels (training data). Each item in the list 71 corresponds to the likelyhood of a sample to belong to the 72 respective class. However the semantics might differ between 73 classifiers, e.g. kNN would probably store distances to class- 74 neighbors, where PLR would store the raw function value of the 75 logistic function. So in the case of kNN low is predictive and for 76 PLR high is predictive. Don't know if there is the need to unify 77 that. 78 79 As the storage and/or computation of this information might be 80 demanding its collection should be switchable and off be default. 81 82 Nomenclature 83 * predictions : corresponds to the quantized labels if classifier spits 84 out labels by .predict() 85 * values : might be different from predictions if a classifier's predict() 86 makes a decision based on some internal value such as 87 probability or a distance. 88 """ 89 # Dict that contains the parameters of a classifier. 90 # This shall provide an interface to plug generic parameter optimizer 91 # on all classifiers (e.g. grid- or line-search optimizer) 92 # A dictionary is used because Michael thinks that access by name is nicer. 93 # Additionally Michael thinks ATM that additional information might be 94 # necessary in some situations (e.g. reasonably predefined parameter range, 95 # minimal iteration stepsize, ...), therefore the value to each key should 96 # also be a dict or we should use mvpa.misc.param.Parameter'... 97 98 trained_labels = StateVariable(enabled=True, 99 doc="Set of unique labels it has been trained on") 100 101 trained_nsamples = StateVariable(enabled=True, 102 doc="Number of samples it has been trained on") 103 104 trained_dataset = StateVariable(enabled=False, 105 doc="The dataset it has been trained on") 106 107 training_confusion = StateVariable(enabled=False, 108 doc="Confusion matrix of learning performance") 109 110 predictions = StateVariable(enabled=True, 111 doc="Most recent set of predictions") 112 113 values = StateVariable(enabled=True, 114 doc="Internal classifier values the most recent " + 115 "predictions are based on") 116 117 training_time = StateVariable(enabled=True, 118 doc="Time (in seconds) which took classifier to train") 119 120 predicting_time = StateVariable(enabled=True, 121 doc="Time (in seconds) which took classifier to predict") 122 123 feature_ids = StateVariable(enabled=False, 124 doc="Feature IDS which were used for the actual training.") 125 126 _clf_internals = [] 127 """Describes some specifics about the classifier -- is that it is 128 doing regression for instance....""" 129 130 regression = Parameter(False, allowedtype='bool', 131 doc="""Either to use 'regression' as regression. By default any 132 Classifier-derived class serves as a classifier, so regression 133 does binary classification.""", index=1001) 134 135 # TODO: make it available only for actually retrainable classifiers 136 retrainable = Parameter(False, allowedtype='bool', 137 doc="""Either to enable retraining for 'retrainable' classifier.""", 138 index=1002) 139 140
141 - def __init__(self, **kwargs):
142 """Cheap initialization. 143 """ 144 ClassWithCollections.__init__(self, **kwargs) 145 146 147 self.__trainednfeatures = None 148 """Stores number of features for which classifier was trained. 149 If None -- it wasn't trained at all""" 150 151 self._setRetrainable(self.params.retrainable, force=True) 152 153 if self.params.regression: 154 for statevar in [ "trained_labels"]: #, "training_confusion" ]: 155 if self.states.isEnabled(statevar): 156 if __debug__: 157 debug("CLF", 158 "Disabling state %s since doing regression, " % 159 statevar + "not classification") 160 self.states.disable(statevar) 161 self._summaryClass = RegressionStatistics 162 else: 163 self._summaryClass = ConfusionMatrix 164 clf_internals = self._clf_internals 165 if 'regression' in clf_internals and not ('binary' in clf_internals): 166 # regressions are used as binary classifiers if not 167 # asked to perform regression explicitly 168 # We need a copy of the list, so we don't override class-wide 169 self._clf_internals = clf_internals + ['binary']
170 171 # deprecate 172 #self.__trainedidhash = None 173 #"""Stores id of the dataset on which it was trained to signal 174 #in trained() if it was trained already on the same dataset""" 175 176
177 - def __str__(self):
178 if __debug__ and 'CLF_' in debug.active: 179 return "%s / %s" % (repr(self), super(Classifier, self).__str__()) 180 else: 181 return repr(self)
182
183 - def __repr__(self, prefixes=[]):
184 return super(Classifier, self).__repr__(prefixes=prefixes)
185 186
187 - def _pretrain(self, dataset):
188 """Functionality prior to training 189 """ 190 # So we reset all state variables and may be free up some memory 191 # explicitly 192 params = self.params 193 if not params.retrainable: 194 self.untrain() 195 else: 196 # just reset the states, do not untrain 197 self.states.reset() 198 if not self.__changedData_isset: 199 self.__resetChangedData() 200 _changedData = self._changedData 201 __idhashes = self.__idhashes 202 __invalidatedChangedData = self.__invalidatedChangedData 203 204 # if we don't know what was changed we need to figure 205 # them out 206 if __debug__: 207 debug('CLF_', "IDHashes are %s" % (__idhashes)) 208 209 # Look at the data if any was changed 210 for key, data_ in (('traindata', dataset.samples), 211 ('labels', dataset.labels)): 212 _changedData[key] = self.__wasDataChanged(key, data_) 213 # if those idhashes were invalidated by retraining 214 # we need to adjust _changedData accordingly 215 if __invalidatedChangedData.get(key, False): 216 if __debug__ and not _changedData[key]: 217 debug('CLF_', 'Found that idhash for %s was ' 218 'invalidated by retraining' % key) 219 _changedData[key] = True 220 221 # Look at the parameters 222 for col in self._paramscols: 223 changedParams = self._collections[col].whichSet() 224 if len(changedParams): 225 _changedData[col] = changedParams 226 227 self.__invalidatedChangedData = {} # reset it on training 228 229 if __debug__: 230 debug('CLF_', "Obtained _changedData is %s" 231 % (self._changedData)) 232 233 if not params.regression and 'regression' in self._clf_internals \ 234 and not self.states.isEnabled('trained_labels'): 235 # if classifier internally does regression we need to have 236 # labels it was trained on 237 if __debug__: 238 debug("CLF", "Enabling trained_labels state since it is needed") 239 self.states.enable('trained_labels')
240 241
242 - def _posttrain(self, dataset):
243 """Functionality post training 244 245 For instance -- computing confusion matrix 246 :Parameters: 247 dataset : Dataset 248 Data which was used for training 249 """ 250 if self.states.isEnabled('trained_labels'): 251 self.trained_labels = dataset.uniquelabels 252 253 self.trained_dataset = dataset 254 self.trained_nsamples = dataset.nsamples 255 256 # needs to be assigned first since below we use predict 257 self.__trainednfeatures = dataset.nfeatures 258 259 if __debug__ and 'CHECK_TRAINED' in debug.active: 260 self.__trainedidhash = dataset.idhash 261 262 if self.states.isEnabled('training_confusion') and \ 263 not self.states.isSet('training_confusion'): 264 # we should not store predictions for training data, 265 # it is confusing imho (yoh) 266 self.states._changeTemporarily( 267 disable_states=["predictions"]) 268 if self.params.retrainable: 269 # we would need to recheck if data is the same, 270 # XXX think if there is a way to make this all 271 # efficient. For now, probably, retrainable 272 # classifiers have no chance but not to use 273 # training_confusion... sad 274 self.__changedData_isset = False 275 predictions = self.predict(dataset.samples) 276 self.states._resetEnabledTemporarily() 277 self.training_confusion = self._summaryClass( 278 targets=dataset.labels, 279 predictions=predictions) 280 281 try: 282 self.training_confusion.labels_map = dataset.labels_map 283 except: 284 pass 285 286 if self.states.isEnabled('feature_ids'): 287 self.feature_ids = self._getFeatureIds()
288 289
290 - def _getFeatureIds(self):
291 """Virtual method to return feature_ids used while training 292 293 Is not intended to be called anywhere but from _posttrain, 294 thus classifier is assumed to be trained at this point 295 """ 296 # By default all features are used 297 return range(self.__trainednfeatures)
298 299
300 - def summary(self):
301 """Providing summary over the classifier""" 302 303 s = "Classifier %s" % self 304 states = self.states 305 states_enabled = states.enabled 306 307 if self.trained: 308 s += "\n trained" 309 if states.isSet('training_time'): 310 s += ' in %.3g sec' % states.training_time 311 s += ' on data with' 312 if states.isSet('trained_labels'): 313 s += ' labels:%s' % list(states.trained_labels) 314 315 nsamples, nchunks = None, None 316 if states.isSet('trained_nsamples'): 317 nsamples = states.trained_nsamples 318 if states.isSet('trained_dataset'): 319 td = states.trained_dataset 320 nsamples, nchunks = td.nsamples, len(td.uniquechunks) 321 if nsamples is not None: 322 s += ' #samples:%d' % nsamples 323 if nchunks is not None: 324 s += ' #chunks:%d' % nchunks 325 326 s += " #features:%d" % self.__trainednfeatures 327 if states.isSet('feature_ids'): 328 s += ", used #features:%d" % len(states.feature_ids) 329 if states.isSet('training_confusion'): 330 s += ", training error:%.3g" % states.training_confusion.error 331 else: 332 s += "\n not yet trained" 333 334 if len(states_enabled): 335 s += "\n enabled states:%s" % ', '.join([str(states[x]) 336 for x in states_enabled]) 337 return s
338 339
340 - def clone(self):
341 """Create full copy of the classifier. 342 343 It might require classifier to be untrained first due to 344 present SWIG bindings. 345 346 TODO: think about proper re-implementation, without enrollment of deepcopy 347 """ 348 if __debug__: 349 debug("CLF", "Cloning %s#%s" % (self, id(self))) 350 try: 351 return deepcopy(self) 352 except: 353 self.untrain() 354 return deepcopy(self)
355 356
357 - def _train(self, dataset):
358 """Function to be actually overridden in derived classes 359 """ 360 raise NotImplementedError
361 362
363 - def train(self, dataset):
364 """Train classifier on a dataset 365 366 Shouldn't be overridden in subclasses unless explicitly needed 367 to do so 368 """ 369 if dataset.nfeatures == 0 or dataset.nsamples == 0: 370 raise DegenerateInputError, \ 371 "Cannot train classifier on degenerate data %s" % dataset 372 if __debug__: 373 debug("CLF", "Training classifier %(clf)s on dataset %(dataset)s", 374 msgargs={'clf':self, 'dataset':dataset}) 375 376 self._pretrain(dataset) 377 378 # remember the time when started training 379 t0 = time.time() 380 381 if dataset.nfeatures > 0: 382 result = self._train(dataset) 383 else: 384 warning("Trying to train on dataset with no features present") 385 if __debug__: 386 debug("CLF", 387 "No features present for training, no actual training " \ 388 "is called") 389 result = None 390 391 self.training_time = time.time() - t0 392 self._posttrain(dataset) 393 return result
394 395
396 - def _prepredict(self, data):
397 """Functionality prior prediction 398 """ 399 if not ('notrain2predict' in self._clf_internals): 400 # check if classifier was trained if that is needed 401 if not self.trained: 402 raise ValueError, \ 403 "Classifier %s wasn't yet trained, therefore can't " \ 404 "predict" % self 405 nfeatures = data.shape[1] 406 # check if number of features is the same as in the data 407 # it was trained on 408 if nfeatures != self.__trainednfeatures: 409 raise ValueError, \ 410 "Classifier %s was trained on data with %d features, " % \ 411 (self, self.__trainednfeatures) + \ 412 "thus can't predict for %d features" % nfeatures 413 414 415 if self.params.retrainable: 416 if not self.__changedData_isset: 417 self.__resetChangedData() 418 _changedData = self._changedData 419 _changedData['testdata'] = \ 420 self.__wasDataChanged('testdata', data) 421 if __debug__: 422 debug('CLF_', "prepredict: Obtained _changedData is %s" 423 % (_changedData))
424 425
426 - def _postpredict(self, data, result):
427 """Functionality after prediction is computed 428 """ 429 self.predictions = result 430 if self.params.retrainable: 431 self.__changedData_isset = False
432
433 - def _predict(self, data):
434 """Actual prediction 435 """ 436 raise NotImplementedError
437 438
439 - def predict(self, data):
440 """Predict classifier on data 441 442 Shouldn't be overridden in subclasses unless explicitly needed 443 to do so. Also subclasses trying to call super class's predict 444 should call _predict if within _predict instead of predict() 445 since otherwise it would loop 446 """ 447 data = N.asarray(data) 448 if __debug__: 449 debug("CLF", "Predicting classifier %(clf)s on data %(data)s", 450 msgargs={'clf':self, 'data':data.shape}) 451 452 # remember the time when started computing predictions 453 t0 = time.time() 454 455 states = self.states 456 # to assure that those are reset (could be set due to testing 457 # post-training) 458 states.reset(['values', 'predictions']) 459 460 self._prepredict(data) 461 462 if self.__trainednfeatures > 0 \ 463 or 'notrain2predict' in self._clf_internals: 464 result = self._predict(data) 465 else: 466 warning("Trying to predict using classifier trained on no features") 467 if __debug__: 468 debug("CLF", 469 "No features were present for training, prediction is " \ 470 "bogus") 471 result = [None]*data.shape[0] 472 473 states.predicting_time = time.time() - t0 474 475 if 'regression' in self._clf_internals and not self.params.regression: 476 # We need to convert regression values into labels 477 # XXX unify may be labels -> internal_labels conversion. 478 #if len(self.trained_labels) != 2: 479 # raise RuntimeError, "Ask developer to implement for " \ 480 # "multiclass mapping from regression into classification" 481 482 # must be N.array so we copy it to assign labels directly 483 # into labels, or should we just recreate "result"??? 484 result_ = N.array(result) 485 if states.isEnabled('values'): 486 # values could be set by now so assigning 'result' would 487 # be misleading 488 if not states.isSet('values'): 489 states.values = result_.copy() 490 else: 491 # it might be the values are pointing to result at 492 # the moment, so lets assure this silly way that 493 # they do not overlap 494 states.values = N.array(states.values, copy=True) 495 496 trained_labels = self.trained_labels 497 for i, value in enumerate(result): 498 dists = N.abs(value - trained_labels) 499 result[i] = trained_labels[N.argmin(dists)] 500 501 if __debug__: 502 debug("CLF_", "Converted regression result %(result_)s " 503 "into labels %(result)s for %(self_)s", 504 msgargs={'result_':result_, 'result':result, 505 'self_': self}) 506 507 self._postpredict(data, result) 508 return result
509 510 # deprecate ???
511 - def isTrained(self, dataset=None):
512 """Either classifier was already trained. 513 514 MUST BE USED WITH CARE IF EVER""" 515 if dataset is None: 516 # simply return if it was trained on anything 517 return not self.__trainednfeatures is None 518 else: 519 res = (self.__trainednfeatures == dataset.nfeatures) 520 if __debug__ and 'CHECK_TRAINED' in debug.active: 521 res2 = (self.__trainedidhash == dataset.idhash) 522 if res2 != res: 523 raise RuntimeError, \ 524 "isTrained is weak and shouldn't be relied upon. " \ 525 "Got result %b although comparing of idhash says %b" \ 526 % (res, res2) 527 return res
528 529
530 - def _regressionIsBogus(self):
531 """Some classifiers like BinaryClassifier can't be used for 532 regression""" 533 534 if self.params.regression: 535 raise ValueError, "Regression mode is meaningless for %s" % \ 536 self.__class__.__name__ + " thus don't enable it"
537 538 539 @property
540 - def trained(self):
541 """Either classifier was already trained""" 542 return self.isTrained()
543
544 - def untrain(self):
545 """Reset trained state""" 546 self.__trainednfeatures = None 547 # probably not needed... retrainable shouldn't be fully untrained 548 # or should be??? 549 #if self.params.retrainable: 550 # # ??? don't duplicate the code ;-) 551 # self.__idhashes = {'traindata': None, 'labels': None, 552 # 'testdata': None, 'testtraindata': None} 553 super(Classifier, self).reset()
554 555
556 - def getSensitivityAnalyzer(self, **kwargs):
557 """Factory method to return an appropriate sensitivity analyzer for 558 the respective classifier.""" 559 raise NotImplementedError
560 561 562 # 563 # Methods which are needed for retrainable classifiers 564 #
565 - def _setRetrainable(self, value, force=False):
566 """Assign value of retrainable parameter 567 568 If retrainable flag is to be changed, classifier has to be 569 untrained. Also internal attributes such as _changedData, 570 __changedData_isset, and __idhashes should be initialized if 571 it becomes retrainable 572 """ 573 pretrainable = self.params['retrainable'] 574 if (force or value != pretrainable.value) \ 575 and 'retrainable' in self._clf_internals: 576 if __debug__: 577 debug("CLF_", "Setting retrainable to %s" % value) 578 if 'meta' in self._clf_internals: 579 warning("Retrainability is not yet crafted/tested for " 580 "meta classifiers. Unpredictable behavior might occur") 581 # assure that we don't drag anything behind 582 if self.trained: 583 self.untrain() 584 states = self.states 585 if not value and states.isKnown('retrained'): 586 states.remove('retrained') 587 states.remove('repredicted') 588 if value: 589 if not 'retrainable' in self._clf_internals: 590 warning("Setting of flag retrainable for %s has no effect" 591 " since classifier has no such capability. It would" 592 " just lead to resources consumption and slowdown" 593 % self) 594 states.add(StateVariable(enabled=True, 595 name='retrained', 596 doc="Either retrainable classifier was retrained")) 597 states.add(StateVariable(enabled=True, 598 name='repredicted', 599 doc="Either retrainable classifier was repredicted")) 600 601 pretrainable.value = value 602 603 # if retrainable we need to keep track of things 604 if value: 605 self.__idhashes = {'traindata': None, 'labels': None, 606 'testdata': None} #, 'testtraindata': None} 607 if __debug__ and 'CHECK_RETRAIN' in debug.active: 608 # ??? it is not clear though if idhash is faster than 609 # simple comparison of (dataset != __traineddataset).any(), 610 # but if we like to get rid of __traineddataset then we 611 # should use idhash anyways 612 self.__trained = self.__idhashes.copy() # just same Nones 613 self.__resetChangedData() 614 self.__invalidatedChangedData = {} 615 elif 'retrainable' in self._clf_internals: 616 #self.__resetChangedData() 617 self.__changedData_isset = False 618 self._changedData = None 619 self.__idhashes = None 620 if __debug__ and 'CHECK_RETRAIN' in debug.active: 621 self.__trained = None
622
623 - def __resetChangedData(self):
624 """For retrainable classifier we keep track of what was changed 625 This function resets that dictionary 626 """ 627 if __debug__: 628 debug('CLF_', 629 'Retrainable: resetting flags on either data was changed') 630 keys = self.__idhashes.keys() + self._paramscols 631 # we might like to just reinit values to False??? 632 #_changedData = self._changedData 633 #if isinstance(_changedData, dict): 634 # for key in _changedData.keys(): 635 # _changedData[key] = False 636 self._changedData = dict(zip(keys, [False]*len(keys))) 637 self.__changedData_isset = False
638 639
640 - def __wasDataChanged(self, key, entry, update=True):
641 """Check if given entry was changed from what known prior. 642 643 If so -- store only the ones needed for retrainable beastie 644 """ 645 idhash_ = idhash(entry) 646 __idhashes = self.__idhashes 647 648 changed = __idhashes[key] != idhash_ 649 if __debug__ and 'CHECK_RETRAIN' in debug.active: 650 __trained = self.__trained 651 changed2 = entry != __trained[key] 652 if isinstance(changed2, N.ndarray): 653 changed2 = changed2.any() 654 if changed != changed2 and not changed: 655 raise RuntimeError, \ 656 'idhash found to be weak for %s. Though hashid %s!=%s %s, '\ 657 'values %s!=%s %s' % \ 658 (key, idhash_, __idhashes[key], changed, 659 entry, __trained[key], changed2) 660 if update: 661 __trained[key] = entry 662 663 if __debug__ and changed: 664 debug('CLF_', "Changed %s from %s to %s.%s" 665 % (key, __idhashes[key], idhash_, 666 ('','updated')[int(update)])) 667 if update: 668 __idhashes[key] = idhash_ 669 670 return changed
671 672 673 # def __updateHashIds(self, key, data): 674 # """Is twofold operation: updates hashid if was said that it changed. 675 # 676 # or if it wasn't said that data changed, but CHECK_RETRAIN and it found 677 # to be changed -- raise Exception 678 # """ 679 # 680 # check_retrain = __debug__ and 'CHECK_RETRAIN' in debug.active 681 # chd = self._changedData 682 # 683 # # we need to updated idhashes 684 # if chd[key] or check_retrain: 685 # keychanged = self.__wasDataChanged(key, data) 686 # if check_retrain and keychanged and not chd[key]: 687 # raise RuntimeError, \ 688 # "Data %s found changed although wasn't " \ 689 # "labeled as such" % key 690 691 692 # 693 # Additional API which is specific only for retrainable classifiers. 694 # For now it would just puke if asked from not retrainable one. 695 # 696 # Might come useful and efficient for statistics testing, so if just 697 # labels of dataset changed, then 698 # self.retrain(dataset, labels=True) 699 # would cause efficient retraining (no kernels recomputed etc) 700 # and subsequent self.repredict(data) should be also quite fase ;-) 701
702 - def retrain(self, dataset, **kwargs):
703 """Helper to avoid check if data was changed actually changed 704 705 Useful if just some aspects of classifier were changed since 706 its previous training. For instance if dataset wasn't changed 707 but only classifier parameters, then kernel matrix does not 708 have to be computed. 709 710 Words of caution: classifier must be previously trained, 711 results always should first be compared to the results on not 712 'retrainable' classifier (without calling retrain). Some 713 additional checks are enabled if debug id 'CHECK_RETRAIN' is 714 enabled, to guard against obvious mistakes. 715 716 :Parameters: 717 kwargs 718 that is what _changedData gets updated with. So, smth like 719 ``(params=['C'], labels=True)`` if parameter C and labels 720 got changed 721 """ 722 # Note that it also demolishes anything for repredicting, 723 # which should be ok in most of the cases 724 if __debug__: 725 if not self.params.retrainable: 726 raise RuntimeError, \ 727 "Do not use re(train,predict) on non-retrainable %s" % \ 728 self 729 730 if kwargs.has_key('params') or kwargs.has_key('kernel_params'): 731 raise ValueError, \ 732 "Retraining for changed params not working yet" 733 734 self.__resetChangedData() 735 736 # local bindings 737 chd = self._changedData 738 ichd = self.__invalidatedChangedData 739 740 chd.update(kwargs) 741 # mark for future 'train()' items which are explicitely 742 # mentioned as changed 743 for key, value in kwargs.iteritems(): 744 if value: 745 ichd[key] = True 746 self.__changedData_isset = True 747 748 # To check if we are not fooled 749 if __debug__ and 'CHECK_RETRAIN' in debug.active: 750 for key, data_ in (('traindata', dataset.samples), 751 ('labels', dataset.labels)): 752 # so it wasn't told to be invalid 753 if not chd[key] and not ichd.get(key, False): 754 if self.__wasDataChanged(key, data_, update=False): 755 raise RuntimeError, \ 756 "Data %s found changed although wasn't " \ 757 "labeled as such" % key 758 759 # TODO: parameters of classifiers... for now there is explicit 760 # 'forbidance' above 761 762 # Below check should be superseeded by check above, thus never occur. 763 # remove later on ??? 764 if __debug__ and 'CHECK_RETRAIN' in debug.active and self.trained \ 765 and not self._changedData['traindata'] \ 766 and self.__trained['traindata'].shape != dataset.samples.shape: 767 raise ValueError, "In retrain got dataset with %s size, " \ 768 "whenever previousely was trained on %s size" \ 769 % (dataset.samples.shape, self.__trained['traindata'].shape) 770 self.train(dataset)
771 772
773 - def repredict(self, data, **kwargs):
774 """Helper to avoid check if data was changed actually changed 775 776 Useful if classifier was (re)trained but with the same data 777 (so just parameters were changed), so that it could be 778 repredicted easily (on the same data as before) without 779 recomputing for instance train/test kernel matrix. Should be 780 used with caution and always compared to the results on not 781 'retrainable' classifier. Some additional checks are enabled 782 if debug id 'CHECK_RETRAIN' is enabled, to guard against 783 obvious mistakes. 784 785 :Parameters: 786 data 787 data which is conventionally given to predict 788 kwargs 789 that is what _changedData gets updated with. So, smth like 790 ``(params=['C'], labels=True)`` if parameter C and labels 791 got changed 792 """ 793 if len(kwargs)>0: 794 raise RuntimeError, \ 795 "repredict for now should be used without params since " \ 796 "it makes little sense to repredict if anything got changed" 797 if __debug__ and not self.params.retrainable: 798 raise RuntimeError, \ 799 "Do not use retrain/repredict on non-retrainable classifiers" 800 801 self.__resetChangedData() 802 chd = self._changedData 803 chd.update(**kwargs) 804 self.__changedData_isset = True 805 806 807 # check if we are attempted to perform on the same data 808 if __debug__ and 'CHECK_RETRAIN' in debug.active: 809 for key, data_ in (('testdata', data),): 810 # so it wasn't told to be invalid 811 #if not chd[key]:# and not ichd.get(key, False): 812 if self.__wasDataChanged(key, data_, update=False): 813 raise RuntimeError, \ 814 "Data %s found changed although wasn't " \ 815 "labeled as such" % key 816 817 # Should be superseded by above 818 # remove in future??? 819 if __debug__ and 'CHECK_RETRAIN' in debug.active \ 820 and not self._changedData['testdata'] \ 821 and self.__trained['testdata'].shape != data.shape: 822 raise ValueError, "In repredict got dataset with %s size, " \ 823 "whenever previously was trained on %s size" \ 824 % (data.shape, self.__trained['testdata'].shape) 825 826 return self.predict(data)
827 828 829 # TODO: callback into retrainable parameter 830 #retrainable = property(fget=_getRetrainable, fset=_setRetrainable, 831 # doc="Specifies either classifier should be retrainable") 832