fr/fr_env/lib/python3.8/site-packages/pywt/tests/test_multilevel.py

1034 lines
38 KiB
Python

#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import warnings
from itertools import combinations
import numpy as np
import pytest
from numpy.testing import (assert_almost_equal, assert_allclose, assert_,
assert_equal, assert_raises, assert_raises_regex,
assert_array_equal, assert_warns)
import pywt
# Check that float32, float64, complex64, complex128 are preserved.
# Other real types get converted to float64.
# complex256 gets converted to complex128
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
np.complex128]
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
np.complex128]
# tolerances used in accuracy comparisons
tol_single = 1e-6
tol_double = 1e-13
dtypes_and_tolerances = [(np.float16, tol_single), (np.float32, tol_single),
(np.float64, tol_double), (np.int8, tol_double),
(np.complex64, tol_single),
(np.complex128, tol_double)]
# test complex256 as well if it is available
try:
dtypes_in += [np.complex256, ]
dtypes_out += [np.complex128, ]
dtypes_and_tolerances += [(np.complex256, tol_double), ]
except AttributeError:
pass
# determine which wavelets to test
wavelist = pywt.wavelist()
if 'dmey' in wavelist:
# accuracy is very low for dmey, so omit it
wavelist.remove('dmey')
# removing wavelets with dwt_possible == False
del_list = []
for wavelet in wavelist:
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
if not isinstance(pywt.DiscreteContinuousWavelet(wavelet),
pywt.Wavelet):
del_list.append(wavelet)
for del_ind in del_list:
wavelist.remove(del_ind)
####
# 1d multilevel dwt tests
####
def test_wavedec():
x = [3, 7, 1, 1, -2, 5, 4, 6]
db1 = pywt.Wavelet('db1')
cA3, cD3, cD2, cD1 = pywt.wavedec(x, db1)
assert_almost_equal(cA3, [8.83883476])
assert_almost_equal(cD3, [-0.35355339])
assert_allclose(cD2, [4., -3.5])
assert_allclose(cD1, [-2.82842712, 0, -4.94974747, -1.41421356])
assert_(pywt.dwt_max_level(len(x), db1) == 3)
def test_waverec_invalid_inputs():
# input must be list or tuple
assert_raises(ValueError, pywt.waverec, np.ones(8), 'haar')
# input list cannot be empty
assert_raises(ValueError, pywt.waverec, [], 'haar')
# 'array_to_coeffs must specify 'output_format' to perform waverec
x = [3, 7, 1, 1, -2, 5, 4, 6]
coeffs = pywt.wavedec(x, 'db1')
arr, coeff_slices = pywt.coeffs_to_array(coeffs)
coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices)
message = "Unexpected detail coefficient type"
assert_raises_regex(ValueError, message, pywt.waverec, coeffs_from_arr,
'haar')
def test_waverec_accuracies():
rstate = np.random.RandomState(1234)
x0 = rstate.randn(8)
for dt, tol in dtypes_and_tolerances:
x = x0.astype(dt)
if np.iscomplexobj(x):
x += 1j*rstate.randn(8).astype(x.real.dtype)
coeffs = pywt.wavedec(x, 'db1')
assert_allclose(pywt.waverec(coeffs, 'db1'), x, atol=tol, rtol=tol)
def test_waverec_none():
x = [3, 7, 1, 1, -2, 5, 4, 6]
coeffs = pywt.wavedec(x, 'db1')
# set some coefficients to None
coeffs[2] = None
coeffs[0] = None
assert_(pywt.waverec(coeffs, 'db1').size, len(x))
def test_waverec_odd_length():
x = [3, 7, 1, 1, -2, 5]
coeffs = pywt.wavedec(x, 'db1')
assert_allclose(pywt.waverec(coeffs, 'db1'), x, rtol=1e-12)
def test_waverec_complex():
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
x = x + 1j
coeffs = pywt.wavedec(x, 'db1')
assert_allclose(pywt.waverec(coeffs, 'db1'), x, rtol=1e-12)
def test_multilevel_dtypes_1d():
# only checks that the result is of the expected type
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
# wavedec, waverec
x = np.ones(8, dtype=dt_in)
errmsg = "wrong dtype returned for {0} input".format(dt_in)
coeffs = pywt.wavedec(x, wavelet, level=2)
for c in coeffs:
assert_(c.dtype == dt_out, "wavedec: " + errmsg)
x_roundtrip = pywt.waverec(coeffs, wavelet)
assert_(x_roundtrip.dtype == dt_out, "waverec: " + errmsg)
def test_waverec_all_wavelets_modes():
# test 2D case using all wavelets and modes
rstate = np.random.RandomState(1234)
r = rstate.randn(80)
for wavelet in wavelist:
for mode in pywt.Modes.modes:
coeffs = pywt.wavedec(r, wavelet, mode=mode)
assert_allclose(pywt.waverec(coeffs, wavelet, mode=mode),
r, rtol=tol_single, atol=tol_single)
####
# 2d multilevel dwt function tests
####
def test_waverec2_accuracies():
rstate = np.random.RandomState(1234)
x0 = rstate.randn(4, 4)
for dt, tol in dtypes_and_tolerances:
x = x0.astype(dt)
if np.iscomplexobj(x):
x += 1j*rstate.randn(4, 4).astype(x.real.dtype)
coeffs = pywt.wavedec2(x, 'db1')
assert_(len(coeffs) == 3)
assert_allclose(pywt.waverec2(coeffs, 'db1'), x, atol=tol, rtol=tol)
def test_multilevel_dtypes_2d():
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
# wavedec2, waverec2
x = np.ones((8, 8), dtype=dt_in)
errmsg = "wrong dtype returned for {0} input".format(dt_in)
cA, coeffsD2, coeffsD1 = pywt.wavedec2(x, wavelet, level=2)
assert_(cA.dtype == dt_out, "wavedec2: " + errmsg)
for c in coeffsD1:
assert_(c.dtype == dt_out, "wavedec2: " + errmsg)
for c in coeffsD2:
assert_(c.dtype == dt_out, "wavedec2: " + errmsg)
x_roundtrip = pywt.waverec2([cA, coeffsD2, coeffsD1], wavelet)
assert_(x_roundtrip.dtype == dt_out, "waverec2: " + errmsg)
@pytest.mark.slow
def test_waverec2_all_wavelets_modes():
# test 2D case using all wavelets and modes
rstate = np.random.RandomState(1234)
r = rstate.randn(80, 96)
for wavelet in wavelist:
for mode in pywt.Modes.modes:
coeffs = pywt.wavedec2(r, wavelet, mode=mode)
assert_allclose(pywt.waverec2(coeffs, wavelet, mode=mode),
r, rtol=tol_single, atol=tol_single)
def test_wavedec2_complex():
data = np.ones((4, 4)) + 1j
coeffs = pywt.wavedec2(data, 'db1')
assert_(len(coeffs) == 3)
assert_allclose(pywt.waverec2(coeffs, 'db1'), data, rtol=1e-12)
def test_wavedec2_invalid_inputs():
# input array has too few dimensions
data = np.ones(4)
assert_raises(ValueError, pywt.wavedec2, data, 'haar')
def test_waverec2_invalid_inputs():
# input must be list or tuple
assert_raises(ValueError, pywt.waverec2, np.ones((8, 8)), 'haar')
# input list cannot be empty
assert_raises(ValueError, pywt.waverec2, [], 'haar')
# coefficients from a difference decomposition used as input
for dec_func in [pywt.wavedec, pywt.wavedecn]:
coeffs = dec_func(np.ones((8, 8)), 'haar')
message = "Unexpected detail coefficient type"
assert_raises_regex(ValueError, message, pywt.waverec2, coeffs,
'haar')
def test_waverec2_coeff_shape_mismatch():
x = np.ones((8, 8))
coeffs = pywt.wavedec2(x, 'db1')
# introduce a shape mismatch in the coefficients
coeffs = list(coeffs)
coeffs[1] = list(coeffs[1])
coeffs[1][1] = np.zeros((16, 1))
assert_raises(ValueError, pywt.waverec2, coeffs, 'db1')
def test_waverec2_odd_length():
x = np.ones((10, 6))
coeffs = pywt.wavedec2(x, 'db1')
assert_allclose(pywt.waverec2(coeffs, 'db1'), x, rtol=1e-12)
def test_waverec2_none_coeffs():
x = np.arange(24).reshape(6, 4)
coeffs = pywt.wavedec2(x, 'db1')
coeffs[1] = (None, None, None)
assert_(x.shape == pywt.waverec2(coeffs, 'db1').shape)
####
# nd multilevel dwt function tests
####
def test_waverecn():
rstate = np.random.RandomState(1234)
# test 1D through 4D cases
for nd in range(1, 5):
x = rstate.randn(*(4, )*nd)
coeffs = pywt.wavedecn(x, 'db1')
assert_(len(coeffs) == 3)
assert_allclose(pywt.waverecn(coeffs, 'db1'), x, rtol=tol_double)
def test_waverecn_empty_coeff():
coeffs = [np.ones((2, 2, 2)), {}, {}]
assert_equal(pywt.waverecn(coeffs, 'db1').shape, (8, 8, 8))
assert_equal(pywt.waverecn(coeffs, 'db1').shape, (8, 8, 8))
coeffs = [np.ones((2, 2, 2)), {}, {'daa': np.ones((4, 4, 4))}]
coeffs = [np.ones((2, 2, 2)), {}, {}, {'daa': np.ones((8, 8, 8))}]
assert_equal(pywt.waverecn(coeffs, 'db1').shape, (16, 16, 16))
def test_waverecn_invalid_coeffs():
# approximation coeffs as None and no valid detail oeffs
coeffs = [None, {}]
assert_raises(ValueError, pywt.waverecn, coeffs, 'db1')
# use of None for a coefficient value
coeffs = [np.ones((2, 2, 2)), {}, {'daa': None}, ]
assert_raises(ValueError, pywt.waverecn, coeffs, 'db1')
# invalid key names in coefficient list
coeffs = [np.ones((4, 4, 4)), {'daa': np.ones((4, 4, 4)),
'foo': np.ones((4, 4, 4))}]
assert_raises(ValueError, pywt.waverecn, coeffs, 'db1')
# mismatched key name lengths
coeffs = [np.ones((4, 4, 4)), {'daa': np.ones((4, 4, 4)),
'da': np.ones((4, 4, 4))}]
assert_raises(ValueError, pywt.waverecn, coeffs, 'db1')
# key name lengths don't match the array dimensions
coeffs = [[[[1.0]]], {'ad': [[[0.0]]], 'da': [[[0.0]]], 'dd': [[[0.0]]]}]
assert_raises(ValueError, pywt.waverecn, coeffs, 'db1')
# input list cannot be empty
assert_raises(ValueError, pywt.waverecn, [], 'haar')
def test_waverecn_invalid_inputs():
# coefficients from a difference decomposition used as input
for dec_func in [pywt.wavedec, pywt.wavedec2]:
coeffs = dec_func(np.ones((8, 8)), 'haar')
message = "Unexpected detail coefficient type"
assert_raises_regex(ValueError, message, pywt.waverecn, coeffs,
'haar')
def test_waverecn_lists():
# support coefficient arrays specified as lists instead of arrays
coeffs = [[[1.0]], {'ad': [[0.0]], 'da': [[0.0]], 'dd': [[0.0]]}]
assert_equal(pywt.waverecn(coeffs, 'db1').shape, (2, 2))
def test_waverecn_invalid_coeffs2():
# shape mismatch should raise an error
coeffs = [np.ones((4, 4, 4)), {'ada': np.ones((4, 4))}]
assert_raises(ValueError, pywt.waverecn, coeffs, 'db1')
def test_wavedecn_invalid_inputs():
# input array has too few dimensions
data = np.array(0)
assert_raises(ValueError, pywt.wavedecn, data, 'haar')
# invalid number of levels
data = np.ones(16)
assert_raises(ValueError, pywt.wavedecn, data, 'haar', level=-1)
def test_wavedecn_many_levels():
# perfect reconstruction even when level > pywt.dwt_max_level
data = np.arange(64).reshape(8, 8)
tol = 1e-12
dec_funcs = [pywt.wavedec, pywt.wavedec2, pywt.wavedecn]
rec_funcs = [pywt.waverec, pywt.waverec2, pywt.waverecn]
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning)
for dec_func, rec_func in zip(dec_funcs, rec_funcs):
for mode in ['periodization', 'symmetric']:
coeffs = dec_func(data, 'haar', mode=mode, level=20)
r = rec_func(coeffs, 'haar', mode=mode)
assert_allclose(data, r, atol=tol, rtol=tol)
def test_waverecn_accuracies():
# testing 3D only here
rstate = np.random.RandomState(1234)
x0 = rstate.randn(4, 4, 4)
for dt, tol in dtypes_and_tolerances:
x = x0.astype(dt)
if np.iscomplexobj(x):
x += 1j*rstate.randn(4, 4, 4).astype(x.real.dtype)
coeffs = pywt.wavedecn(x.astype(dt), 'db1')
assert_allclose(pywt.waverecn(coeffs, 'db1'), x, atol=tol, rtol=tol)
def test_multilevel_dtypes_nd():
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
# wavedecn, waverecn
x = np.ones((8, 8), dtype=dt_in)
errmsg = "wrong dtype returned for {0} input".format(dt_in)
cA, coeffsD2, coeffsD1 = pywt.wavedecn(x, wavelet, level=2)
assert_(cA.dtype == dt_out, "wavedecn: " + errmsg)
for key, c in coeffsD1.items():
assert_(c.dtype == dt_out, "wavedecn: " + errmsg)
for key, c in coeffsD2.items():
assert_(c.dtype == dt_out, "wavedecn: " + errmsg)
x_roundtrip = pywt.waverecn([cA, coeffsD2, coeffsD1], wavelet)
assert_(x_roundtrip.dtype == dt_out, "waverecn: " + errmsg)
def test_wavedecn_complex():
data = np.ones((4, 4, 4)) + 1j
coeffs = pywt.wavedecn(data, 'db1')
assert_allclose(pywt.waverecn(coeffs, 'db1'), data, rtol=1e-12)
def test_waverecn_dtypes():
x = np.ones((4, 4, 4))
for dt, tol in dtypes_and_tolerances:
coeffs = pywt.wavedecn(x.astype(dt), 'db1')
assert_allclose(pywt.waverecn(coeffs, 'db1'), x, atol=tol, rtol=tol)
@pytest.mark.slow
def test_waverecn_all_wavelets_modes():
# test 2D case using all wavelets and modes
rstate = np.random.RandomState(1234)
r = rstate.randn(80, 96)
for wavelet in wavelist:
for mode in pywt.Modes.modes:
coeffs = pywt.wavedecn(r, wavelet, mode=mode)
assert_allclose(pywt.waverecn(coeffs, wavelet, mode=mode),
r, rtol=tol_single, atol=tol_single)
def test_coeffs_to_array():
# single element list returns the first element
a_coeffs = [np.arange(8).reshape(2, 4), ]
arr, arr_slices = pywt.coeffs_to_array(a_coeffs)
assert_allclose(arr, a_coeffs[0])
assert_allclose(arr, arr[arr_slices[0]])
assert_raises(ValueError, pywt.coeffs_to_array, [])
# invalid second element: array as in wavedec, but not 1D
assert_raises(ValueError, pywt.coeffs_to_array, [a_coeffs[0], ] * 2)
# invalid second element: tuple as in wavedec2, but not a 3-tuple
assert_raises(ValueError, pywt.coeffs_to_array, [a_coeffs[0],
(a_coeffs[0], )])
# coefficients as None is not supported
assert_raises(ValueError, pywt.coeffs_to_array, [None, ])
assert_raises(ValueError, pywt.coeffs_to_array, [a_coeffs,
(None, None, None)])
# invalid type for second coefficient list element
assert_raises(ValueError, pywt.coeffs_to_array, [a_coeffs, None])
# use an invalid key name in the coef dictionary
coeffs = [np.array([0]), dict(d=np.array([0]), c=np.array([0]))]
assert_raises(ValueError, pywt.coeffs_to_array, coeffs)
def test_wavedecn_coeff_reshape_even():
# verify round trip is correct:
# wavedecn - >coeffs_to_array-> array_to_coeffs -> waverecn
# This is done for wavedec{1, 2, n}
rng = np.random.RandomState(1234)
params = {'wavedec': {'d': 1, 'dec': pywt.wavedec, 'rec': pywt.waverec},
'wavedec2': {'d': 2, 'dec': pywt.wavedec2, 'rec': pywt.waverec2},
'wavedecn': {'d': 3, 'dec': pywt.wavedecn, 'rec': pywt.waverecn}}
N = 28
for f in params:
x1 = rng.randn(*([N] * params[f]['d']))
for mode in pywt.Modes.modes:
for wave in wavelist:
w = pywt.Wavelet(wave)
maxlevel = pywt.dwt_max_level(np.min(x1.shape), w.dec_len)
if maxlevel == 0:
continue
coeffs = params[f]['dec'](x1, w, mode=mode)
coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs)
coeffs2 = pywt.array_to_coeffs(coeff_arr, coeff_slices,
output_format=f)
x1r = params[f]['rec'](coeffs2, w, mode=mode)
assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)
def test_wavedecn_coeff_reshape_axes_subset():
# verify round trip is correct when only a subset of axes are transformed:
# wavedecn - >coeffs_to_array-> array_to_coeffs -> waverecn
# This is done for wavedec{1, 2, n}
rng = np.random.RandomState(1234)
mode = 'symmetric'
w = pywt.Wavelet('db2')
N = 16
ndim = 3
for axes in [(-1, ), (0, ), (1, ), (0, 1), (1, 2), (0, 2), None]:
x1 = rng.randn(*([N] * ndim))
coeffs = pywt.wavedecn(x1, w, mode=mode, axes=axes)
coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs, axes=axes)
if axes is not None:
# if axes is not None, it must be provided to coeffs_to_array
assert_raises(ValueError, pywt.coeffs_to_array, coeffs)
# mismatched axes size
assert_raises(ValueError, pywt.coeffs_to_array, coeffs,
axes=(0, 1, 2, 3))
assert_raises(ValueError, pywt.coeffs_to_array, coeffs,
axes=())
coeffs2 = pywt.array_to_coeffs(coeff_arr, coeff_slices)
x1r = pywt.waverecn(coeffs2, w, mode=mode, axes=axes)
assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)
def test_coeffs_to_array_padding():
rng = np.random.RandomState(1234)
x1 = rng.randn(32, 32)
mode = 'symmetric'
coeffs = pywt.wavedecn(x1, 'db2', mode=mode)
# padding=None raises a ValueError when tight packing is not possible
assert_raises(ValueError, pywt.coeffs_to_array, coeffs, padding=None)
# set padded values to nan
coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs, padding=np.nan)
npad = np.sum(np.isnan(coeff_arr))
assert_(npad > 0)
# pad with zeros
coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs, padding=0)
assert_(np.sum(np.isnan(coeff_arr)) == 0)
assert_(np.sum(coeff_arr == 0) == npad)
# Haar case with N as a power of 2 can be tightly packed
coeffs_haar = pywt.wavedecn(x1, 'haar', mode=mode)
coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs_haar, padding=None)
# shape of coeff_arr will match in this case, but not in general
assert_equal(coeff_arr.shape, x1.shape)
def test_waverecn_coeff_reshape_odd():
# verify round trip is correct:
# wavedecn - >coeffs_to_array-> array_to_coeffs -> waverecn
rng = np.random.RandomState(1234)
x1 = rng.randn(35, 33)
for mode in pywt.Modes.modes:
for wave in ['haar', ]:
w = pywt.Wavelet(wave)
maxlevel = pywt.dwt_max_level(np.min(x1.shape), w.dec_len)
if maxlevel == 0:
continue
coeffs = pywt.wavedecn(x1, w, mode=mode)
coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs)
coeffs2 = pywt.array_to_coeffs(coeff_arr, coeff_slices)
x1r = pywt.waverecn(coeffs2, w, mode=mode)
# truncate reconstructed values to original shape
x1r = x1r[tuple([slice(s) for s in x1.shape])]
assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)
def test_array_to_coeffs_invalid_inputs():
coeffs = pywt.wavedecn(np.ones(2), 'haar')
arr, arr_slices = pywt.coeffs_to_array(coeffs)
# empty list of array slices
assert_raises(ValueError, pywt.array_to_coeffs, arr, [])
# invalid format name
assert_raises(ValueError, pywt.array_to_coeffs, arr, arr_slices, 'foo')
def test_wavedecn_coeff_ravel():
# verify round trip is correct:
# wavedecn - >ravel_coeffs-> unravel_coeffs -> waverecn
# This is done for wavedec{1, 2, n}
rng = np.random.RandomState(1234)
params = {'wavedec': {'d': 1, 'dec': pywt.wavedec, 'rec': pywt.waverec},
'wavedec2': {'d': 2, 'dec': pywt.wavedec2, 'rec': pywt.waverec2},
'wavedecn': {'d': 3, 'dec': pywt.wavedecn, 'rec': pywt.waverecn}}
N = 12
for f in params:
x1 = rng.randn(*([N] * params[f]['d']))
for mode in pywt.Modes.modes:
for wave in wavelist:
w = pywt.Wavelet(wave)
maxlevel = pywt.dwt_max_level(np.min(x1.shape), w.dec_len)
if maxlevel == 0:
continue
coeffs = params[f]['dec'](x1, w, mode=mode)
coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs)
coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes,
output_format=f)
x1r = params[f]['rec'](coeffs2, w, mode=mode)
assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)
def test_wavedecn_coeff_ravel_zero_level():
# verify round trip is correct:
# wavedecn - >ravel_coeffs-> unravel_coeffs -> waverecn
# This is done for wavedec{1, 2, n}
rng = np.random.RandomState(1234)
params = {'wavedec': {'d': 1, 'dec': pywt.wavedec, 'rec': pywt.waverec},
'wavedec2': {'d': 2, 'dec': pywt.wavedec2, 'rec': pywt.waverec2},
'wavedecn': {'d': 3, 'dec': pywt.wavedecn, 'rec': pywt.waverecn}}
N = 16
for f in params:
x1 = rng.randn(*([N] * params[f]['d']))
for mode in pywt.Modes.modes:
w = pywt.Wavelet('db2')
coeffs = params[f]['dec'](x1, w, mode=mode, level=0)
coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs)
coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes,
output_format=f)
x1r = params[f]['rec'](coeffs2, w, mode=mode)
assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)
def test_waverecn_coeff_ravel_odd():
# verify round trip is correct:
# wavedecn - >ravel_coeffs-> unravel_coeffs -> waverecn
rng = np.random.RandomState(1234)
x1 = rng.randn(35, 33)
for mode in pywt.Modes.modes:
for wave in ['haar', ]:
w = pywt.Wavelet(wave)
maxlevel = pywt.dwt_max_level(np.min(x1.shape), w.dec_len)
if maxlevel == 0:
continue
coeffs = pywt.wavedecn(x1, w, mode=mode)
coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs)
coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes)
x1r = pywt.waverecn(coeffs2, w, mode=mode)
# truncate reconstructed values to original shape
x1r = x1r[tuple([slice(s) for s in x1.shape])]
assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)
def test_ravel_wavedec2_with_lists():
x1 = np.ones((8, 8))
wav = pywt.Wavelet('haar')
coeffs = pywt.wavedec2(x1, wav)
# list [cHn, cVn, cDn] instead of tuple is okay
coeffs[1:] = [list(c) for c in coeffs[1:]]
coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs)
coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes,
output_format='wavedec2')
x1r = pywt.waverec2(coeffs2, wav)
assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)
# wrong length list will cause a ValueError
coeffs[1:] = [list(c[:-1]) for c in coeffs[1:]] # truncate diag coeffs
assert_raises(ValueError, pywt.ravel_coeffs, coeffs)
def test_ravel_invalid_input():
# wavedec ravel does not support any coefficient arrays being set to None
coeffs = pywt.wavedec(np.ones(8), 'haar')
coeffs[1] = None
assert_raises(ValueError, pywt.ravel_coeffs, coeffs)
# wavedec2 ravel cannot have None or a tuple/list of None
coeffs = pywt.wavedec2(np.ones((8, 8)), 'haar')
coeffs[1] = (None, None, None)
assert_raises(ValueError, pywt.ravel_coeffs, coeffs)
coeffs[1] = [None, None, None]
assert_raises(ValueError, pywt.ravel_coeffs, coeffs)
coeffs[1] = None
assert_raises(ValueError, pywt.ravel_coeffs, coeffs)
# wavedecn ravel cannot have any dictionary elements as None
coeffs = pywt.wavedecn(np.ones((8, 8, 8)), 'haar')
coeffs[1]['ddd'] = None
assert_raises(ValueError, pywt.ravel_coeffs, coeffs)
def test_unravel_invalid_inputs():
coeffs = pywt.wavedecn(np.ones(2), 'haar')
arr, slices, shapes = pywt.ravel_coeffs(coeffs)
# empty list for slices or shapes
assert_raises(ValueError, pywt.unravel_coeffs, arr, slices, [])
assert_raises(ValueError, pywt.unravel_coeffs, arr, [], shapes)
# unequal length for slices/shapes
assert_raises(ValueError, pywt.unravel_coeffs, arr, slices[:-1], shapes)
# invalid format name
assert_raises(ValueError, pywt.unravel_coeffs, arr, slices, shapes, 'foo')
def test_wavedecn_shapes_and_size():
wav = pywt.Wavelet('db2')
for data_shape in [(33, ), (64, 32), (1, 15, 30)]:
for axes in [None, 0, -1]:
for mode in pywt.Modes.modes:
coeffs = pywt.wavedecn(np.ones(data_shape), wav,
mode=mode, axes=axes)
# verify that the shapes match the coefficient shapes
shapes = pywt.wavedecn_shapes(data_shape, wav,
mode=mode, axes=axes)
assert_equal(coeffs[0].shape, shapes[0])
expected_size = coeffs[0].size
for level in range(1, len(coeffs)):
for k, v in coeffs[level].items():
expected_size += v.size
assert_equal(shapes[level][k], v.shape)
# size can be determined from either the shapes or coeffs
size = pywt.wavedecn_size(shapes)
assert_equal(size, expected_size)
size = pywt.wavedecn_size(coeffs)
assert_equal(size, expected_size)
def test_dwtn_max_level():
# predicted and empirical dwtn_max_level match
for wav in [pywt.Wavelet('db2'), 'sym8']:
for data_shape in [(33, ), (64, 32), (1, 15, 30)]:
for axes in [None, 0, -1]:
for mode in pywt.Modes.modes:
coeffs = pywt.wavedecn(np.ones(data_shape), wav,
mode=mode, axes=axes)
max_lev = pywt.dwtn_max_level(data_shape, wav, axes)
assert_equal(len(coeffs[1:]), max_lev)
def test_waverec_axes_subsets():
rstate = np.random.RandomState(0)
data = rstate.standard_normal((8, 8, 8))
# test all combinations of 1 out of 3 axes transformed
for axis in [0, 1, 2]:
coefs = pywt.wavedec(data, 'haar', axis=axis)
rec = pywt.waverec(coefs, 'haar', axis=axis)
assert_allclose(rec, data, atol=1e-14)
def test_waverec_axis_db2():
# test for fix to issue gh-293
rstate = np.random.RandomState(0)
data = rstate.standard_normal((16, 16))
for axis in [0, 1]:
coefs = pywt.wavedec(data, 'db2', axis=axis)
rec = pywt.waverec(coefs, 'db2', axis=axis)
assert_allclose(rec, data, atol=1e-14)
def test_waverec2_axes_subsets():
rstate = np.random.RandomState(0)
data = rstate.standard_normal((8, 8, 8))
# test all combinations of 2 out of 3 axes transformed
for axes in combinations((0, 1, 2), 2):
coefs = pywt.wavedec2(data, 'haar', axes=axes)
rec = pywt.waverec2(coefs, 'haar', axes=axes)
assert_allclose(rec, data, atol=1e-14)
def test_waverecn_axes_subsets():
rstate = np.random.RandomState(0)
data = rstate.standard_normal((8, 8, 8, 8))
# test all combinations of 3 out of 4 axes transformed
for axes in combinations((0, 1, 2, 3), 3):
coefs = pywt.wavedecn(data, 'haar', axes=axes)
rec = pywt.waverecn(coefs, 'haar', axes=axes)
assert_allclose(rec, data, atol=1e-14)
def test_waverecn_int_axis():
# waverecn should also work for axes as an integer
rstate = np.random.RandomState(0)
data = rstate.standard_normal((8, 8))
for axis in [0, 1]:
coefs = pywt.wavedecn(data, 'haar', axes=axis)
rec = pywt.waverecn(coefs, 'haar', axes=axis)
assert_allclose(rec, data, atol=1e-14)
def test_wavedec_axis_error():
data = np.ones(4)
# out of range axis not allowed
assert_raises(ValueError, pywt.wavedec, data, 'haar', axis=1)
def test_waverec_axis_error():
c = pywt.wavedec(np.ones(4), 'haar')
# out of range axis not allowed
assert_raises(ValueError, pywt.waverec, c, 'haar', axis=1)
def test_waverec_shape_mismatch_error():
c = pywt.wavedec(np.ones(16), 'haar')
# truncate a detail coefficient to an incorrect shape
c[3] = c[3][:-1]
assert_raises(ValueError, pywt.waverec, c, 'haar', axis=1)
def test_wavedec2_axes_errors():
data = np.ones((4, 4))
# integer axes not allowed
assert_raises(TypeError, pywt.wavedec2, data, 'haar', axes=1)
# non-unique axes not allowed
assert_raises(ValueError, pywt.wavedec2, data, 'haar', axes=(0, 0))
# out of range axis not allowed
assert_raises(ValueError, pywt.wavedec2, data, 'haar', axes=(0, 2))
def test_waverec2_axes_errors():
data = np.ones((4, 4))
c = pywt.wavedec2(data, 'haar')
# integer axes not allowed
assert_raises(TypeError, pywt.waverec2, c, 'haar', axes=1)
# non-unique axes not allowed
assert_raises(ValueError, pywt.waverec2, c, 'haar', axes=(0, 0))
# out of range axis not allowed
assert_raises(ValueError, pywt.waverec2, c, 'haar', axes=(0, 2))
def test_wavedecn_axes_errors():
data = np.ones((8, 8, 8))
# repeated axes not allowed
assert_raises(ValueError, pywt.wavedecn, data, 'haar', axes=(1, 1))
# out of range axis not allowed
assert_raises(ValueError, pywt.wavedecn, data, 'haar', axes=(0, 1, 3))
def test_waverecn_axes_errors():
data = np.ones((8, 8, 8))
c = pywt.wavedecn(data, 'haar')
# repeated axes not allowed
assert_raises(ValueError, pywt.waverecn, c, 'haar', axes=(1, 1))
# out of range axis not allowed
assert_raises(ValueError, pywt.waverecn, c, 'haar', axes=(0, 1, 3))
def test_per_axis_wavelets_and_modes():
# tests seperate wavelet and edge mode for each axis.
rstate = np.random.RandomState(1234)
data = rstate.randn(24, 24, 16)
# wavelet can be a string or wavelet object
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db2')
# The default number of levels should be the minimum over this list
max_levels = [pywt._dwt.dwt_max_level(nd, nf) for nd, nf in
zip(data.shape, wavelets)]
# mode can be a string or a Modes enum
modes = ('symmetric', 'periodization',
pywt._extensions._pywt.Modes.reflect)
coefs = pywt.wavedecn(data, wavelets, modes)
assert_allclose(pywt.waverecn(coefs, wavelets, modes), data, atol=1e-14)
assert_equal(min(max_levels), len(coefs[1:]))
coefs = pywt.wavedecn(data, wavelets[:1], modes)
assert_allclose(pywt.waverecn(coefs, wavelets[:1], modes), data,
atol=1e-14)
coefs = pywt.wavedecn(data, wavelets, modes[:1])
assert_allclose(pywt.waverecn(coefs, wavelets, modes[:1]), data,
atol=1e-14)
# length of wavelets or modes doesn't match the length of axes
assert_raises(ValueError, pywt.wavedecn, data, wavelets[:2])
assert_raises(ValueError, pywt.wavedecn, data, wavelets, mode=modes[:2])
assert_raises(ValueError, pywt.waverecn, coefs, wavelets[:2])
assert_raises(ValueError, pywt.waverecn, coefs, wavelets, mode=modes[:2])
# dwt2/idwt2 also support per-axis wavelets/modes
data2 = data[..., 0]
coefs2 = pywt.wavedec2(data2, wavelets[:2], modes[:2])
assert_allclose(pywt.waverec2(coefs2, wavelets[:2], modes[:2]), data2,
atol=1e-14)
assert_equal(min(max_levels[:2]), len(coefs2[1:]))
# Tests for fully separable multi-level transforms
def test_fswavedecn_fswaverecn_roundtrip():
# verify proper round trip result for 1D through 4D data
# same DWT as wavedecn/waverecn so don't need to test all modes/wavelets
rstate = np.random.RandomState(0)
for ndim in range(1, 5):
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
for levels in (1, None):
data = rstate.standard_normal((8, )*ndim)
data = data.astype(dt_in)
T = pywt.fswavedecn(data, 'haar', levels=levels)
rec = pywt.fswaverecn(T)
if data.real.dtype in [np.float32, np.float16]:
assert_allclose(rec, data, rtol=1e-6, atol=1e-6)
else:
assert_allclose(rec, data, rtol=1e-14, atol=1e-14)
assert_(T.coeffs.dtype == dt_out)
assert_(rec.dtype == dt_out)
def test_fswavedecn_fswaverecn_zero_levels():
# zero level transform gives coefs matching the original data
rstate = np.random.RandomState(0)
ndim = 2
data = rstate.standard_normal((8, )*ndim)
T = pywt.fswavedecn(data, 'haar', levels=0)
assert_array_equal(T.coeffs, data)
rec = pywt.fswaverecn(T)
assert_array_equal(T.coeffs, rec)
def test_fswavedecn_fswaverecn_variable_levels():
# test with differing number of transform levels per axis
rstate = np.random.RandomState(0)
ndim = 3
data = rstate.standard_normal((16, )*ndim)
T = pywt.fswavedecn(data, 'haar', levels=(1, 2, 3))
rec = pywt.fswaverecn(T)
assert_allclose(rec, data, atol=1e-14)
# levels doesn't match number of axes
assert_raises(ValueError, pywt.fswavedecn, data, 'haar', levels=(1, 1))
assert_raises(ValueError, pywt.fswavedecn, data, 'haar', levels=(1, 1, 1, 1))
# levels too large for array size
assert_warns(UserWarning, pywt.fswavedecn, data, 'haar',
levels=int(np.log2(np.min(data.shape)))+1)
def test_fswavedecn_fswaverecn_variable_wavelets_and_modes():
# test with differing number of transform levels per axis
rstate = np.random.RandomState(0)
ndim = 3
data = rstate.standard_normal((16, )*ndim)
wavelets = ('haar', 'db2', 'sym3')
modes = ('periodic', 'symmetric', 'periodization')
T = pywt.fswavedecn(data, wavelet=wavelets, mode=modes)
for ax in range(ndim):
# expect approx + dwt_max_level detail coeffs along each axis
assert_equal(len(T.coeff_slices[ax]),
pywt.dwt_max_level(data.shape[ax], wavelets[ax])+1)
rec = pywt.fswaverecn(T)
assert_allclose(rec, data, atol=1e-14)
# number of wavelets doesn't match number of axes
assert_raises(ValueError, pywt.fswavedecn, data, wavelets[:2])
# number of modes doesn't match number of axes
assert_raises(ValueError, pywt.fswavedecn, data, wavelets[0], mode=modes[:2])
def test_fswavedecn_fswaverecn_axes_subsets():
"""Fully separable DWT over only a subset of axes"""
rstate = np.random.RandomState(0)
# use anisotropic data to result in unique number of levels per axis
data = rstate.standard_normal((4, 8, 16, 32))
# test all combinations of 3 out of 4 axes transformed
for axes in combinations((0, 1, 2, 3), 3):
T = pywt.fswavedecn(data, 'haar', axes=axes)
rec = pywt.fswaverecn(T)
assert_allclose(rec, data, atol=1e-14)
# some axes exceed data dimensions
assert_raises(ValueError, pywt.fswavedecn, data, 'haar', axes=(1, 5))
def test_fswavedecnresult():
data = np.ones((32, 32))
levels = (1, 2)
result = pywt.fswavedecn(data, 'sym2', levels=levels)
# can access the lowpass band via .approx or via __getitem__
approx_key = (0, ) * data.ndim
assert_array_equal(result[approx_key], result.approx)
dkeys = result.detail_keys()
# the approximation key shouldn't be present in the detail_keys
assert_(approx_key not in dkeys)
# can access all detail coefficients and they have matching ndim
for k in dkeys:
d = result[k]
assert_equal(d.ndim, data.ndim)
# can assign modified coefficients
result[k] = np.zeros_like(d)
# assigning a differently sized array raises a ValueError
assert_raises(ValueError, result.__setitem__,
k, np.zeros(tuple([s + 1 for s in d.shape])))
# warns on assigning with a non-matching dtype
assert_warns(UserWarning, result.__setitem__,
k, np.zeros_like(d).astype(np.float32))
# all coefficients are stacked into result.coeffs (same ndim)
assert_equal(result.coeffs.ndim, data.ndim)
def test_error_on_continuous_wavelet():
# A ValueError is raised if a Continuous wavelet is selected
data = np.ones((16, 16))
for dec_fun, rec_fun in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn],
[pywt.waverec, pywt.waverec2, pywt.waverecn]):
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
assert_raises(ValueError, dec_fun, data, wavelet=cwave)
c = dec_fun(data, 'db1')
assert_raises(ValueError, rec_fun, c, wavelet=cwave)
def test_default_level():
# default level is the maximum permissible for the transformed axes
data = np.ones((128, 32, 4))
wavelet = ('db8', 'db1')
for dec_func in [pywt.wavedec2, pywt.wavedecn]:
for axes in [(0, 1), (2, 1), (0, 2)]:
c = dec_func(data, wavelet, axes=axes)
max_lev = np.min([pywt.dwt_max_level(data.shape[ax], wav)
for ax, wav in zip(axes, wavelet)])
assert_equal(len(c[1:]), max_lev)
for ax in [0, 1]:
c = pywt.wavedecn(data, wavelet[ax], axes=(ax, ))
assert_equal(len(c[1:]),
pywt.dwt_max_level(data.shape[ax], wavelet[ax]))
def test_waverec_mixed_precision():
rstate = np.random.RandomState(0)
for func, ifunc, shape in [(pywt.wavedec, pywt.waverec, (8, )),
(pywt.wavedec2, pywt.waverec2, (8, 8)),
(pywt.wavedecn, pywt.waverecn, (8, 8, 8))]:
x = rstate.randn(*shape)
coeffs_real = func(x, 'db1')
# real: single precision approx, double precision details
coeffs_real[0] = coeffs_real[0].astype(np.float32)
r = ifunc(coeffs_real, 'db1')
assert_allclose(r, x, rtol=1e-7, atol=1e-7)
assert_equal(r.dtype, np.float64)
x = x + 1j*x
coeffs = func(x, 'db1')
# complex: single precision approx, double precision details
coeffs[0] = coeffs[0].astype(np.complex64)
r = ifunc(coeffs, 'db1')
assert_allclose(r, x, rtol=1e-7, atol=1e-7)
assert_equal(r.dtype, np.complex128)
# complex: double precision approx, single precision details
if x.ndim == 1:
coeffs[0] = coeffs[0].astype(np.complex128)
coeffs[1] = coeffs[1].astype(np.complex64)
if x.ndim == 2:
coeffs[0] = coeffs[0].astype(np.complex128)
coeffs[1] = tuple([v.astype(np.complex64) for v in coeffs[1]])
if x.ndim == 3:
coeffs[0] = coeffs[0].astype(np.complex128)
coeffs[1] = {k: v.astype(np.complex64)
for k, v in coeffs[1].items()}
r = ifunc(coeffs, 'db1')
assert_allclose(r, x, rtol=1e-7, atol=1e-7)
assert_equal(r.dtype, np.complex128)