Previous topic

Compare SMLR to Linear SVM Classifier

Next topic

Minimal Searchlight Example

This Page

Quick search

Classifier SweepΒΆ

This examples shows a test of various classifiers on different datasets.

>>> from mvpa.suite import *
>>>
>>> # no MVPA warnings during whole testsuite
>>> warning.handlers = []
>>>
>>> def main():
>>>
>>>     # fix seed or set to None for new each time
>>>     N.random.seed(44)
>>>
>>>
>>>     # Load Haxby dataset example
>>>     haxby1path = 'data'
>>>     attrs = SampleAttributes(os.path.join(haxby1path,
>>>                                           'attributes_literal.txt'))
>>>     haxby8 = NiftiDataset(samples=os.path.join(haxby1path,
>>>                                                'bold.nii.gz'),
>>>                           labels=attrs.labels,
>>>                           labels_map=True,
>>>                           chunks=attrs.chunks,
>>>                           mask=os.path.join(haxby1path, 'mask.nii.gz'),
>>>                           dtype=N.float32)
>>>
>>>     # preprocess slightly
>>>     rest_label = haxby8.labels_map['rest']
>>>     detrend(haxby8, perchunk=True, model='linear')
>>>     zscore(haxby8, perchunk=True, baselinelabels=[rest_label],
>>>            targetdtype='float32')
>>>     haxby8_no0 = haxby8.selectSamples(haxby8.labels != rest_label)
>>>
>>>     dummy2 = normalFeatureDataset(perlabel=30, nlabels=2,
>>>                                   nfeatures=100,
>>>                                   nchunks=6, nonbogus_features=[11, 10],
>>>                                   snr=3.0)
>>>
>>>     for (dataset, datasetdescr), clfs_ in \
>>>         [
>>>         ((dummy2,
>>>           "Dummy 2-class univariate with 2 useful features out of 100"),
>>>           clfs[:]),
>>>         ((pureMultivariateSignal(8, 3),
>>>           "Dummy XOR-pattern"),
>>>           clfs['non-linear']),
>>>         ((haxby8_no0,
>>>           "Haxby 8-cat subject 1"),
>>>           clfs['multiclass']),
>>>         ]:
>>>         print "%s\n %s" % (datasetdescr, dataset.summary(idhash=False))
>>>         print " Classifier                               " \
>>>               "%corr  #features\t train predict  full"
>>>         for clf in clfs_:
>>>             print "  %-40s: "  % clf.descr,
>>>             # Lets do splits/train/predict explicitely so we could track
>>>             # timing otherwise could be just
>>>             #cv = CrossValidatedTransferError(
>>>             #         TransferError(clf),
>>>             #         NFoldSplitter(),
>>>             #         enable_states=['confusion'])
>>>             #error = cv(dataset)
>>>             #print cv.confusion
>>>
>>>             # to report transfer error
>>>             confusion = ConfusionMatrix(labels_map=dataset.labels_map)
>>>             times = []
>>>             nf = []
>>>             t0 = time.time()
>>>             clf.states.enable('feature_ids')
>>>             for nfold, (training_ds, validation_ds) in \
>>>                     enumerate(NFoldSplitter()(dataset)):
>>>                 clf.train(training_ds)
>>>                 nf.append(len(clf.feature_ids))
>>>                 if nf[-1] == 0:
>>>                     break
>>>                 predictions = clf.predict(validation_ds.samples)
>>>                 confusion.add(validation_ds.labels, predictions)
>>>                 times.append([clf.training_time, clf.predicting_time])
>>>             if nf[-1] == 0:
>>>                 print "no features were selected. skipped"
>>>                 continue
>>>             tfull = time.time() - t0
>>>             times = N.mean(times, axis=0)
>>>             nf = N.mean(nf)
>>>             # print "\n", confusion
>>>             print "%5.1f%%   %-4d\t %.2fs  %.2fs   %.2fs" % \
>>>                   (confusion.percentCorrect, nf, times[0], times[1], tfull)
>>>
>>>
>>> if __name__ == "__main__":
>>>     main()

See also

The full source code of this example is included in the PyMVPA source distribution (doc/examples/clfs_examples.py).