forked from 170010011/fr
188 lines
5.7 KiB
Python
188 lines
5.7 KiB
Python
"""Utilities used to generate various figures in the documentation."""
|
|
from itertools import product
|
|
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
|
|
from ._dwt import pad
|
|
|
|
__all__ = ['wavedec_keys', 'wavedec2_keys', 'draw_2d_wp_basis',
|
|
'draw_2d_fswavedecn_basis', 'boundary_mode_subplot']
|
|
|
|
|
|
def wavedec_keys(level):
|
|
"""Subband keys corresponding to a wavedec decomposition."""
|
|
approx = ''
|
|
coeffs = {}
|
|
for lev in range(level):
|
|
for k in ['a', 'd']:
|
|
coeffs[approx + k] = None
|
|
approx = 'a' * (lev + 1)
|
|
if lev < level - 1:
|
|
coeffs.pop(approx)
|
|
return list(coeffs.keys())
|
|
|
|
|
|
def wavedec2_keys(level):
|
|
"""Subband keys corresponding to a wavedec2 decomposition."""
|
|
approx = ''
|
|
coeffs = {}
|
|
for lev in range(level):
|
|
for k in ['a', 'h', 'v', 'd']:
|
|
coeffs[approx + k] = None
|
|
approx = 'a' * (lev + 1)
|
|
if lev < level - 1:
|
|
coeffs.pop(approx)
|
|
return list(coeffs.keys())
|
|
|
|
|
|
def _box(bl, ur):
|
|
"""(x, y) coordinates for the 4 lines making up a rectangular box.
|
|
|
|
Parameters
|
|
==========
|
|
bl : float
|
|
The bottom left corner of the box
|
|
ur : float
|
|
The upper right corner of the box
|
|
|
|
Returns
|
|
=======
|
|
coords : 2-tuple
|
|
The first and second elements of the tuple are the x and y coordinates
|
|
of the box.
|
|
"""
|
|
xl, xr = bl[0], ur[0]
|
|
yb, yt = bl[1], ur[1]
|
|
box_x = [xl, xr,
|
|
xr, xr,
|
|
xr, xl,
|
|
xl, xl]
|
|
box_y = [yb, yb,
|
|
yb, yt,
|
|
yt, yt,
|
|
yt, yb]
|
|
return (box_x, box_y)
|
|
|
|
|
|
def _2d_wp_basis_coords(shape, keys):
|
|
# Coordinates of the lines to be drawn by draw_2d_wp_basis
|
|
coords = []
|
|
centers = {} # retain center of boxes for use in labeling
|
|
for key in keys:
|
|
offset_x = offset_y = 0
|
|
for n, char in enumerate(key):
|
|
if char in ['h', 'd']:
|
|
offset_x += shape[0] // 2**(n + 1)
|
|
if char in ['v', 'd']:
|
|
offset_y += shape[1] // 2**(n + 1)
|
|
sx = shape[0] // 2**(n + 1)
|
|
sy = shape[1] // 2**(n + 1)
|
|
xc, yc = _box((offset_x, -offset_y),
|
|
(offset_x + sx, -offset_y - sy))
|
|
coords.append((xc, yc))
|
|
centers[key] = (offset_x + sx // 2, -offset_y - sy // 2)
|
|
return coords, centers
|
|
|
|
|
|
def draw_2d_wp_basis(shape, keys, fmt='k', plot_kwargs={}, ax=None,
|
|
label_levels=0):
|
|
"""Plot a 2D representation of a WaveletPacket2D basis."""
|
|
coords, centers = _2d_wp_basis_coords(shape, keys)
|
|
if ax is None:
|
|
fig, ax = plt.subplots(1, 1)
|
|
else:
|
|
fig = ax.get_figure()
|
|
for coord in coords:
|
|
ax.plot(coord[0], coord[1], fmt)
|
|
ax.set_axis_off()
|
|
ax.axis('square')
|
|
if label_levels > 0:
|
|
for key, c in centers.items():
|
|
if len(key) <= label_levels:
|
|
ax.text(c[0], c[1], key,
|
|
horizontalalignment='center',
|
|
verticalalignment='center')
|
|
return fig, ax
|
|
|
|
|
|
def _2d_fswavedecn_coords(shape, levels):
|
|
coords = []
|
|
centers = {} # retain center of boxes for use in labeling
|
|
for key in product(wavedec_keys(levels), repeat=2):
|
|
(key0, key1) = key
|
|
offsets = [0, 0]
|
|
widths = list(shape)
|
|
for n0, char in enumerate(key0):
|
|
if char in ['d']:
|
|
offsets[0] += shape[0] // 2**(n0 + 1)
|
|
for n1, char in enumerate(key1):
|
|
if char in ['d']:
|
|
offsets[1] += shape[1] // 2**(n1 + 1)
|
|
widths[0] = shape[0] // 2**(n0 + 1)
|
|
widths[1] = shape[1] // 2**(n1 + 1)
|
|
xc, yc = _box((offsets[0], -offsets[1]),
|
|
(offsets[0] + widths[0], -offsets[1] - widths[1]))
|
|
coords.append((xc, yc))
|
|
centers[(key0, key1)] = (offsets[0] + widths[0] / 2,
|
|
-offsets[1] - widths[1] / 2)
|
|
return coords, centers
|
|
|
|
|
|
def draw_2d_fswavedecn_basis(shape, levels, fmt='k', plot_kwargs={}, ax=None,
|
|
label_levels=0):
|
|
"""Plot a 2D representation of a WaveletPacket2D basis."""
|
|
coords, centers = _2d_fswavedecn_coords(shape, levels)
|
|
if ax is None:
|
|
fig, ax = plt.subplots(1, 1)
|
|
else:
|
|
fig = ax.get_figure()
|
|
for coord in coords:
|
|
ax.plot(coord[0], coord[1], fmt)
|
|
ax.set_axis_off()
|
|
ax.axis('square')
|
|
if label_levels > 0:
|
|
for key, c in centers.items():
|
|
lev = np.max([len(k) for k in key])
|
|
if lev <= label_levels:
|
|
ax.text(c[0], c[1], key,
|
|
horizontalalignment='center',
|
|
verticalalignment='center')
|
|
return fig, ax
|
|
|
|
|
|
def boundary_mode_subplot(x, mode, ax, symw=True):
|
|
"""Plot an illustration of the boundary mode in a subplot axis."""
|
|
|
|
# if odd-length, periodization replicates the last sample to make it even
|
|
if mode == 'periodization' and len(x) % 2 == 1:
|
|
x = np.concatenate((x, (x[-1], )))
|
|
|
|
npad = 2 * len(x)
|
|
t = np.arange(len(x) + 2 * npad)
|
|
xp = pad(x, (npad, npad), mode=mode)
|
|
|
|
ax.plot(t, xp, 'k.')
|
|
ax.set_title(mode)
|
|
|
|
# plot the original signal in red
|
|
if mode == 'periodization':
|
|
ax.plot(t[npad:npad + len(x) - 1], x[:-1], 'r.')
|
|
else:
|
|
ax.plot(t[npad:npad + len(x)], x, 'r.')
|
|
|
|
# add vertical bars indicating points of symmetry or boundary extension
|
|
o2 = np.ones(2)
|
|
left = npad
|
|
if symw:
|
|
step = len(x) - 1
|
|
rng = range(-2, 4)
|
|
else:
|
|
left -= 0.5
|
|
step = len(x)
|
|
rng = range(-2, 4)
|
|
if mode in ['smooth', 'constant', 'zero']:
|
|
rng = range(0, 2)
|
|
for rep in rng:
|
|
ax.plot((left + rep * step) * o2, [xp.min() - .5, xp.max() + .5], 'k-')
|