1
2
3
4
5
6
7
8
9 """Misc. plotting helpers."""
10
11 __docformat__ = 'restructuredtext'
12
13 import pylab as P
14 import numpy as N
15
16 from mvpa.datasets.splitter import NFoldSplitter
17 from mvpa.clfs.distance import squared_euclidean_distance
18
19
20
21 -def errLinePlot(data, errtype='ste', curves=None, linestyle='--', fmt='o'):
22 """Make a line plot with errorbars on the data points.
23
24 :Parameters:
25 data: sequence of sequences
26 First axis separates samples and second axis will appear as
27 x-axis in the plot.
28 errtype: 'ste' | 'std'
29 Type of error value to be computed per datapoint.
30 'ste': standard error of the mean
31 'std': standard deviation
32 curves: None | ndarrayb
33 Each *row* of the array is plotted as an additional curve. The
34 curves might have a different sampling frequency (i.e. number of
35 samples) than the data array, as it will be scaled (along
36 x-axis) to the range of the data points.
37 linestyle: str
38 matplotlib linestyle argument. Applied to either the additional
39 curve or a the line connecting the datapoints. Set to 'None' to
40 disable the line completely.
41 fmt: str
42 matplotlib plot style argument to be applied to the data points
43 and errorbars.
44
45
46 :Example:
47
48 Make dataset with 20 samples from a full sinus wave period,
49 computed 100 times with individual noise pattern.
50
51 >>> x = N.linspace(0, N.pi * 2, 20)
52 >>> data = N.vstack([N.sin(x)] * 30)
53 >>> data += N.random.normal(size=data.shape)
54
55 Now, plot mean data points with error bars, plus a high-res
56 version of the original sinus wave.
57
58 >>> errLinePlot(data, curves=N.sin(N.linspace(0, N.pi * 2, 200)))
59 >>> #P.show()
60 """
61 data = N.asanyarray(data)
62
63 if len(data.shape) < 2:
64 data = N.atleast_2d(data)
65
66
67 md = data.mean(axis=0)
68
69
70 x = N.arange(len(md))
71
72
73 if curves is not None:
74 curves = N.array(curves, ndmin=2).T
75 xaxis = N.linspace(0, len(md), len(curves))
76
77
78
79 for c in xrange(curves.shape[1]):
80
81 P.plot(xaxis, curves[:, c], linestyle=linestyle)
82
83 linestyle = 'None'
84
85
86 if errtype == 'ste':
87 err = data.std(axis=0) / N.sqrt(len(data))
88 elif errtype == 'std':
89 err = data.std(axis=0)
90 else:
91 raise ValueError, "Unknown error type '%s'" % errtype
92
93
94 P.errorbar(x, md, err, fmt=fmt, linestyle=linestyle)
95
96
97 -def plotFeatureHist(dataset, xlim=None, noticks=True, perchunk=False,
98 **kwargs):
99 """Plot histograms of feature values for each labels.
100
101 :Parameters:
102 dataset: Dataset
103 xlim: None | 2-tuple
104 Common x-axis limits for all histograms.
105 noticks: boolean
106 If True, no axis ticks will be plotted. This is useful to save
107 space in large plots.
108 perchunk: boolean
109 If True, one histogramm will be plotted per each label and each
110 chunk, resulting is a histogram grid (labels x chunks).
111 **kwargs:
112 Any additional arguments are passed to matplotlib's hist().
113 """
114 lsplit = NFoldSplitter(1, attr='labels')
115 csplit = NFoldSplitter(1, attr='chunks')
116
117 nrows = len(dataset.uniquelabels)
118 ncols = len(dataset.uniquechunks)
119
120 def doplot(data):
121 P.hist(data, **kwargs)
122
123 if xlim is not None:
124 P.xlim(xlim)
125
126 if noticks:
127 P.yticks([])
128 P.xticks([])
129
130 fig = 1
131
132
133 for row, (ignore, ds) in enumerate(lsplit(dataset)):
134 if perchunk:
135 for col, (alsoignore, d) in enumerate(csplit(ds)):
136
137 P.subplot(nrows, ncols, fig)
138 doplot(d.samples.ravel())
139
140 if row == 0:
141 P.title('C:' + str(d.uniquechunks[0]))
142 if col == 0:
143 P.ylabel('L:' + str(d.uniquelabels[0]))
144
145 fig += 1
146 else:
147 P.subplot(1, nrows, fig)
148 doplot(ds.samples)
149
150 P.title('L:' + str(ds.uniquelabels[0]))
151
152 fig += 1
153
154
156 """Plot the euclidean distances between all samples of a dataset.
157
158 :Parameters:
159 dataset: Dataset
160 Providing the samples.
161 sortbyattr: None | str
162 If None, the samples distances will be in the same order as their
163 appearance in the dataset. Alternatively, the name of a samples
164 attribute can be given, which wil then be used to sort/group the
165 samples, e.g. to investigate the similarity samples by label or by
166 chunks.
167 """
168 if sortbyattr is not None:
169 slicer = []
170 for attr in dataset.__getattribute__('unique' + sortbyattr):
171 slicer += \
172 dataset.__getattribute__('idsby' + sortbyattr)(attr).tolist()
173 samples = dataset.samples[slicer]
174 else:
175 samples = dataset.samples
176
177 ed = N.sqrt(squared_euclidean_distance(samples))
178
179 P.imshow(ed)
180 P.colorbar()
181
182
183 -def plotBars(data, labels=None, title=None, ylim=None, ylabel=None,
184 width=0.2, offset=0.2, color='0.6', distance=1.0,
185 yerr='ste', **kwargs):
186 """Make bar plots with automatically computed error bars.
187
188 Candlestick plot (multiple interleaved barplots) can be done,
189 by calling this function multiple time with appropriatly modified
190 `offset` argument.
191
192 :Parameters:
193 data: array (nbars x nobservations) | other sequence type
194 Source data for the barplot. Error measure is computed along the
195 second axis.
196 labels: list | None
197 If not None, a label from this list is placed on each bar.
198 title: str
199 An optional title of the barplot.
200 ylim: 2-tuple
201 Y-axis range.
202 ylabel: str
203 An optional label for the y-axis.
204 width: float
205 Width of a bar. The value should be in a reasonable relation to
206 `distance`.
207 offset: float
208 Constant offset of all bar along the x-axis. Can be used to create
209 candlestick plots.
210 color: matplotlib color spec
211 Color of the bars.
212 distance: float
213 Distance of two adjacent bars.
214 yerr: 'ste' | 'std' | None
215 Type of error for the errorbars. If `None` no errorbars are plotted.
216 **kwargs:
217 Any additional arguments are passed to matplotlib's `bar()` function.
218 """
219
220 xlocations = (N.arange(len(data)) * distance) + offset
221
222 if yerr == 'ste':
223 yerr = [N.std(d) / N.sqrt(len(d)) for d in data]
224 elif yerr == 'std':
225 yerr = [N.std(d) for d in data]
226 else:
227
228 pass
229
230
231 plot = P.bar(xlocations,
232 [N.mean(d) for d in data],
233 yerr=yerr,
234 width=width,
235 color=color,
236 ecolor='black',
237 **kwargs)
238
239 if ylim:
240 P.ylim(*(ylim))
241 if title:
242 P.title(title)
243
244 if labels:
245 P.xticks(xlocations + width / 2, labels)
246
247 if ylabel:
248 P.ylabel(ylabel)
249
250
251 P.xlim(0, xlocations[-1] + width + offset)
252
253 return plot
254
255
257 """Create a new colormap from the named colormap, where it got reversed
258
259 """
260 import matplotlib._cm as _cm
261 import matplotlib as mpl
262 try:
263 cmap_data = eval('_cm._%s_data' % cmap_name)
264 except:
265 raise ValueError, "Cannot obtain data for the colormap %s" % cmap_name
266 new_data = dict( [(k, [(v[i][0], v[-(i+1)][1], v[-(i+1)][2])
267 for i in xrange(len(v))])
268 for k,v in cmap_data.iteritems()] )
269 return mpl.colors.LinearSegmentedColormap('%s_rev' % cmap_name,
270 new_data, _cm.LUTSIZE)
271
272
274 """Quick plot to see chunk sctructure in dataset with 2 features
275
276 if clf_labels is provided for the predicted labels, then
277 incorrectly labeled samples will have 'x' in them
278 """
279 if ds.nfeatures != 2:
280 raise ValueError, "Can plot only in 2D, ie for datasets with 2 features"
281 if P.matplotlib.get_backend() == 'TkAgg':
282 P.ioff()
283 if clf_labels is not None and len(clf_labels) != ds.nsamples:
284 clf_labels = None
285 colors = ('b', 'g', 'r', 'c', 'm', 'y', 'k', 'w')
286 labels = ds.uniquelabels
287 labels_map = dict(zip(labels, colors[:len(labels)]))
288 for chunk in ds.uniquechunks:
289 chunk_text = str(chunk)
290 ids = ds.where(chunks=chunk)
291 ds_chunk = ds[ids]
292 for i in xrange(ds_chunk.nsamples):
293 s = ds_chunk.samples[i]
294 l = ds_chunk.labels[i]
295 format = ''
296 if clf_labels != None:
297 if clf_labels[i] != ds_chunk.labels[i]:
298 P.plot([s[0]], [s[1]], 'x' + labels_map[l])
299 P.text(s[0], s[1], chunk_text, color=labels_map[l],
300 horizontalalignment='center',
301 verticalalignment='center',
302 )
303 dss = ds.samples
304 P.axis((1.1 * N.min(dss[:, 0]),
305 1.1 * N.max(dss[:, 1]),
306 1.1 * N.max(dss[:, 0]),
307 1.1 * N.min(dss[:, 1])))
308 P.draw()
309 P.ion()
310