Package mvpa :: Package algorithms :: Module cvtranserror
[hide private]
[frames] | no frames]

Source Code for Module mvpa.algorithms.cvtranserror

  1  #emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  #ex: set sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Cross-validate a classifier on a dataset""" 
 10   
 11  __docformat__ = 'restructuredtext' 
 12   
 13  from mvpa.misc.copy import deepcopy 
 14   
 15  from mvpa.measures.base import DatasetMeasure 
 16  from mvpa.datasets.splitter import NoneSplitter 
 17  from mvpa.base import warning 
 18  from mvpa.misc.state import StateVariable, Harvestable 
 19  from mvpa.misc.transformers import GrandMean 
 20   
 21  if __debug__: 
 22      from mvpa.base import debug 
 23   
 24   
25 -class CrossValidatedTransferError(DatasetMeasure, Harvestable):
26 """Cross validate a classifier on datasets generated by a splitter from a 27 source dataset. 28 29 Arbitrary performance/error values can be computed by specifying an error 30 function (used to compute an error value for each cross-validation fold) 31 and a combiner function that aggregates all computed error values across 32 cross-validation folds. 33 """ 34 35 results = StateVariable(enabled=False, doc= 36 """Store individual results in the state""") 37 splits = StateVariable(enabled=False, doc= 38 """Store the actual splits of the data. Can be memory expensive""") 39 transerrors = StateVariable(enabled=False, doc= 40 """Store copies of transerrors at each step""") 41 confusion = StateVariable(enabled=False, doc= 42 """Store total confusion matrix (if available)""") 43 training_confusion = StateVariable(enabled=False, doc= 44 """Store total training confusion matrix (if available)""") 45 samples_error = StateVariable(enabled=False, 46 doc="Per sample errors.") 47 48
49 - def __init__(self, 50 transerror, 51 splitter=NoneSplitter(), 52 combiner=GrandMean, 53 expose_testdataset=False, 54 harvest_attribs=None, 55 copy_attribs='copy', 56 **kwargs):
57 """ 58 Cheap initialization. 59 60 :Parameters: 61 transerror : TransferError instance 62 Provides the classifier used for cross-validation. 63 splitter : Splitter instance 64 Used to split the dataset for cross-validation folds. By 65 convention the first dataset in the tuple returned by the 66 splitter is used to train the provided classifier. If the 67 first element is 'None' no training is performed. The second 68 dataset is used to generate predictions with the (trained) 69 classifier. 70 combiner : Functor 71 Used to aggregate the error values of all cross-validation 72 folds. 73 expose_testdataset : bool 74 In the proper pipeline, classifier must not know anything 75 about testing data, but in some cases it might lead only 76 to marginal harm, thus migth wanted to be enabled (provide 77 testdataset for RFE to determine stopping point). 78 harvest_attribs : list of basestr 79 What attributes of call to store and return within 80 harvested state variable 81 copy_attribs : None or basestr 82 Force copying values of attributes on harvesting 83 """ 84 DatasetMeasure.__init__(self, **kwargs) 85 Harvestable.__init__(self, harvest_attribs, copy_attribs) 86 87 self.__splitter = splitter 88 self.__transerror = transerror 89 self.__combiner = combiner 90 self.__expose_testdataset = expose_testdataset
91 92 # TODO: put back in ASAP 93 # def __repr__(self): 94 # """String summary over the object 95 # """ 96 # return """CrossValidatedTransferError / 97 # splitter: %s 98 # classifier: %s 99 # errorfx: %s 100 # combiner: %s""" % (indentDoc(self.__splitter), indentDoc(self.__clf), 101 # indentDoc(self.__errorfx), indentDoc(self.__combiner)) 102 103
104 - def _call(self, dataset):
105 """Perform cross-validation on a dataset. 106 107 'dataset' is passed to the splitter instance and serves as the source 108 dataset to generate split for the single cross-validation folds. 109 """ 110 # store the results of the splitprocessor 111 results = [] 112 self.splits = [] 113 114 # local bindings 115 states = self.states 116 clf = self.__transerror.clf 117 expose_testdataset = self.__expose_testdataset 118 119 # what states to enable in terr 120 terr_enable = [] 121 for state_var in ['confusion', 'training_confusion', 'samples_error']: 122 if states.isEnabled(state_var): 123 terr_enable += [state_var] 124 125 # charge states with initial values 126 summaryClass = clf._summaryClass 127 clf_hastestdataset = hasattr(clf, 'testdataset') 128 129 self.confusion = summaryClass() 130 self.training_confusion = summaryClass() 131 self.transerrors = [] 132 self.samples_error = dict([(id, []) for id in dataset.origids]) 133 134 # enable requested states in child TransferError instance (restored 135 # again below) 136 if len(terr_enable): 137 self.__transerror.states._changeTemporarily( 138 enable_states=terr_enable) 139 140 # splitter 141 for split in self.__splitter(dataset): 142 # only train classifier if splitter provides something in first 143 # element of tuple -- the is the behavior of TransferError 144 if states.isEnabled("splits"): 145 self.splits.append(split) 146 147 if states.isEnabled("transerrors"): 148 # copy first and then train, as some classifiers cannot be copied 149 # when already trained, e.g. SWIG'ed stuff 150 transerror = deepcopy(self.__transerror) 151 else: 152 transerror = self.__transerror 153 154 # assign testing dataset if given classifier can digest it 155 if clf_hastestdataset and expose_testdataset: 156 clf.testdataset = split[1] 157 pass 158 159 # run the beast 160 result = transerror(split[1], split[0]) 161 162 # unbind the testdataset from the classifier 163 if clf_hastestdataset and expose_testdataset: 164 clf.testdataset = None 165 166 # next line is important for 'self._harvest' call 167 self._harvest(locals()) 168 169 # XXX Look below -- may be we should have not auto added .? 170 # then transerrors also could be deprecated 171 if states.isEnabled("transerrors"): 172 self.transerrors.append(transerror) 173 174 # XXX: could be merged with next for loop using a utility class 175 # that can add dict elements into a list 176 if states.isEnabled("samples_error"): 177 for k, v in \ 178 transerror.states.getvalue("samples_error").iteritems(): 179 self.samples_error[k].append(v) 180 181 # pull in child states 182 for state_var in ['confusion', 'training_confusion']: 183 if states.isEnabled(state_var): 184 states.getvalue(state_var).__iadd__( 185 transerror.states.getvalue(state_var)) 186 187 if __debug__: 188 debug("CROSSC", "Split #%d: result %s" \ 189 % (len(results), `result`)) 190 results.append(result) 191 192 # Since we could have operated with a copy -- bind the last used one back 193 self.__transerror = transerror 194 195 # put states of child TransferError back into original config 196 if len(terr_enable): 197 self.__transerror.states._resetEnabledTemporarily() 198 199 self.results = results 200 """Store state variable if it is enabled""" 201 202 # Provide those labels_map if appropriate 203 try: 204 if states.isEnabled("confusion"): 205 states.confusion.labels_map = dataset.labels_map 206 if states.isEnabled("training_confusion"): 207 states.training_confusion.labels_map = dataset.labels_map 208 except: 209 pass 210 211 return self.__combiner(results)
212 213 214 splitter = property(fget=lambda self:self.__splitter) 215 transerror = property(fget=lambda self:self.__transerror) 216 combiner = property(fget=lambda self:self.__combiner)
217