Package mvpa :: Package algorithms :: Module cvtranserror
[hide private]
[frames] | no frames]

Source Code for Module mvpa.algorithms.cvtranserror

  1  #emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  #ex: set sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Cross-validate a classifier on a dataset""" 
 10   
 11  __docformat__ = 'restructuredtext' 
 12   
 13  from copy import copy 
 14   
 15  from mvpa.measures.base import DatasetMeasure 
 16  from mvpa.datasets.splitter import NoneSplitter 
 17  from mvpa.clfs.transerror import ConfusionMatrix 
 18  from mvpa.misc import warning 
 19  from mvpa.misc.state import StateVariable, Harvestable 
 20  from mvpa.misc.transformers import GrandMean 
 21   
 22  if __debug__: 
 23      from mvpa.misc import debug 
 24   
 25   
 26   
27 -class CrossValidatedTransferError(DatasetMeasure, Harvestable):
28 """Cross validate a classifier on datasets generated by a splitter from a 29 source dataset. 30 31 Arbitrary performance/error values can be computed by specifying an error 32 function (used to compute an error value for each cross-validation fold) 33 and a combiner function that aggregates all computed error values across 34 cross-validation folds. 35 """ 36 37 results = StateVariable(enabled=False, doc= 38 """Store individual results in the state""") 39 splits = StateVariable(enabled=False, doc= 40 """Store the actual splits of the data. Can be memory expensive""") 41 transerrors = StateVariable(enabled=False, doc= 42 """Store copies of transerrors at each step""") 43 confusion = StateVariable(enabled=False, doc= 44 """Store total confusion matrix (if available)""") 45 training_confusion = StateVariable(enabled=False, doc= 46 """Store total training confusion matrix (if available)""") 47 48
49 - def __init__(self, 50 transerror, 51 splitter=NoneSplitter(), 52 combiner=GrandMean, 53 harvest_attribs=None, 54 copy_attribs='copy', 55 **kwargs):
56 """ 57 Cheap initialization. 58 59 :Parameters: 60 transerror : TransferError instance 61 Provides the classifier used for cross-validation. 62 splitter : Splitter instance 63 Used to split the dataset for cross-validation folds. By 64 convention the first dataset in the tuple returned by the 65 splitter is used to train the provided classifier. If the 66 first element is 'None' no training is performed. The second 67 dataset is used to generate predictions with the (trained) 68 classifier. 69 combiner : Functor 70 Used to aggregate the error values of all cross-validation 71 folds. 72 harvest_attribs : list of basestr 73 What attributes of call to store and return within 74 harvested state variable 75 copy_attribs : None or basestr 76 Force copying values of attributes on harvesting 77 """ 78 DatasetMeasure.__init__(self, **kwargs) 79 Harvestable.__init__(self, harvest_attribs, copy_attribs) 80 81 self.__splitter = splitter 82 self.__transerror = transerror 83 self.__combiner = combiner
84 85 # TODO: put back in ASAP 86 # def __repr__(self): 87 # """String summary over the object 88 # """ 89 # return """CrossValidatedTransferError / 90 # splitter: %s 91 # classifier: %s 92 # errorfx: %s 93 # combiner: %s""" % (indentDoc(self.__splitter), indentDoc(self.__clf), 94 # indentDoc(self.__errorfx), indentDoc(self.__combiner)) 95 96
97 - def _call(self, dataset):
98 """Perform cross-validation on a dataset. 99 100 'dataset' is passed to the splitter instance and serves as the source 101 dataset to generate split for the single cross-validation folds. 102 """ 103 # store the results of the splitprocessor 104 results = [] 105 self.splits = [] 106 107 # what states to enable in terr 108 terr_enable = [] 109 for state_var in ['confusion', 'training_confusion']: 110 if self.states.isEnabled(state_var): 111 terr_enable += [state_var] 112 113 # charge states with initial values 114 self.confusion = ConfusionMatrix() 115 self.training_confusion = ConfusionMatrix() 116 self.transerrors = [] 117 118 # enable requested states in child TransferError instance (restored 119 # again below) 120 if len(terr_enable): 121 self.__transerror.states._changeTemporarily( 122 enable_states=terr_enable) 123 124 # splitter 125 for split in self.__splitter(dataset): 126 # only train classifier if splitter provides something in first 127 # element of tuple -- the is the behavior of TransferError 128 if self.states.isEnabled("splits"): 129 self.splits.append(split) 130 131 result = self.__transerror(split[1], split[0]) 132 133 # next line is important for 'self._harvest' call 134 transerror = self.__transerror 135 self._harvest(locals()) 136 137 # XXX Look below -- may be we should have not auto added .? 138 # then transerrors also could be deprecated 139 if self.states.isEnabled("transerrors"): 140 self.transerrors.append(copy(self.__transerror)) 141 142 for state_var in ['confusion', 'training_confusion']: 143 if self.states.isEnabled(state_var): 144 self.states.getvalue(state_var).__iadd__( 145 self.__transerror.states.getvalue(state_var)) 146 147 if __debug__: 148 debug("CROSSC", "Split #%d: result %s" \ 149 % (len(results), `result`)) 150 results.append(result) 151 152 # put states of child TransferError back into original config 153 if len(terr_enable): 154 self.__transerror.states._resetEnabledTemporarily() 155 156 self.results = results 157 """Store state variable if it is enabled""" 158 159 return self.__combiner(results)
160 161 162 splitter = property(fget=lambda self:self.__splitter) 163 transerror = property(fget=lambda self:self.__transerror) 164 combiner = property(fget=lambda self:self.__combiner)
165