forked from 170010011/fr
444 lines
15 KiB
Python
444 lines
15 KiB
Python
|
#!/usr/bin/env python
|
||
|
|
||
|
from __future__ import division, print_function, absolute_import
|
||
|
|
||
|
import numpy as np
|
||
|
from itertools import combinations
|
||
|
from numpy.testing import assert_allclose, assert_, assert_raises, assert_equal
|
||
|
|
||
|
import pywt
|
||
|
# Check that float32, float64, complex64, complex128 are preserved.
|
||
|
# Other real types get converted to float64.
|
||
|
# complex256 gets converted to complex128
|
||
|
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
|
||
|
np.complex128]
|
||
|
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
|
||
|
np.complex128]
|
||
|
|
||
|
# test complex256 as well if it is available
|
||
|
try:
|
||
|
dtypes_in += [np.complex256, ]
|
||
|
dtypes_out += [np.complex128, ]
|
||
|
except AttributeError:
|
||
|
pass
|
||
|
|
||
|
|
||
|
def test_dwtn_input():
|
||
|
# Array-like must be accepted
|
||
|
pywt.dwtn([1, 2, 3, 4], 'haar')
|
||
|
# Others must not
|
||
|
data = dict()
|
||
|
assert_raises(TypeError, pywt.dwtn, data, 'haar')
|
||
|
# Must be at least 1D
|
||
|
assert_raises(ValueError, pywt.dwtn, 2, 'haar')
|
||
|
|
||
|
|
||
|
def test_3D_reconstruct():
|
||
|
data = np.array([
|
||
|
[[0, 4, 1, 5, 1, 4],
|
||
|
[0, 5, 26, 3, 2, 1],
|
||
|
[5, 8, 2, 33, 4, 9],
|
||
|
[2, 5, 19, 4, 19, 1]],
|
||
|
[[1, 5, 1, 2, 3, 4],
|
||
|
[7, 12, 6, 52, 7, 8],
|
||
|
[2, 12, 3, 52, 6, 8],
|
||
|
[5, 2, 6, 78, 12, 2]]])
|
||
|
|
||
|
wavelet = pywt.Wavelet('haar')
|
||
|
for mode in pywt.Modes.modes:
|
||
|
d = pywt.dwtn(data, wavelet, mode=mode)
|
||
|
assert_allclose(data, pywt.idwtn(d, wavelet, mode=mode),
|
||
|
rtol=1e-13, atol=1e-13)
|
||
|
|
||
|
|
||
|
def test_dwdtn_idwtn_allwavelets():
|
||
|
rstate = np.random.RandomState(1234)
|
||
|
r = rstate.randn(16, 16)
|
||
|
# test 2D case only for all wavelet types
|
||
|
wavelist = pywt.wavelist()
|
||
|
if 'dmey' in wavelist:
|
||
|
wavelist.remove('dmey')
|
||
|
for wavelet in wavelist:
|
||
|
if wavelet in ['cmor', 'shan', 'fbsp']:
|
||
|
# skip these CWT families to avoid warnings
|
||
|
continue
|
||
|
if isinstance(pywt.DiscreteContinuousWavelet(wavelet), pywt.Wavelet):
|
||
|
for mode in pywt.Modes.modes:
|
||
|
coeffs = pywt.dwtn(r, wavelet, mode=mode)
|
||
|
assert_allclose(pywt.idwtn(coeffs, wavelet, mode=mode),
|
||
|
r, rtol=1e-7, atol=1e-7)
|
||
|
|
||
|
|
||
|
def test_stride():
|
||
|
wavelet = pywt.Wavelet('haar')
|
||
|
|
||
|
for dtype in ('float32', 'float64'):
|
||
|
data = np.array([[0, 4, 1, 5, 1, 4],
|
||
|
[0, 5, 6, 3, 2, 1],
|
||
|
[2, 5, 19, 4, 19, 1]],
|
||
|
dtype=dtype)
|
||
|
|
||
|
for mode in pywt.Modes.modes:
|
||
|
expected = pywt.dwtn(data, wavelet)
|
||
|
strided = np.ones((3, 12), dtype=data.dtype)
|
||
|
strided[::-1, ::2] = data
|
||
|
strided_dwtn = pywt.dwtn(strided[::-1, ::2], wavelet)
|
||
|
for key in expected.keys():
|
||
|
assert_allclose(strided_dwtn[key], expected[key])
|
||
|
|
||
|
|
||
|
def test_byte_offset():
|
||
|
wavelet = pywt.Wavelet('haar')
|
||
|
for dtype in ('float32', 'float64'):
|
||
|
data = np.array([[0, 4, 1, 5, 1, 4],
|
||
|
[0, 5, 6, 3, 2, 1],
|
||
|
[2, 5, 19, 4, 19, 1]],
|
||
|
dtype=dtype)
|
||
|
|
||
|
for mode in pywt.Modes.modes:
|
||
|
expected = pywt.dwtn(data, wavelet)
|
||
|
padded = np.ones((3, 6), dtype=np.dtype({'data': (data.dtype, 0),
|
||
|
'pad': ('byte', data.dtype.itemsize)},
|
||
|
align=True))
|
||
|
padded[:] = data
|
||
|
padded_dwtn = pywt.dwtn(padded['data'], wavelet)
|
||
|
for key in expected.keys():
|
||
|
assert_allclose(padded_dwtn[key], expected[key])
|
||
|
|
||
|
|
||
|
def test_3D_reconstruct_complex():
|
||
|
# All dimensions even length so `take` does not need to be specified
|
||
|
data = np.array([
|
||
|
[[0, 4, 1, 5, 1, 4],
|
||
|
[0, 5, 26, 3, 2, 1],
|
||
|
[5, 8, 2, 33, 4, 9],
|
||
|
[2, 5, 19, 4, 19, 1]],
|
||
|
[[1, 5, 1, 2, 3, 4],
|
||
|
[7, 12, 6, 52, 7, 8],
|
||
|
[2, 12, 3, 52, 6, 8],
|
||
|
[5, 2, 6, 78, 12, 2]]])
|
||
|
data = data + 1j
|
||
|
|
||
|
wavelet = pywt.Wavelet('haar')
|
||
|
d = pywt.dwtn(data, wavelet)
|
||
|
# idwtn creates even-length shapes (2x dwtn size)
|
||
|
original_shape = tuple([slice(None, s) for s in data.shape])
|
||
|
assert_allclose(data, pywt.idwtn(d, wavelet)[original_shape],
|
||
|
rtol=1e-13, atol=1e-13)
|
||
|
|
||
|
|
||
|
def test_idwtn_idwt2():
|
||
|
data = np.array([
|
||
|
[0, 4, 1, 5, 1, 4],
|
||
|
[0, 5, 6, 3, 2, 1],
|
||
|
[2, 5, 19, 4, 19, 1]])
|
||
|
|
||
|
wavelet = pywt.Wavelet('haar')
|
||
|
|
||
|
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||
|
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||
|
|
||
|
for mode in pywt.Modes.modes:
|
||
|
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
|
||
|
pywt.idwtn(d, wavelet, mode=mode),
|
||
|
rtol=1e-14, atol=1e-14)
|
||
|
|
||
|
|
||
|
def test_idwtn_idwt2_complex():
|
||
|
data = np.array([
|
||
|
[0, 4, 1, 5, 1, 4],
|
||
|
[0, 5, 6, 3, 2, 1],
|
||
|
[2, 5, 19, 4, 19, 1]])
|
||
|
data = data + 1j
|
||
|
wavelet = pywt.Wavelet('haar')
|
||
|
|
||
|
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||
|
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||
|
|
||
|
for mode in pywt.Modes.modes:
|
||
|
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
|
||
|
pywt.idwtn(d, wavelet, mode=mode),
|
||
|
rtol=1e-14, atol=1e-14)
|
||
|
|
||
|
|
||
|
def test_idwtn_missing():
|
||
|
# Test to confirm missing data behave as zeroes
|
||
|
data = np.array([
|
||
|
[0, 4, 1, 5, 1, 4],
|
||
|
[0, 5, 6, 3, 2, 1],
|
||
|
[2, 5, 19, 4, 19, 1]])
|
||
|
|
||
|
wavelet = pywt.Wavelet('haar')
|
||
|
|
||
|
coefs = pywt.dwtn(data, wavelet)
|
||
|
|
||
|
# No point removing zero, or all
|
||
|
for num_missing in range(1, len(coefs)):
|
||
|
for missing in combinations(coefs.keys(), num_missing):
|
||
|
missing_coefs = coefs.copy()
|
||
|
for key in missing:
|
||
|
del missing_coefs[key]
|
||
|
LL = missing_coefs.get('aa', None)
|
||
|
HL = missing_coefs.get('da', None)
|
||
|
LH = missing_coefs.get('ad', None)
|
||
|
HH = missing_coefs.get('dd', None)
|
||
|
|
||
|
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet),
|
||
|
pywt.idwtn(missing_coefs, 'haar'), atol=1e-15)
|
||
|
|
||
|
|
||
|
def test_idwtn_all_coeffs_None():
|
||
|
coefs = dict(aa=None, da=None, ad=None, dd=None)
|
||
|
assert_raises(ValueError, pywt.idwtn, coefs, 'haar')
|
||
|
|
||
|
|
||
|
def test_error_on_invalid_keys():
|
||
|
data = np.array([
|
||
|
[0, 4, 1, 5, 1, 4],
|
||
|
[0, 5, 6, 3, 2, 1],
|
||
|
[2, 5, 19, 4, 19, 1]])
|
||
|
|
||
|
wavelet = pywt.Wavelet('haar')
|
||
|
|
||
|
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||
|
|
||
|
# unexpected key
|
||
|
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH, 'ff': LH}
|
||
|
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||
|
|
||
|
# mismatched key lengths
|
||
|
d = {'a': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||
|
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||
|
|
||
|
|
||
|
def test_error_mismatched_size():
|
||
|
data = np.array([
|
||
|
[0, 4, 1, 5, 1, 4],
|
||
|
[0, 5, 6, 3, 2, 1],
|
||
|
[2, 5, 19, 4, 19, 1]])
|
||
|
|
||
|
wavelet = pywt.Wavelet('haar')
|
||
|
|
||
|
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||
|
|
||
|
# Pass/fail depends on first element being shorter than remaining ones so
|
||
|
# set 3/4 to an incorrect size to maximize chances. Order of dict items
|
||
|
# is random so may not trigger on every test run. Dict is constructed
|
||
|
# inside idwtn function so no use using an OrderedDict here.
|
||
|
LL = LL[:, :-1]
|
||
|
LH = LH[:, :-1]
|
||
|
HH = HH[:, :-1]
|
||
|
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||
|
|
||
|
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||
|
|
||
|
|
||
|
def test_dwt2_idwt2_dtypes():
|
||
|
wavelet = pywt.Wavelet('haar')
|
||
|
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||
|
x = np.ones((4, 4), dtype=dt_in)
|
||
|
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||
|
|
||
|
cA, (cH, cV, cD) = pywt.dwt2(x, wavelet)
|
||
|
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype,
|
||
|
"dwt2: " + errmsg)
|
||
|
|
||
|
x_roundtrip = pywt.idwt2((cA, (cH, cV, cD)), wavelet)
|
||
|
assert_(x_roundtrip.dtype == dt_out, "idwt2: " + errmsg)
|
||
|
|
||
|
|
||
|
def test_dwtn_axes():
|
||
|
data = np.array([[0, 1, 2, 3],
|
||
|
[1, 1, 1, 1],
|
||
|
[1, 4, 2, 8]])
|
||
|
data = data + 1j*data # test with complex data
|
||
|
coefs = pywt.dwtn(data, 'haar', axes=(1,))
|
||
|
expected_a = list(map(lambda x: pywt.dwt(x, 'haar')[0], data))
|
||
|
assert_equal(coefs['a'], expected_a)
|
||
|
expected_d = list(map(lambda x: pywt.dwt(x, 'haar')[1], data))
|
||
|
assert_equal(coefs['d'], expected_d)
|
||
|
|
||
|
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||
|
expected_aa = list(map(lambda x: pywt.dwt(x, 'haar')[0], expected_a))
|
||
|
assert_equal(coefs['aa'], expected_aa)
|
||
|
expected_ad = list(map(lambda x: pywt.dwt(x, 'haar')[1], expected_a))
|
||
|
assert_equal(coefs['ad'], expected_ad)
|
||
|
|
||
|
|
||
|
def test_idwtn_axes():
|
||
|
data = np.array([[0, 1, 2, 3],
|
||
|
[1, 1, 1, 1],
|
||
|
[1, 4, 2, 8]])
|
||
|
data = data + 1j*data # test with complex data
|
||
|
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||
|
assert_allclose(pywt.idwtn(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
|
||
|
|
||
|
|
||
|
def test_idwt2_none_coeffs():
|
||
|
data = np.array([[0, 1, 2, 3],
|
||
|
[1, 1, 1, 1],
|
||
|
[1, 4, 2, 8]])
|
||
|
data = data + 1j*data # test with complex data
|
||
|
cA, (cH, cV, cD) = pywt.dwt2(data, 'haar', axes=(1, 1))
|
||
|
|
||
|
# verify setting coefficients to None is the same as zeroing them
|
||
|
cD = np.zeros_like(cD)
|
||
|
result_zeros = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
|
||
|
|
||
|
cD = None
|
||
|
result_none = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
|
||
|
|
||
|
assert_equal(result_zeros, result_none)
|
||
|
|
||
|
|
||
|
def test_idwtn_none_coeffs():
|
||
|
data = np.array([[0, 1, 2, 3],
|
||
|
[1, 1, 1, 1],
|
||
|
[1, 4, 2, 8]])
|
||
|
data = data + 1j*data # test with complex data
|
||
|
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||
|
|
||
|
# verify setting coefficients to None is the same as zeroing them
|
||
|
coefs['dd'] = np.zeros_like(coefs['dd'])
|
||
|
result_zeros = pywt.idwtn(coefs, 'haar', axes=(1, 1))
|
||
|
|
||
|
coefs['dd'] = None
|
||
|
result_none = pywt.idwtn(coefs, 'haar', axes=(1, 1))
|
||
|
|
||
|
assert_equal(result_zeros, result_none)
|
||
|
|
||
|
|
||
|
def test_idwt2_axes():
|
||
|
data = np.array([[0, 1, 2, 3],
|
||
|
[1, 1, 1, 1],
|
||
|
[1, 4, 2, 8]])
|
||
|
coefs = pywt.dwt2(data, 'haar', axes=(1, 1))
|
||
|
assert_allclose(pywt.idwt2(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
|
||
|
|
||
|
# too many axes
|
||
|
assert_raises(ValueError, pywt.idwt2, coefs, 'haar', axes=(0, 1, 1))
|
||
|
|
||
|
|
||
|
def test_idwt2_axes_subsets():
|
||
|
data = np.array(np.random.standard_normal((4, 4, 4)))
|
||
|
# test all combinations of 2 out of 3 axes transformed
|
||
|
for axes in combinations((0, 1, 2), 2):
|
||
|
coefs = pywt.dwt2(data, 'haar', axes=axes)
|
||
|
assert_allclose(pywt.idwt2(coefs, 'haar', axes=axes), data, atol=1e-14)
|
||
|
|
||
|
|
||
|
def test_idwtn_axes_subsets():
|
||
|
data = np.array(np.random.standard_normal((4, 4, 4, 4)))
|
||
|
# test all combinations of 3 out of 4 axes transformed
|
||
|
for axes in combinations((0, 1, 2, 3), 3):
|
||
|
coefs = pywt.dwtn(data, 'haar', axes=axes)
|
||
|
assert_allclose(pywt.idwtn(coefs, 'haar', axes=axes), data, atol=1e-14)
|
||
|
|
||
|
|
||
|
def test_negative_axes():
|
||
|
data = np.array([[0, 1, 2, 3],
|
||
|
[1, 1, 1, 1],
|
||
|
[1, 4, 2, 8]])
|
||
|
coefs1 = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||
|
coefs2 = pywt.dwtn(data, 'haar', axes=(-1, -1))
|
||
|
assert_equal(coefs1, coefs2)
|
||
|
|
||
|
rec1 = pywt.idwtn(coefs1, 'haar', axes=(1, 1))
|
||
|
rec2 = pywt.idwtn(coefs1, 'haar', axes=(-1, -1))
|
||
|
assert_equal(rec1, rec2)
|
||
|
|
||
|
|
||
|
def test_dwtn_idwtn_dtypes():
|
||
|
wavelet = pywt.Wavelet('haar')
|
||
|
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||
|
x = np.ones((4, 4), dtype=dt_in)
|
||
|
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||
|
|
||
|
coeffs = pywt.dwtn(x, wavelet)
|
||
|
for k, v in coeffs.items():
|
||
|
assert_(v.dtype == dt_out, "dwtn: " + errmsg)
|
||
|
|
||
|
x_roundtrip = pywt.idwtn(coeffs, wavelet)
|
||
|
assert_(x_roundtrip.dtype == dt_out, "idwtn: " + errmsg)
|
||
|
|
||
|
|
||
|
def test_idwtn_mixed_complex_dtype():
|
||
|
rstate = np.random.RandomState(0)
|
||
|
x = rstate.randn(8, 8, 8)
|
||
|
x = x + 1j*x
|
||
|
coeffs = pywt.dwtn(x, 'db2')
|
||
|
|
||
|
x_roundtrip = pywt.idwtn(coeffs, 'db2')
|
||
|
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||
|
|
||
|
# mismatched dtypes OK
|
||
|
coeffs['a' * x.ndim] = coeffs['a' * x.ndim].astype(np.complex64)
|
||
|
x_roundtrip2 = pywt.idwtn(coeffs, 'db2')
|
||
|
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
|
||
|
assert_(x_roundtrip2.dtype == np.complex128)
|
||
|
|
||
|
|
||
|
def test_idwt2_size_mismatch_error():
|
||
|
LL = np.zeros((6, 6))
|
||
|
LH = HL = HH = np.zeros((5, 5))
|
||
|
|
||
|
assert_raises(ValueError, pywt.idwt2, (LL, (LH, HL, HH)), wavelet='haar')
|
||
|
|
||
|
|
||
|
def test_dwt2_dimension_error():
|
||
|
data = np.ones(16)
|
||
|
wavelet = pywt.Wavelet('haar')
|
||
|
|
||
|
# wrong number of input dimensions
|
||
|
assert_raises(ValueError, pywt.dwt2, data, wavelet)
|
||
|
|
||
|
# too many axes
|
||
|
data2 = np.ones((8, 8))
|
||
|
assert_raises(ValueError, pywt.dwt2, data2, wavelet, axes=(0, 1, 1))
|
||
|
|
||
|
|
||
|
def test_per_axis_wavelets_and_modes():
|
||
|
# tests seperate wavelet and edge mode for each axis.
|
||
|
rstate = np.random.RandomState(1234)
|
||
|
data = rstate.randn(16, 16, 16)
|
||
|
|
||
|
# wavelet can be a string or wavelet object
|
||
|
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
|
||
|
|
||
|
# mode can be a string or a Modes enum
|
||
|
modes = ('symmetric', 'periodization',
|
||
|
pywt._extensions._pywt.Modes.reflect)
|
||
|
|
||
|
coefs = pywt.dwtn(data, wavelets, modes)
|
||
|
assert_allclose(pywt.idwtn(coefs, wavelets, modes), data, atol=1e-14)
|
||
|
|
||
|
coefs = pywt.dwtn(data, wavelets[:1], modes)
|
||
|
assert_allclose(pywt.idwtn(coefs, wavelets[:1], modes), data, atol=1e-14)
|
||
|
|
||
|
coefs = pywt.dwtn(data, wavelets, modes[:1])
|
||
|
assert_allclose(pywt.idwtn(coefs, wavelets, modes[:1]), data, atol=1e-14)
|
||
|
|
||
|
# length of wavelets or modes doesn't match the length of axes
|
||
|
assert_raises(ValueError, pywt.dwtn, data, wavelets[:2])
|
||
|
assert_raises(ValueError, pywt.dwtn, data, wavelets, mode=modes[:2])
|
||
|
assert_raises(ValueError, pywt.idwtn, coefs, wavelets[:2])
|
||
|
assert_raises(ValueError, pywt.idwtn, coefs, wavelets, mode=modes[:2])
|
||
|
|
||
|
# dwt2/idwt2 also support per-axis wavelets/modes
|
||
|
data2 = data[..., 0]
|
||
|
coefs2 = pywt.dwt2(data2, wavelets[:2], modes[:2])
|
||
|
assert_allclose(pywt.idwt2(coefs2, wavelets[:2], modes[:2]), data2,
|
||
|
atol=1e-14)
|
||
|
|
||
|
|
||
|
def test_error_on_continuous_wavelet():
|
||
|
# A ValueError is raised if a Continuous wavelet is selected
|
||
|
data = np.ones((16, 16))
|
||
|
for dec_fun, rec_fun in zip([pywt.dwt2, pywt.dwtn],
|
||
|
[pywt.idwt2, pywt.idwtn]):
|
||
|
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
|
||
|
assert_raises(ValueError, dec_fun, data, wavelet=cwave)
|
||
|
|
||
|
c = dec_fun(data, 'db1')
|
||
|
assert_raises(ValueError, rec_fun, c, wavelet=cwave)
|