fr/fr_env/lib/python3.8/site-packages/pywt/tests/test_swt.py

634 lines
24 KiB
Python

#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import warnings
from copy import deepcopy
from itertools import combinations, permutations
import numpy as np
import pytest
from numpy.testing import (assert_allclose, assert_, assert_equal,
assert_raises, assert_array_equal, assert_warns)
import pywt
from pywt._extensions._swt import swt_axis
# Check that float32 and complex64 are preserved. Other real types get
# converted to float64.
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
np.complex128]
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
np.complex128]
# tolerances used in accuracy comparisons
tol_single = 1e-6
tol_double = 1e-13
####
# 1d multilevel swt tests
####
def test_swt_decomposition():
x = [3, 7, 1, 3, -2, 6, 4, 6]
db1 = pywt.Wavelet('db1')
atol = tol_double
(cA3, cD3), (cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=3)
expected_cA1 = [7.07106781, 5.65685425, 2.82842712, 0.70710678,
2.82842712, 7.07106781, 7.07106781, 6.36396103]
assert_allclose(cA1, expected_cA1, rtol=1e-8, atol=atol)
expected_cD1 = [-2.82842712, 4.24264069, -1.41421356, 3.53553391,
-5.65685425, 1.41421356, -1.41421356, 2.12132034]
assert_allclose(cD1, expected_cD1, rtol=1e-8, atol=atol)
expected_cA2 = [7, 4.5, 4, 5.5, 7, 9.5, 10, 8.5]
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
expected_cD2 = [3, 3.5, 0, -4.5, -3, 0.5, 0, 0.5]
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
expected_cA3 = [9.89949494, ] * 8
assert_allclose(cA3, expected_cA3, rtol=1e-8, atol=atol)
expected_cD3 = [0.00000000, -3.53553391, -4.24264069, -2.12132034,
0.00000000, 3.53553391, 4.24264069, 2.12132034]
assert_allclose(cD3, expected_cD3, rtol=1e-8, atol=atol)
# level=1, start_level=1 decomposition should match level=2
res = pywt.swt(cA1, db1, level=1, start_level=1)
cA2, cD2 = res[0]
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
coeffs = pywt.swt(x, db1)
assert_(len(coeffs) == 3)
assert_(pywt.swt_max_level(len(x)), 3)
def test_swt_max_level():
# odd sized signal will warn about no levels of decomposition possible
assert_warns(UserWarning, pywt.swt_max_level, 11)
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning)
assert_equal(pywt.swt_max_level(11), 0)
# no warnings when >= 1 level of decomposition possible
assert_equal(pywt.swt_max_level(2), 1) # divisible by 2**1
assert_equal(pywt.swt_max_level(4*3), 2) # divisible by 2**2
assert_equal(pywt.swt_max_level(16), 4) # divisible by 2**4
assert_equal(pywt.swt_max_level(16*3), 4) # divisible by 2**4
def test_swt_axis():
x = [3, 7, 1, 3, -2, 6, 4, 6]
db1 = pywt.Wavelet('db1')
(cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=2)
# test cases use 2D arrays based on tiling x along an axis and then
# calling swt along the other axis.
for order in ['C', 'F']:
# test SWT of 2D data along default axis (-1)
x_2d = np.asarray(x).reshape((1, -1))
x_2d = np.concatenate((x_2d, )*5, axis=0)
if order == 'C':
x_2d = np.ascontiguousarray(x_2d)
elif order == 'F':
x_2d = np.asfortranarray(x_2d)
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2)
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
assert_(c.shape == x_2d.shape)
# each row should match the 1D result
for row in cA1_2d:
assert_array_equal(row, cA1)
for row in cA2_2d:
assert_array_equal(row, cA2)
for row in cD1_2d:
assert_array_equal(row, cD1)
for row in cD2_2d:
assert_array_equal(row, cD2)
# test SWT of 2D data along other axis (0)
x_2d = np.asarray(x).reshape((-1, 1))
x_2d = np.concatenate((x_2d, )*5, axis=1)
if order == 'C':
x_2d = np.ascontiguousarray(x_2d)
elif order == 'F':
x_2d = np.asfortranarray(x_2d)
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2,
axis=0)
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
assert_(c.shape == x_2d.shape)
# each column should match the 1D result
for row in cA1_2d.transpose((1, 0)):
assert_array_equal(row, cA1)
for row in cA2_2d.transpose((1, 0)):
assert_array_equal(row, cA2)
for row in cD1_2d.transpose((1, 0)):
assert_array_equal(row, cD1)
for row in cD2_2d.transpose((1, 0)):
assert_array_equal(row, cD2)
# axis too large
assert_raises(ValueError, pywt.swt, x, db1, level=2, axis=5)
def test_swt_iswt_integration():
# This function performs a round-trip swt/iswt transform test on
# all available types of wavelets in PyWavelets - except the
# 'dmey' wavelet. The latter has been excluded because it does not
# produce very precise results. This is likely due to the fact
# that the 'dmey' wavelet is a discrete approximation of a
# continuous wavelet. All wavelets are tested up to 3 levels. The
# test validates neither swt or iswt as such, but it does ensure
# that they are each other's inverse.
max_level = 3
wavelets = pywt.wavelist(kind='discrete')
if 'dmey' in wavelets:
# The 'dmey' wavelet seems to be a bit special - disregard it for now
wavelets.remove('dmey')
for current_wavelet_str in wavelets:
current_wavelet = pywt.Wavelet(current_wavelet_str)
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power + max_level - 1)
X = np.arange(input_length)
for norm in [True, False]:
if norm and not current_wavelet.orthogonal:
# non-orthogonal wavelets to avoid warnings when norm=True
continue
for trim_approx in [True, False]:
coeffs = pywt.swt(X, current_wavelet, max_level,
trim_approx=trim_approx, norm=norm)
Y = pywt.iswt(coeffs, current_wavelet, norm=norm)
assert_allclose(Y, X, rtol=1e-5, atol=1e-7)
def test_swt_dtypes():
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
errmsg = "wrong dtype returned for {0} input".format(dt_in)
# swt
x = np.ones(8, dtype=dt_in)
(cA2, cD2), (cA1, cD1) = pywt.swt(x, wavelet, level=2)
assert_(cA2.dtype == cD2.dtype == cA1.dtype == cD1.dtype == dt_out,
"swt: " + errmsg)
# swt2
x = np.ones((8, 8), dtype=dt_in)
cA, (cH, cV, cD) = pywt.swt2(x, wavelet, level=1)[0]
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype == dt_out,
"swt2: " + errmsg)
def test_swt_roundtrip_dtypes():
# verify perfect reconstruction for all dtypes
rstate = np.random.RandomState(5)
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
# swt, iswt
x = rstate.standard_normal((8, )).astype(dt_in)
c = pywt.swt(x, wavelet, level=2)
xr = pywt.iswt(c, wavelet)
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
# swt2, iswt2
x = rstate.standard_normal((8, 8)).astype(dt_in)
c = pywt.swt2(x, wavelet, level=2)
xr = pywt.iswt2(c, wavelet)
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
def test_swt_default_level_by_axis():
# make sure default number of levels matches the max level along the axis
wav = 'db2'
x = np.ones((2**3, 2**4, 2**5))
for axis in (0, 1, 2):
sdec = pywt.swt(x, wav, level=None, start_level=0, axis=axis)
assert_equal(len(sdec), pywt.swt_max_level(x.shape[axis]))
def test_swt2_ndim_error():
x = np.ones(8)
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
assert_raises(ValueError, pywt.swt2, x, 'haar', level=1)
@pytest.mark.slow
def test_swt2_iswt2_integration(wavelets=None):
# This function performs a round-trip swt2/iswt2 transform test on
# all available types of wavelets in PyWavelets - except the
# 'dmey' wavelet. The latter has been excluded because it does not
# produce very precise results. This is likely due to the fact
# that the 'dmey' wavelet is a discrete approximation of a
# continuous wavelet. All wavelets are tested up to 3 levels. The
# test validates neither swt2 or iswt2 as such, but it does ensure
# that they are each other's inverse.
max_level = 3
if wavelets is None:
wavelets = pywt.wavelist(kind='discrete')
if 'dmey' in wavelets:
# The 'dmey' wavelet is a special case - disregard it for now
wavelets.remove('dmey')
for current_wavelet_str in wavelets:
current_wavelet = pywt.Wavelet(current_wavelet_str)
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power + max_level - 1)
X = np.arange(input_length**2).reshape(input_length, input_length)
for norm in [True, False]:
if norm and not current_wavelet.orthogonal:
# non-orthogonal wavelets to avoid warnings when norm=True
continue
for trim_approx in [True, False]:
coeffs = pywt.swt2(X, current_wavelet, max_level,
trim_approx=trim_approx, norm=norm)
Y = pywt.iswt2(coeffs, current_wavelet, norm=norm)
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
def test_swt2_iswt2_quick():
test_swt2_iswt2_integration(wavelets=['db1', ])
def test_swt2_iswt2_non_square(wavelets=None):
for nrows in [8, 16, 48]:
X = np.arange(nrows*32).reshape(nrows, 32)
current_wavelet = 'db1'
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
coeffs = pywt.swt2(X, current_wavelet, level=2)
Y = pywt.iswt2(coeffs, current_wavelet)
assert_allclose(Y, X, rtol=tol_single, atol=tol_single)
def test_swt2_axes():
atol = 1e-14
current_wavelet = pywt.Wavelet('db2')
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power)
X = np.arange(input_length**2).reshape(input_length, input_length)
(cA1, (cH1, cV1, cD1)) = pywt.swt2(X, current_wavelet, level=1)[0]
# opposite order
(cA2, (cH2, cV2, cD2)) = pywt.swt2(X, current_wavelet, level=1,
axes=(1, 0))[0]
assert_allclose(cA1, cA2, atol=atol)
assert_allclose(cH1, cV2, atol=atol)
assert_allclose(cV1, cH2, atol=atol)
assert_allclose(cD1, cD2, atol=atol)
# duplicate axes not allowed
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1,
axes=(0, 0))
# too few axes
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1, axes=(0, ))
def test_iswt2_2d_only():
# iswt2 is not currently compatible with data that is not 2D
x_3d = np.ones((4, 4, 4))
c = pywt.swt2(x_3d, 'haar', level=1)
assert_raises(ValueError, pywt.iswt2, c, 'haar')
def test_swtn_axes():
atol = 1e-14
current_wavelet = pywt.Wavelet('db2')
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power)
X = np.arange(input_length**2).reshape(input_length, input_length)
coeffs = pywt.swtn(X, current_wavelet, level=1, axes=None)[0]
# opposite order
coeffs2 = pywt.swtn(X, current_wavelet, level=1, axes=(1, 0))[0]
assert_allclose(coeffs['aa'], coeffs2['aa'], atol=atol)
assert_allclose(coeffs['ad'], coeffs2['da'], atol=atol)
assert_allclose(coeffs['da'], coeffs2['ad'], atol=atol)
assert_allclose(coeffs['dd'], coeffs2['dd'], atol=atol)
# 0-level transform
empty = pywt.swtn(X, current_wavelet, level=0)
assert_equal(empty, [])
# duplicate axes not allowed
assert_raises(ValueError, pywt.swtn, X, current_wavelet, 1, axes=(0, 0))
# data.ndim = 0
assert_raises(ValueError, pywt.swtn, np.asarray([]), current_wavelet, 1)
# start_level too large
assert_raises(ValueError, pywt.swtn, X, current_wavelet,
level=1, start_level=2)
# level < 1 in swt_axis call
assert_raises(ValueError, swt_axis, X, current_wavelet, level=0,
start_level=0)
# odd-sized data not allowed
assert_raises(ValueError, swt_axis, X[:-1, :], current_wavelet, level=0,
start_level=0, axis=0)
@pytest.mark.slow
def test_swtn_iswtn_integration(wavelets=None):
# This function performs a round-trip swtn/iswtn transform for various
# possible combinations of:
# 1.) 1 out of 2 axes of a 2D array
# 2.) 2 out of 3 axes of a 3D array
#
# To keep test time down, only wavelets of length <= 8 are run.
#
# This test does not validate swtn or iswtn individually, but only
# confirms that iswtn yields an (almost) perfect reconstruction of swtn.
max_level = 3
if wavelets is None:
wavelets = pywt.wavelist(kind='discrete')
if 'dmey' in wavelets:
# The 'dmey' wavelet is a special case - disregard it for now
wavelets.remove('dmey')
for ndim_transform in range(1, 3):
ndim = ndim_transform + 1
for axes in combinations(range(ndim), ndim_transform):
for current_wavelet_str in wavelets:
wav = pywt.Wavelet(current_wavelet_str)
if wav.dec_len > 8:
continue # avoid excessive test duration
input_length_power = int(np.ceil(np.log2(max(
wav.dec_len,
wav.rec_len))))
N = 2**(input_length_power + max_level - 1)
X = np.arange(N**ndim).reshape((N, )*ndim)
for norm in [True, False]:
if norm and not wav.orthogonal:
# non-orthogonal wavelets to avoid warnings
continue
for trim_approx in [True, False]:
coeffs = pywt.swtn(X, wav, max_level, axes=axes,
trim_approx=trim_approx, norm=norm)
coeffs_copy = deepcopy(coeffs)
Y = pywt.iswtn(coeffs, wav, axes=axes, norm=norm)
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
# verify the inverse transform didn't modify any coeffs
for c, c2 in zip(coeffs, coeffs_copy):
for k, v in c.items():
assert_array_equal(c2[k], v)
def test_swtn_iswtn_quick():
test_swtn_iswtn_integration(wavelets=['db1', ])
def test_iswtn_errors():
x = np.arange(8**3).reshape(8, 8, 8)
max_level = 2
axes = (0, 1)
w = pywt.Wavelet('db1')
coeffs = pywt.swtn(x, w, max_level, axes=axes)
# more axes than dimensions transformed
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 1, 2))
# duplicate axes not allowed
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 0))
# mismatched coefficient size
coeffs[0]['da'] = coeffs[0]['da'][:-1, :]
assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)
def test_swtn_iswtn_unique_shape_per_axis():
# test case for gh-460
_shape = (1, 48, 32) # unique shape per axis
wav = 'sym2'
max_level = 3
rstate = np.random.RandomState(0)
for shape in permutations(_shape):
# transform only along the non-singleton axes
axes = [ax for ax, s in enumerate(shape) if s != 1]
x = rstate.standard_normal(shape)
c = pywt.swtn(x, wav, max_level, axes=axes)
r = pywt.iswtn(c, wav, axes=axes)
assert_allclose(x, r, rtol=1e-10, atol=1e-10)
def test_per_axis_wavelets():
# tests seperate wavelet for each axis.
rstate = np.random.RandomState(1234)
data = rstate.randn(16, 16, 16)
level = 3
# wavelet can be a string or wavelet object
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
coefs = pywt.swtn(data, wavelets, level=level)
assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14)
# 1-tuple also okay
coefs = pywt.swtn(data, wavelets[:1], level=level)
assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)
# length of wavelets doesn't match the length of axes
assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
# swt2/iswt2 also support per-axis wavelets/modes
data2 = data[..., 0]
coefs2 = pywt.swt2(data2, wavelets[:2], level)
assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14)
def test_error_on_continuous_wavelet():
# A ValueError is raised if a Continuous wavelet is selected
data = np.ones((16, 16))
for dec_func, rec_func in zip([pywt.swt, pywt.swt2, pywt.swtn],
[pywt.iswt, pywt.iswt2, pywt.iswtn]):
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
assert_raises(ValueError, dec_func, data, wavelet=cwave,
level=3)
c = dec_func(data, 'db1', level=3)
assert_raises(ValueError, rec_func, c, wavelet=cwave)
def test_iswt_mixed_dtypes():
# Mixed precision inputs give double precision output
x_real = np.arange(16).astype(np.float64)
x_complex = x_real + 1j*x_real
wav = 'sym2'
for dtype1, dtype2 in [(np.float64, np.float32),
(np.float32, np.float64),
(np.float16, np.float64),
(np.complex128, np.complex64),
(np.complex64, np.complex128)]:
if dtype1 in [np.complex64, np.complex128]:
x = x_complex
output_dtype = np.complex128
else:
x = x_real
output_dtype = np.float64
coeffs = pywt.swt(x, wav, 2)
# different precision for the approximation coefficients
coeffs[0] = [coeffs[0][0].astype(dtype1),
coeffs[0][1].astype(dtype2)]
y = pywt.iswt(coeffs, wav)
assert_equal(output_dtype, y.dtype)
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
def test_iswt2_mixed_dtypes():
# Mixed precision inputs give double precision output
rstate = np.random.RandomState(0)
x_real = rstate.randn(8, 8)
x_complex = x_real + 1j*x_real
wav = 'sym2'
for dtype1, dtype2 in [(np.float64, np.float32),
(np.float32, np.float64),
(np.float16, np.float64),
(np.complex128, np.complex64),
(np.complex64, np.complex128)]:
if dtype1 in [np.complex64, np.complex128]:
x = x_complex
output_dtype = np.complex128
else:
x = x_real
output_dtype = np.float64
coeffs = pywt.swt2(x, wav, 2)
# different precision for the approximation coefficients
coeffs[0] = [coeffs[0][0].astype(dtype1),
tuple([c.astype(dtype2) for c in coeffs[0][1]])]
y = pywt.iswt2(coeffs, wav)
assert_equal(output_dtype, y.dtype)
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
def test_iswtn_mixed_dtypes():
# Mixed precision inputs give double precision output
rstate = np.random.RandomState(0)
x_real = rstate.randn(8, 8, 8)
x_complex = x_real + 1j*x_real
wav = 'sym2'
for dtype1, dtype2 in [(np.float64, np.float32),
(np.float32, np.float64),
(np.float16, np.float64),
(np.complex128, np.complex64),
(np.complex64, np.complex128)]:
if dtype1 in [np.complex64, np.complex128]:
x = x_complex
output_dtype = np.complex128
else:
x = x_real
output_dtype = np.float64
coeffs = pywt.swtn(x, wav, 2)
# different precision for the approximation coefficients
a = coeffs[0].pop('a' * x.ndim)
a = a.astype(dtype1)
coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()}
coeffs[0]['a' * x.ndim] = a
y = pywt.iswtn(coeffs, wav)
assert_equal(output_dtype, y.dtype)
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
def test_swt_zero_size_axes():
# raise on empty input array
assert_raises(ValueError, pywt.swt, [], 'db2')
# >1D case uses a different code path so check there as well
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,))
def test_swt_variance_and_energy_preservation():
"""Verify that the 1D SWT partitions variance among the coefficients."""
# When norm is True and the wavelet is orthogonal, the sum of the
# variances of the coefficients should equal the variance of the signal.
wav = 'db2'
rstate = np.random.RandomState(5)
x = rstate.randn(256)
coeffs = pywt.swt(x, wav, trim_approx=True, norm=True)
variances = [np.var(c) for c in coeffs]
assert_allclose(np.sum(variances), np.var(x))
# also verify L2-norm energy preservation property
assert_allclose(np.linalg.norm(x),
np.linalg.norm(np.concatenate(coeffs)))
# non-orthogonal wavelet with norm=True raises a warning
assert_warns(UserWarning, pywt.swt, x, 'bior2.2', norm=True)
def test_swt2_variance_and_energy_preservation():
"""Verify that the 2D SWT partitions variance among the coefficients."""
# When norm is True and the wavelet is orthogonal, the sum of the
# variances of the coefficients should equal the variance of the signal.
wav = 'db2'
rstate = np.random.RandomState(5)
x = rstate.randn(64, 64)
coeffs = pywt.swt2(x, wav, level=4, trim_approx=True, norm=True)
coeff_list = [coeffs[0].ravel()]
for d in coeffs[1:]:
for v in d:
coeff_list.append(v.ravel())
variances = [np.var(v) for v in coeff_list]
assert_allclose(np.sum(variances), np.var(x))
# also verify L2-norm energy preservation property
assert_allclose(np.linalg.norm(x),
np.linalg.norm(np.concatenate(coeff_list)))
# non-orthogonal wavelet with norm=True raises a warning
assert_warns(UserWarning, pywt.swt2, x, 'bior2.2', level=4, norm=True)
def test_swtn_variance_and_energy_preservation():
"""Verify that the nD SWT partitions variance among the coefficients."""
# When norm is True and the wavelet is orthogonal, the sum of the
# variances of the coefficients should equal the variance of the signal.
wav = 'db2'
rstate = np.random.RandomState(5)
x = rstate.randn(64, 64)
coeffs = pywt.swtn(x, wav, level=4, trim_approx=True, norm=True)
coeff_list = [coeffs[0].ravel()]
for d in coeffs[1:]:
for k, v in d.items():
coeff_list.append(v.ravel())
variances = [np.var(v) for v in coeff_list]
assert_allclose(np.sum(variances), np.var(x))
# also verify L2-norm energy preservation property
assert_allclose(np.linalg.norm(x),
np.linalg.norm(np.concatenate(coeff_list)))
# non-orthogonal wavelet with norm=True raises a warning
assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True)
def test_swt_ravel_and_unravel():
# When trim_approx=True, all swt functions can user pywt.ravel_coeffs
for ndim, _swt, _iswt, ravel_type in [
(1, pywt.swt, pywt.iswt, 'swt'),
(2, pywt.swt2, pywt.iswt2, 'swt2'),
(3, pywt.swtn, pywt.iswtn, 'swtn')]:
x = np.ones((16, ) * ndim)
c = _swt(x, 'sym2', level=3, trim_approx=True)
arr, slices, shapes = pywt.ravel_coeffs(c)
c = pywt.unravel_coeffs(arr, slices, shapes, output_format=ravel_type)
r = _iswt(c, 'sym2')
assert_allclose(x, r)