1
2
3
4
5
6
7
8
9 """Cross-validate a classifier on a dataset"""
10
11 __docformat__ = 'restructuredtext'
12
13 from copy import copy
14
15 from mvpa.measures.base import DatasetMeasure
16 from mvpa.datasets.splitter import NoneSplitter
17 from mvpa.clfs.transerror import ConfusionMatrix
18 from mvpa.misc import warning
19 from mvpa.misc.state import StateVariable, Harvestable
20 from mvpa.misc.transformers import GrandMean
21
22 if __debug__:
23 from mvpa.misc import debug
24
25
26
28 """Cross validate a classifier on datasets generated by a splitter from a
29 source dataset.
30
31 Arbitrary performance/error values can be computed by specifying an error
32 function (used to compute an error value for each cross-validation fold)
33 and a combiner function that aggregates all computed error values across
34 cross-validation folds.
35 """
36
37 results = StateVariable(enabled=False, doc=
38 """Store individual results in the state""")
39 splits = StateVariable(enabled=False, doc=
40 """Store the actual splits of the data. Can be memory expensive""")
41 transerrors = StateVariable(enabled=False, doc=
42 """Store copies of transerrors at each step""")
43 confusion = StateVariable(enabled=False, doc=
44 """Store total confusion matrix (if available)""")
45 training_confusion = StateVariable(enabled=False, doc=
46 """Store total training confusion matrix (if available)""")
47
48
49 - def __init__(self,
50 transerror,
51 splitter=NoneSplitter(),
52 combiner=GrandMean,
53 harvest_attribs=None,
54 copy_attribs='copy',
55 **kwargs):
56 """
57 Cheap initialization.
58
59 :Parameters:
60 transerror : TransferError instance
61 Provides the classifier used for cross-validation.
62 splitter : Splitter instance
63 Used to split the dataset for cross-validation folds. By
64 convention the first dataset in the tuple returned by the
65 splitter is used to train the provided classifier. If the
66 first element is 'None' no training is performed. The second
67 dataset is used to generate predictions with the (trained)
68 classifier.
69 combiner : Functor
70 Used to aggregate the error values of all cross-validation
71 folds.
72 harvest_attribs : list of basestr
73 What attributes of call to store and return within
74 harvested state variable
75 copy_attribs : None or basestr
76 Force copying values of attributes on harvesting
77 """
78 DatasetMeasure.__init__(self, **kwargs)
79 Harvestable.__init__(self, harvest_attribs, copy_attribs)
80
81 self.__splitter = splitter
82 self.__transerror = transerror
83 self.__combiner = combiner
84
85
86
87
88
89
90
91
92
93
94
95
96
97 - def _call(self, dataset):
98 """Perform cross-validation on a dataset.
99
100 'dataset' is passed to the splitter instance and serves as the source
101 dataset to generate split for the single cross-validation folds.
102 """
103
104 results = []
105 self.splits = []
106
107
108 terr_enable = []
109 for state_var in ['confusion', 'training_confusion']:
110 if self.states.isEnabled(state_var):
111 terr_enable += [state_var]
112
113
114 self.confusion = ConfusionMatrix()
115 self.training_confusion = ConfusionMatrix()
116 self.transerrors = []
117
118
119
120 if len(terr_enable):
121 self.__transerror.states._changeTemporarily(
122 enable_states=terr_enable)
123
124
125 for split in self.__splitter(dataset):
126
127
128 if self.states.isEnabled("splits"):
129 self.splits.append(split)
130
131 result = self.__transerror(split[1], split[0])
132
133
134 transerror = self.__transerror
135 self._harvest(locals())
136
137
138
139 if self.states.isEnabled("transerrors"):
140 self.transerrors.append(copy(self.__transerror))
141
142 for state_var in ['confusion', 'training_confusion']:
143 if self.states.isEnabled(state_var):
144 self.states.getvalue(state_var).__iadd__(
145 self.__transerror.states.getvalue(state_var))
146
147 if __debug__:
148 debug("CROSSC", "Split #%d: result %s" \
149 % (len(results), `result`))
150 results.append(result)
151
152
153 if len(terr_enable):
154 self.__transerror.states._resetEnabledTemporarily()
155
156 self.results = results
157 """Store state variable if it is enabled"""
158
159 return self.__combiner(results)
160
161
162 splitter = property(fget=lambda self:self.__splitter)
163 transerror = property(fget=lambda self:self.__transerror)
164 combiner = property(fget=lambda self:self.__combiner)
165