1
2
3
4
5
6
7
8
9 """Cross-validate a classifier on a dataset"""
10
11 __docformat__ = 'restructuredtext'
12
13 from mvpa.misc.copy import deepcopy
14
15 from mvpa.measures.base import DatasetMeasure
16 from mvpa.datasets.splitter import NoneSplitter
17 from mvpa.base import warning
18 from mvpa.misc.state import StateVariable, Harvestable
19 from mvpa.misc.transformers import GrandMean
20
21 if __debug__:
22 from mvpa.base import debug
23
24
26 """Cross validate a classifier on datasets generated by a splitter from a
27 source dataset.
28
29 Arbitrary performance/error values can be computed by specifying an error
30 function (used to compute an error value for each cross-validation fold)
31 and a combiner function that aggregates all computed error values across
32 cross-validation folds.
33 """
34
35 results = StateVariable(enabled=False, doc=
36 """Store individual results in the state""")
37 splits = StateVariable(enabled=False, doc=
38 """Store the actual splits of the data. Can be memory expensive""")
39 transerrors = StateVariable(enabled=False, doc=
40 """Store copies of transerrors at each step""")
41 confusion = StateVariable(enabled=False, doc=
42 """Store total confusion matrix (if available)""")
43 training_confusion = StateVariable(enabled=False, doc=
44 """Store total training confusion matrix (if available)""")
45 samples_error = StateVariable(enabled=False,
46 doc="Per sample errors.")
47
48
49 - def __init__(self,
50 transerror,
51 splitter=NoneSplitter(),
52 combiner=GrandMean,
53 expose_testdataset=False,
54 harvest_attribs=None,
55 copy_attribs='copy',
56 **kwargs):
57 """
58 Cheap initialization.
59
60 :Parameters:
61 transerror : TransferError instance
62 Provides the classifier used for cross-validation.
63 splitter : Splitter instance
64 Used to split the dataset for cross-validation folds. By
65 convention the first dataset in the tuple returned by the
66 splitter is used to train the provided classifier. If the
67 first element is 'None' no training is performed. The second
68 dataset is used to generate predictions with the (trained)
69 classifier.
70 combiner : Functor
71 Used to aggregate the error values of all cross-validation
72 folds.
73 expose_testdataset : bool
74 In the proper pipeline, classifier must not know anything
75 about testing data, but in some cases it might lead only
76 to marginal harm, thus migth wanted to be enabled (provide
77 testdataset for RFE to determine stopping point).
78 harvest_attribs : list of basestr
79 What attributes of call to store and return within
80 harvested state variable
81 copy_attribs : None or basestr
82 Force copying values of attributes on harvesting
83 """
84 DatasetMeasure.__init__(self, **kwargs)
85 Harvestable.__init__(self, harvest_attribs, copy_attribs)
86
87 self.__splitter = splitter
88 self.__transerror = transerror
89 self.__combiner = combiner
90 self.__expose_testdataset = expose_testdataset
91
92
93
94
95
96
97
98
99
100
101
102
103
104 - def _call(self, dataset):
105 """Perform cross-validation on a dataset.
106
107 'dataset' is passed to the splitter instance and serves as the source
108 dataset to generate split for the single cross-validation folds.
109 """
110
111 results = []
112 self.splits = []
113
114
115 states = self.states
116 clf = self.__transerror.clf
117 expose_testdataset = self.__expose_testdataset
118
119
120 terr_enable = []
121 for state_var in ['confusion', 'training_confusion', 'samples_error']:
122 if states.isEnabled(state_var):
123 terr_enable += [state_var]
124
125
126 summaryClass = clf._summaryClass
127 clf_hastestdataset = hasattr(clf, 'testdataset')
128
129 self.confusion = summaryClass()
130 self.training_confusion = summaryClass()
131 self.transerrors = []
132 self.samples_error = dict([(id, []) for id in dataset.origids])
133
134
135
136 if len(terr_enable):
137 self.__transerror.states._changeTemporarily(
138 enable_states=terr_enable)
139
140
141 for split in self.__splitter(dataset):
142
143
144 if states.isEnabled("splits"):
145 self.splits.append(split)
146
147 if states.isEnabled("transerrors"):
148
149
150 transerror = deepcopy(self.__transerror)
151 else:
152 transerror = self.__transerror
153
154
155 if clf_hastestdataset and expose_testdataset:
156 clf.testdataset = split[1]
157 pass
158
159
160 result = transerror(split[1], split[0])
161
162
163 if clf_hastestdataset and expose_testdataset:
164 clf.testdataset = None
165
166
167 self._harvest(locals())
168
169
170
171 if states.isEnabled("transerrors"):
172 self.transerrors.append(transerror)
173
174
175
176 if states.isEnabled("samples_error"):
177 for k, v in \
178 transerror.states.getvalue("samples_error").iteritems():
179 self.samples_error[k].append(v)
180
181
182 for state_var in ['confusion', 'training_confusion']:
183 if states.isEnabled(state_var):
184 states.getvalue(state_var).__iadd__(
185 transerror.states.getvalue(state_var))
186
187 if __debug__:
188 debug("CROSSC", "Split #%d: result %s" \
189 % (len(results), `result`))
190 results.append(result)
191
192
193 self.__transerror = transerror
194
195
196 if len(terr_enable):
197 self.__transerror.states._resetEnabledTemporarily()
198
199 self.results = results
200 """Store state variable if it is enabled"""
201
202
203 try:
204 if states.isEnabled("confusion"):
205 states.confusion.labels_map = dataset.labels_map
206 if states.isEnabled("training_confusion"):
207 states.training_confusion.labels_map = dataset.labels_map
208 except:
209 pass
210
211 return self.__combiner(results)
212
213
214 splitter = property(fget=lambda self:self.__splitter)
215 transerror = property(fget=lambda self:self.__transerror)
216 combiner = property(fget=lambda self:self.__combiner)
217