Package mvpa :: Package mappers :: Module wavelet
[hide private]
[frames] | no frames]

Source Code for Module mvpa.mappers.wavelet

  1  #emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  #ex: set sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Wavelet mappers""" 
 10   
 11  import pywt 
 12  import numpy as N 
 13   
 14  from mvpa.mappers.base import Mapper 
 15  from mvpa.base.dochelpers import enhancedDocString 
 16   
 17  if __debug__: 
 18      from mvpa.base import debug 
 19   
 20  # WaveletPacket and WaveletTransformation mappers share lots of common 
 21  # functionality at the moment 
 22   
23 -class _WaveletMapper(Mapper):
24 """Generic class for Wavelet mappers (decomposition and packet) 25 """ 26
27 - def __init__(self, dim=1, wavelet='sym4', mode='per', maxlevel=None):
28 """Initialize WaveletPacket mapper 29 30 :Parameters: 31 dim : int or tuple of int 32 dimensions to work across (for now just scalar value, ie 1D 33 transformation) is supported 34 wavelet : basestring 35 one from the families available withing pywt package 36 mode : basestring 37 periodization mode 38 maxlevel : int or None 39 number of levels to use. If None - automatically selected by pywt 40 """ 41 Mapper.__init__(self) 42 43 self._dim = dim 44 """Dimension to work along""" 45 46 self._maxlevel = maxlevel 47 """Maximal level of decomposition. None for automatic""" 48 49 if not wavelet in pywt.wavelist(): 50 raise ValueError, \ 51 "Unknown family of wavelets '%s'. Please use one " \ 52 "available from the list %s" % (wavelet, pywt.wavelist()) 53 self._wavelet = wavelet 54 """Wavelet family to use""" 55 56 if not mode in pywt.MODES.modes: 57 raise ValueError, \ 58 "Unknown periodization mode '%s'. Please use one " \ 59 "available from the list %s" % (mode, pywt.MODES.modes) 60 self._mode = mode 61 """Periodization mode"""
62 63
64 - def forward(self, data):
65 data = N.asanyarray(data) 66 self._inshape = data.shape 67 self._intimepoints = data.shape[self._dim] 68 res = self._forward(data) 69 self._outshape = res.shape 70 return res
71 72
73 - def reverse(self, data):
74 data = N.asanyarray(data) 75 return self._reverse(data)
76 77
78 - def _forward(self, *args):
79 raise NotImplementedError
80 81
82 - def _reverse(self, *args):
83 raise NotImplementedError
84 85
86 - def getInSize(self):
87 """Returns the number of original features.""" 88 return self._inshape[1:]
89 90
91 - def getOutSize(self):
92 """Returns the number of wavelet components.""" 93 return self._outshape[1:]
94 95
96 - def selectOut(self, outIds):
97 """Choose a subset of components... 98 99 just use MaskMapper on top?""" 100 raise NotImplementedError, "Please use in conjunction with MaskMapper"
101 102 103 __doc__ = enhancedDocString('_WaveletMapper', locals(), Mapper)
104 105
106 -def _getIndexes(shape, dim):
107 """Generator for coordinate tuples providing slice for all in `dim` 108 109 XXX Somewhat sloppy implementation... but works... 110 """ 111 if len(shape) < dim: 112 raise ValueError, "Dimension %d is incorrect for a shape %s" % \ 113 (dim, shape) 114 n = len(shape) 115 curindexes = [0] * n 116 curindexes[dim] = slice(None) # all elements for dimension dim 117 while True: 118 yield tuple(curindexes) 119 for i in xrange(n): 120 if i == dim and dim == n-1: 121 return # we reached it -- thus time to go 122 if curindexes[i] == shape[i] - 1: 123 if i == n-1: 124 return 125 curindexes[i] = 0 126 else: 127 if i != dim: 128 curindexes[i] += 1 129 break
130 131
132 -class WaveletPacketMapper(_WaveletMapper):
133 """Convert signal into an overcomplete representaion using Wavelet packet 134 """ 135
136 - def _forward(self, data):
137 if __debug__: 138 debug('MAP', "Converting signal using DWP") 139 140 wp = None 141 levels_length = None # total length at each level 142 levels_lengths = None # list of lengths per each level 143 for indexes in _getIndexes(data.shape, self._dim): 144 if __debug__: 145 debug('MAP_', " %s" % (indexes,), lf=False, cr=True) 146 WP = pywt.WaveletPacket( 147 data[indexes], 148 wavelet=self._wavelet, 149 mode=self._mode, maxlevel=self._maxlevel) 150 151 if levels_length is None: 152 levels_length = [None] * WP.maxlevel 153 levels_lengths = [None] * WP.maxlevel 154 155 levels_datas = [] 156 for level in xrange(WP.maxlevel): 157 level_nodes = WP.get_level(level+1) 158 level_datas = [node.data for node in level_nodes] 159 160 level_lengths = [len(x) for x in level_datas] 161 level_length = N.sum(level_lengths) 162 163 if levels_lengths[level] is None: 164 levels_lengths[level] = level_lengths 165 elif levels_lengths[level] != level_lengths: 166 raise RuntimeError, \ 167 "ADs of same level of different samples should have same number of elements." \ 168 " Got %s, was %s" % (level_lengths, levels_lengths[level]) 169 170 if levels_length[level] is None: 171 levels_length[level] = level_length 172 elif levels_length[level] != level_length: 173 raise RuntimeError, \ 174 "Levels of different samples should have same number of elements." \ 175 " Got %d, was %d" % (level_length, levels_length[level]) 176 177 level_data = N.hstack(level_datas) 178 levels_datas.append(level_data) 179 180 # assert(len(data) == levels_length) 181 # assert(len(data) >= Ntimepoints) 182 if wp is None: 183 newdim = list(data.shape) 184 newdim[self._dim] = N.sum(levels_length) 185 wp = N.empty( tuple(newdim) ) 186 wp[indexes] = N.hstack(levels_datas) 187 188 self.levels_lengths, self.levels_length = levels_lengths, levels_length 189 if __debug__: 190 debug('MAP_', "") 191 debug('MAP', "Done convertion into wp. Total size %s" % str(wp.shape)) 192 return wp
193 194
195 - def _reverse(self, data):
196 raise NotImplementedError
197 198
199 -class WaveletTransformationMapper(_WaveletMapper):
200 """Convert signal into wavelet representaion 201 """ 202
203 - def _forward(self, data):
204 """Decompose signal into wavelets's coefficients via dwt 205 """ 206 if __debug__: 207 debug('MAP', "Converting signal using DWT") 208 wd = None 209 coeff_lengths = None 210 for indexes in _getIndexes(data.shape, self._dim): 211 if __debug__: 212 debug('MAP_', " %s" % (indexes,), lf=False, cr=True) 213 coeffs = pywt.wavedec( 214 data[indexes], 215 wavelet=self._wavelet, 216 mode=self._mode, 217 level=self._maxlevel) 218 # Silly Yarik embedds extraction of statistics right in place 219 #stats = [] 220 #for coeff in coeffs: 221 # stats_ = [N.std(coeff), 222 # N.sqrt(N.dot(coeff, coeff)), 223 # ]# + list(N.histogram(coeff, normed=True)[0])) 224 # stats__ = list(coeff) + stats_[:] 225 # stats__ += list(N.log(stats_)) 226 # stats__ += list(N.sqrt(stats_)) 227 # stats__ += list(N.array(stats_)**2) 228 # stats__ += [ N.median(coeff), N.mean(coeff), scipy.stats.kurtosis(coeff) ] 229 # stats.append(stats__) 230 #coeffs = stats 231 coeff_lengths_ = N.array([len(x) for x in coeffs]) 232 if coeff_lengths is None: 233 coeff_lengths = coeff_lengths_ 234 assert((coeff_lengths == coeff_lengths_).all()) 235 if wd is None: 236 newdim = list(data.shape) 237 newdim[self._dim] = N.sum(coeff_lengths) 238 wd = N.empty( tuple(newdim) ) 239 coeff = N.hstack(coeffs) 240 wd[indexes] = coeff 241 if __debug__: 242 debug('MAP_', "") 243 debug('MAP', "Done DWT. Total size %s" % str(wd.shape)) 244 self.lengths = coeff_lengths 245 return wd
246 247
248 - def _reverse(self, wd):
249 if __debug__: 250 debug('MAP', "Performing iDWT") 251 signal = None 252 wd_offsets = [0] + list(N.cumsum(self.lengths)) 253 Nlevels = len(self.lengths) 254 Ntime_points = self._intimepoints #len(time_points) 255 # unfortunately sometimes due to padding iDWT would return longer 256 # sequences, thus we just limit to the right ones 257 258 for indexes in _getIndexes(wd.shape, self._dim): 259 if __debug__: 260 debug('MAP_', " %s" % (indexes,), lf=False, cr=True) 261 wd_sample = wd[indexes] 262 wd_coeffs = [wd_sample[wd_offsets[i]:wd_offsets[i+1]] for i in xrange(Nlevels)] 263 # need to compose original list 264 time_points = pywt.waverec( 265 wd_coeffs, wavelet=self._wavelet, mode=self._mode) 266 if signal is None: 267 newdim = list(wd.shape) 268 newdim[self._dim] = Ntime_points 269 signal = N.empty(newdim) 270 signal[indexes] = time_points[:Ntime_points] 271 if __debug__: 272 debug('MAP_', "") 273 debug('MAP', "Done iDWT. Total size %s" % (signal.shape, )) 274 return signal
275