Package mvpa :: Package tests :: Module test_rfe
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_rfe

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Unit tests for PyMVPA recursive feature elimination""" 
 10   
 11  from sets import Set 
 12   
 13  from mvpa.datasets.splitters import NFoldSplitter 
 14  from mvpa.algorithms.cvtranserror import CrossValidatedTransferError 
 15  from mvpa.datasets.masked import MaskedDataset 
 16  from mvpa.measures.base import FeaturewiseDatasetMeasure 
 17  from mvpa.featsel.rfe import RFE 
 18  from mvpa.featsel.base import \ 
 19       SensitivityBasedFeatureSelection, \ 
 20       FeatureSelectionPipeline 
 21  from mvpa.featsel.helpers import \ 
 22       NBackHistoryStopCrit, FractionTailSelector, FixedErrorThresholdStopCrit, \ 
 23       MultiStopCrit, NStepsStopCrit, \ 
 24       FixedNElementTailSelector, BestDetector, RangeElementSelector 
 25   
 26  from mvpa.clfs.meta import FeatureSelectionClassifier, SplitClassifier 
 27  from mvpa.clfs.transerror import TransferError, ConfusionBasedError 
 28  from mvpa.clfs.stats import MCNullDist 
 29  from mvpa.misc.transformers import Absolute, FirstAxisMean 
 30   
 31  from mvpa.misc.state import UnknownStateError 
 32   
 33  from tests_warehouse import * 
 34  from tests_warehouse_clfs import * 
35 36 -class SillySensitivityAnalyzer(FeaturewiseDatasetMeasure):
37 """Simple one which just returns xrange[-N/2, N/2], where N is the 38 number of features 39 """ 40
41 - def __init__(self, mult=1, **kwargs):
42 FeaturewiseDatasetMeasure.__init__(self, **kwargs) 43 self.__mult = mult
44
45 - def __call__(self, dataset):
46 """Train linear SVM on `dataset` and extract weights from classifier. 47 """ 48 return( self.__mult *( N.arange(dataset.nfeatures) - int(dataset.nfeatures/2) ))
49
50 51 -class RFETests(unittest.TestCase):
52
53 - def getData(self):
54 return datasets['uni2medium_train']
55
56 - def getDataT(self):
57 return datasets['uni2medium_test']
58 59
60 - def testBestDetector(self):
61 bd = BestDetector() 62 63 # for empty history -- no best 64 self.failUnless(bd([]) == False) 65 # we got the best if we have just 1 66 self.failUnless(bd([1]) == True) 67 # we got the best if we have the last minimal 68 self.failUnless(bd([1, 0.9, 0.8]) == True) 69 70 # test for alternative func 71 bd = BestDetector(func=max) 72 self.failUnless(bd([0.8, 0.9, 1.0]) == True) 73 self.failUnless(bd([0.8, 0.9, 1.0]+[0.9]*9) == False) 74 self.failUnless(bd([0.8, 0.9, 1.0]+[0.9]*10) == False) 75 76 # test to detect earliest and latest minimum 77 bd = BestDetector(lastminimum=True) 78 self.failUnless(bd([3, 2, 1, 1, 1, 2, 1]) == True) 79 bd = BestDetector() 80 self.failUnless(bd([3, 2, 1, 1, 1, 2, 1]) == False)
81 82
83 - def testNBackHistoryStopCrit(self):
84 """Test stopping criterion""" 85 stopcrit = NBackHistoryStopCrit() 86 # for empty history -- no best but just go 87 self.failUnless(stopcrit([]) == False) 88 # should not stop if we got 10 more after minimal 89 self.failUnless(stopcrit( 90 [1, 0.9, 0.8]+[0.9]*(stopcrit.steps-1)) == False) 91 # should stop if we got 10 more after minimal 92 self.failUnless(stopcrit( 93 [1, 0.9, 0.8]+[0.9]*stopcrit.steps) == True) 94 95 # test for alternative func 96 stopcrit = NBackHistoryStopCrit(BestDetector(func=max)) 97 self.failUnless(stopcrit([0.8, 0.9, 1.0]+[0.9]*9) == False) 98 self.failUnless(stopcrit([0.8, 0.9, 1.0]+[0.9]*10) == True) 99 100 # test to detect earliest and latest minimum 101 stopcrit = NBackHistoryStopCrit(BestDetector(lastminimum=True)) 102 self.failUnless(stopcrit([3, 2, 1, 1, 1, 2, 1]) == False) 103 stopcrit = NBackHistoryStopCrit(steps=4) 104 self.failUnless(stopcrit([3, 2, 1, 1, 1, 2, 1]) == True)
105 106
108 """Test stopping criterion""" 109 stopcrit = FixedErrorThresholdStopCrit(0.5) 110 111 self.failUnless(stopcrit([]) == False) 112 self.failUnless(stopcrit([0.8, 0.9, 0.5]) == False) 113 self.failUnless(stopcrit([0.8, 0.9, 0.4]) == True) 114 # only last error has to be below to stop 115 self.failUnless(stopcrit([0.8, 0.4, 0.6]) == False)
116 117
118 - def testNStepsStopCrit(self):
119 """Test stopping criterion""" 120 stopcrit = NStepsStopCrit(2) 121 122 self.failUnless(stopcrit([]) == False) 123 self.failUnless(stopcrit([0.8, 0.9]) == True) 124 self.failUnless(stopcrit([0.8]) == False)
125 126
127 - def testMultiStopCrit(self):
128 """Test multiple stop criteria""" 129 stopcrit = MultiStopCrit([FixedErrorThresholdStopCrit(0.5), 130 NBackHistoryStopCrit(steps=4)]) 131 132 # default 'or' mode 133 # nback triggers 134 self.failUnless(stopcrit([1, 0.9, 0.8]+[0.9]*4) == True) 135 # threshold triggers 136 self.failUnless(stopcrit([1, 0.9, 0.2]) == True) 137 138 # alternative 'and' mode 139 stopcrit = MultiStopCrit([FixedErrorThresholdStopCrit(0.5), 140 NBackHistoryStopCrit(steps=4)], 141 mode = 'and') 142 # nback triggers not 143 self.failUnless(stopcrit([1, 0.9, 0.8]+[0.9]*4) == False) 144 # threshold triggers not 145 self.failUnless(stopcrit([1, 0.9, 0.2]) == False) 146 # only both satisfy 147 self.failUnless(stopcrit([1, 0.9, 0.4]+[0.4]*4) == True)
148 149
150 - def testFeatureSelector(self):
151 """Test feature selector""" 152 # remove 10% weekest 153 selector = FractionTailSelector(0.1) 154 data = N.array([3.5, 10, 7, 5, -0.4, 0, 0, 2, 10, 9]) 155 # == rank [4, 5, 6, 7, 0, 3, 2, 9, 1, 8] 156 target10 = N.array([0, 1, 2, 3, 5, 6, 7, 8, 9]) 157 target30 = N.array([0, 1, 2, 3, 7, 8, 9]) 158 159 self.failUnlessRaises(UnknownStateError, 160 selector.__getattribute__, 'ndiscarded') 161 self.failUnless((selector(data) == target10).all()) 162 selector.felements = 0.30 # discard 30% 163 self.failUnless(selector.felements == 0.3) 164 self.failUnless((selector(data) == target30).all()) 165 self.failUnless(selector.ndiscarded == 3) # se 3 were discarded 166 167 selector = FixedNElementTailSelector(1) 168 # 0 1 2 3 4 5 6 7 8 9 169 data = N.array([3.5, 10, 7, 5, -0.4, 0, 0, 2, 10, 9]) 170 self.failUnless((selector(data) == target10).all()) 171 172 selector.nelements = 3 173 self.failUnless(selector.nelements == 3) 174 self.failUnless((selector(data) == target30).all()) 175 self.failUnless(selector.ndiscarded == 3) 176 177 # test range selector 178 # simple range 'above' 179 self.failUnless((RangeElementSelector(lower=0)(data) == \ 180 N.array([0,1,2,3,7,8,9])).all()) 181 182 self.failUnless((RangeElementSelector(lower=0, 183 inclusive=True)(data) == \ 184 N.array([0,1,2,3,5,6,7,8,9])).all()) 185 186 self.failUnless((RangeElementSelector(lower=0, mode='discard', 187 inclusive=True)(data) == \ 188 N.array([4])).all()) 189 190 # simple range 'below' 191 self.failUnless((RangeElementSelector(upper=2)(data) == \ 192 N.array([4,5,6])).all()) 193 194 self.failUnless((RangeElementSelector(upper=2, 195 inclusive=True)(data) == \ 196 N.array([4,5,6,7])).all()) 197 198 self.failUnless((RangeElementSelector(upper=2, mode='discard', 199 inclusive=True)(data) == \ 200 N.array([0,1,2,3,8,9])).all()) 201 202 203 # ranges 204 self.failUnless((RangeElementSelector(lower=2, upper=9)(data) == \ 205 N.array([0,2,3])).all()) 206 207 self.failUnless((RangeElementSelector(lower=2, upper=9, 208 inclusive=True)(data) == \ 209 N.array([0,2,3,7,9])).all()) 210 211 self.failUnless((RangeElementSelector(upper=2, lower=9, mode='discard', 212 inclusive=True)(data) == 213 RangeElementSelector(lower=2, upper=9, 214 inclusive=False)(data)).all()) 215 216 # non-0 elements -- should be equivalent to N.nonzero()[0] 217 self.failUnless((RangeElementSelector()(data) == \ 218 N.nonzero(data)[0]).all())
219 220 221 @sweepargs(clf=clfswh['has_sensitivity', '!meta'])
223 224 # sensitivity analyser and transfer error quantifier use the SAME clf! 225 sens_ana = clf.getSensitivityAnalyzer() 226 227 # of features to remove 228 Nremove = 2 229 230 # because the clf is already trained when computing the sensitivity 231 # map, prevent retraining for transfer error calculation 232 # Use absolute of the svm weights as sensitivity 233 fe = SensitivityBasedFeatureSelection(sens_ana, 234 feature_selector=FixedNElementTailSelector(2), 235 enable_states=["sensitivity", "selected_ids"]) 236 237 wdata = self.getData() 238 wdata_nfeatures = wdata.nfeatures 239 tdata = self.getDataT() 240 tdata_nfeatures = tdata.nfeatures 241 242 sdata, stdata = fe(wdata, tdata) 243 244 # fail if orig datasets are changed 245 self.failUnless(wdata.nfeatures == wdata_nfeatures) 246 self.failUnless(tdata.nfeatures == tdata_nfeatures) 247 248 # silly check if nfeatures got a single one removed 249 self.failUnlessEqual(wdata.nfeatures, sdata.nfeatures+Nremove, 250 msg="We had to remove just a single feature") 251 252 self.failUnlessEqual(tdata.nfeatures, stdata.nfeatures+Nremove, 253 msg="We had to remove just a single feature in testing as well") 254 255 self.failUnlessEqual(len(fe.sensitivity), wdata_nfeatures, 256 msg="Sensitivity have to have # of features equal to original") 257 258 self.failUnlessEqual(len(fe.selected_ids), sdata.nfeatures, 259 msg="# of selected features must be equal the one in the result dataset")
260 261
263 sens_ana = SillySensitivityAnalyzer() 264 265 wdata = self.getData() 266 wdata_nfeatures = wdata.nfeatures 267 tdata = self.getDataT() 268 tdata_nfeatures = tdata.nfeatures 269 270 # test silly one first ;-) 271 self.failUnlessEqual(sens_ana(wdata)[0], -int(wdata_nfeatures/2)) 272 273 # OLD: first remove 25% == 6, and then 4, total removing 10 274 # NOW: test should be independent of the numerical number of features 275 feature_selections = [SensitivityBasedFeatureSelection( 276 sens_ana, 277 FractionTailSelector(0.25)), 278 SensitivityBasedFeatureSelection( 279 sens_ana, 280 FixedNElementTailSelector(4)) 281 ] 282 283 # create a FeatureSelection pipeline 284 feat_sel_pipeline = FeatureSelectionPipeline( 285 feature_selections=feature_selections, 286 enable_states=['nfeatures', 'selected_ids']) 287 288 sdata, stdata = feat_sel_pipeline(wdata, tdata) 289 290 self.failUnlessEqual(len(feat_sel_pipeline.feature_selections), 291 len(feature_selections), 292 msg="Test the property feature_selections") 293 294 desired_nfeatures = int(N.ceil(wdata_nfeatures*0.75)) 295 self.failUnlessEqual(feat_sel_pipeline.nfeatures, 296 [wdata_nfeatures, desired_nfeatures], 297 msg="Test if nfeatures get assigned properly." 298 " Got %s!=%s" % (feat_sel_pipeline.nfeatures, 299 [wdata_nfeatures, desired_nfeatures])) 300 301 self.failUnlessEqual(list(feat_sel_pipeline.selected_ids), 302 range(int(wdata_nfeatures*0.25)+4, wdata_nfeatures))
303 304 305 # TODO: should later on work for any clfs_with_sens 306 @sweepargs(clf=clfswh['has_sensitivity', '!meta'][:1])
307 - def testRFE(self, clf):
308 309 # sensitivity analyser and transfer error quantifier use the SAME clf! 310 sens_ana = clf.getSensitivityAnalyzer() 311 trans_error = TransferError(clf) 312 # because the clf is already trained when computing the sensitivity 313 # map, prevent retraining for transfer error calculation 314 # Use absolute of the svm weights as sensitivity 315 rfe = RFE(sens_ana, 316 trans_error, 317 feature_selector=FixedNElementTailSelector(1), 318 train_clf=False) 319 320 wdata = self.getData() 321 wdata_nfeatures = wdata.nfeatures 322 tdata = self.getDataT() 323 tdata_nfeatures = tdata.nfeatures 324 325 sdata, stdata = rfe(wdata, tdata) 326 327 # fail if orig datasets are changed 328 self.failUnless(wdata.nfeatures == wdata_nfeatures) 329 self.failUnless(tdata.nfeatures == tdata_nfeatures) 330 331 # check that the features set with the least error is selected 332 if len(rfe.errors): 333 e = N.array(rfe.errors) 334 self.failUnless(sdata.nfeatures == wdata_nfeatures - e.argmin()) 335 else: 336 self.failUnless(sdata.nfeatures == wdata_nfeatures) 337 338 # silly check if nfeatures is in decreasing order 339 nfeatures = N.array(rfe.nfeatures).copy() 340 nfeatures.sort() 341 self.failUnless( (nfeatures[::-1] == rfe.nfeatures).all() ) 342 343 # check if history has elements for every step 344 self.failUnless(Set(rfe.history) 345 == Set(range(len(N.array(rfe.errors))))) 346 347 # Last (the largest number) can be present multiple times even 348 # if we remove 1 feature at a time -- just need to stop well 349 # in advance when we have more than 1 feature left ;) 350 self.failUnless(rfe.nfeatures[-1] 351 == len(N.where(rfe.history 352 ==max(rfe.history))[0]))
353 354 # XXX add a test where sensitivity analyser and transfer error do not 355 # use the same classifier 356 357
358 - def testJamesProblem(self):
359 percent = 80 360 dataset = datasets['uni2small'] 361 rfesvm_split = LinearCSVMC() 362 fs = \ 363 RFE(sensitivity_analyzer=rfesvm_split.getSensitivityAnalyzer(), 364 transfer_error=TransferError(rfesvm_split), 365 feature_selector=FractionTailSelector( 366 percent / 100.0, 367 mode='select', tail='upper'), update_sensitivity=True) 368 369 clf = FeatureSelectionClassifier( 370 clf = LinearCSVMC(), 371 # on features selected via RFE 372 feature_selection = fs) 373 # update sensitivity at each step (since we're not using the 374 # same CLF as sensitivity analyzer) 375 clf.states.enable('feature_ids') 376 377 cv = CrossValidatedTransferError( 378 TransferError(clf), 379 NFoldSplitter(cvtype=1), 380 enable_states=['confusion'], 381 expose_testdataset=True) 382 #cv = SplitClassifier(clf) 383 try: 384 error = cv(dataset) 385 except Exception, e: 386 self.fail('CrossValidation cannot handle classifier with RFE ' 387 'feature selection. Got exception: %s' % e) 388 self.failUnless(error < 0.2)
389 390
391 - def __testMatthiasQuestion(self):
392 rfe_clf = LinearCSVMC(C=1) 393 394 rfesvm_split = SplitClassifier(rfe_clf) 395 clf = \ 396 FeatureSelectionClassifier( 397 clf = LinearCSVMC(C=1), 398 feature_selection = RFE( 399 sensitivity_analyzer = rfesvm_split.getSensitivityAnalyzer( 400 combiner=FirstAxisMean, 401 transformer=N.abs), 402 transfer_error=ConfusionBasedError( 403 rfesvm_split, 404 confusion_state="confusion"), 405 stopping_criterion=FixedErrorThresholdStopCrit(0.20), 406 feature_selector=FractionTailSelector( 407 0.2, mode='discard', tail='lower'), 408 update_sensitivity=True)) 409 410 splitter = NFoldSplitter(cvtype=1) 411 no_permutations = 1000 412 413 cv = CrossValidatedTransferError( 414 TransferError(clf), 415 splitter, 416 null_dist=MCNullDist(permutations=no_permutations, 417 tail='left'), 418 enable_states=['confusion']) 419 error = cv(datasets['uni2small']) 420 self.failUnless(error < 0.4) 421 self.failUnless(cv.states.null_prob < 0.05)
422
423 -def suite():
424 return unittest.makeSuite(RFETests)
425 426 427 if __name__ == '__main__': 428 import runner 429