1
2
3
4
5
6
7
8
9 """Wavelet mappers"""
10
11 import pywt
12 import numpy as N
13
14 from mvpa.mappers.base import Mapper
15 from mvpa.base.dochelpers import enhancedDocString
16
17 if __debug__:
18 from mvpa.base import debug
19
20
21
22
24 """Generic class for Wavelet mappers (decomposition and packet)
25 """
26
27 - def __init__(self, dim=1, wavelet='sym4', mode='per', maxlevel=None):
28 """Initialize WaveletPacket mapper
29
30 :Parameters:
31 dim : int or tuple of int
32 dimensions to work across (for now just scalar value, ie 1D
33 transformation) is supported
34 wavelet : basestring
35 one from the families available withing pywt package
36 mode : basestring
37 periodization mode
38 maxlevel : int or None
39 number of levels to use. If None - automatically selected by pywt
40 """
41 Mapper.__init__(self)
42
43 self._dim = dim
44 """Dimension to work along"""
45
46 self._maxlevel = maxlevel
47 """Maximal level of decomposition. None for automatic"""
48
49 if not wavelet in pywt.wavelist():
50 raise ValueError, \
51 "Unknown family of wavelets '%s'. Please use one " \
52 "available from the list %s" % (wavelet, pywt.wavelist())
53 self._wavelet = wavelet
54 """Wavelet family to use"""
55
56 if not mode in pywt.MODES.modes:
57 raise ValueError, \
58 "Unknown periodization mode '%s'. Please use one " \
59 "available from the list %s" % (mode, pywt.MODES.modes)
60 self._mode = mode
61 """Periodization mode"""
62
63
65 data = N.asanyarray(data)
66 self._inshape = data.shape
67 self._intimepoints = data.shape[self._dim]
68 res = self._forward(data)
69 self._outshape = res.shape
70 return res
71
72
76
77
79 raise NotImplementedError
80
81
83 raise NotImplementedError
84
85
87 """Returns the number of original features."""
88 return self._inshape[1:]
89
90
92 """Returns the number of wavelet components."""
93 return self._outshape[1:]
94
95
97 """Choose a subset of components...
98
99 just use MaskMapper on top?"""
100 raise NotImplementedError, "Please use in conjunction with MaskMapper"
101
102
103 __doc__ = enhancedDocString('_WaveletMapper', locals(), Mapper)
104
105
107 """Generator for coordinate tuples providing slice for all in `dim`
108
109 XXX Somewhat sloppy implementation... but works...
110 """
111 if len(shape) < dim:
112 raise ValueError, "Dimension %d is incorrect for a shape %s" % \
113 (dim, shape)
114 n = len(shape)
115 curindexes = [0] * n
116 curindexes[dim] = slice(None)
117 while True:
118 yield tuple(curindexes)
119 for i in xrange(n):
120 if i == dim and dim == n-1:
121 return
122 if curindexes[i] == shape[i] - 1:
123 if i == n-1:
124 return
125 curindexes[i] = 0
126 else:
127 if i != dim:
128 curindexes[i] += 1
129 break
130
131
133 """Convert signal into an overcomplete representaion using Wavelet packet
134 """
135
137 if __debug__:
138 debug('MAP', "Converting signal using DWP")
139
140 wp = None
141 levels_length = None
142 levels_lengths = None
143 for indexes in _getIndexes(data.shape, self._dim):
144 if __debug__:
145 debug('MAP_', " %s" % (indexes,), lf=False, cr=True)
146 WP = pywt.WaveletPacket(
147 data[indexes],
148 wavelet=self._wavelet,
149 mode=self._mode, maxlevel=self._maxlevel)
150
151 if levels_length is None:
152 levels_length = [None] * WP.maxlevel
153 levels_lengths = [None] * WP.maxlevel
154
155 levels_datas = []
156 for level in xrange(WP.maxlevel):
157 level_nodes = WP.get_level(level+1)
158 level_datas = [node.data for node in level_nodes]
159
160 level_lengths = [len(x) for x in level_datas]
161 level_length = N.sum(level_lengths)
162
163 if levels_lengths[level] is None:
164 levels_lengths[level] = level_lengths
165 elif levels_lengths[level] != level_lengths:
166 raise RuntimeError, \
167 "ADs of same level of different samples should have same number of elements." \
168 " Got %s, was %s" % (level_lengths, levels_lengths[level])
169
170 if levels_length[level] is None:
171 levels_length[level] = level_length
172 elif levels_length[level] != level_length:
173 raise RuntimeError, \
174 "Levels of different samples should have same number of elements." \
175 " Got %d, was %d" % (level_length, levels_length[level])
176
177 level_data = N.hstack(level_datas)
178 levels_datas.append(level_data)
179
180
181
182 if wp is None:
183 newdim = list(data.shape)
184 newdim[self._dim] = N.sum(levels_length)
185 wp = N.empty( tuple(newdim) )
186 wp[indexes] = N.hstack(levels_datas)
187
188 self.levels_lengths, self.levels_length = levels_lengths, levels_length
189 if __debug__:
190 debug('MAP_', "")
191 debug('MAP', "Done convertion into wp. Total size %s" % str(wp.shape))
192 return wp
193
194
196 raise NotImplementedError
197
198
275