forked from 170010011/fr
110 lines
4.7 KiB
Python
110 lines
4.7 KiB
Python
#!/usr/bin/env python
|
|
from __future__ import division, print_function, absolute_import
|
|
|
|
import numpy as np
|
|
from numpy.testing import assert_raises, assert_equal, assert_allclose
|
|
|
|
import pywt
|
|
|
|
|
|
def test_available_modes():
|
|
modes = ['zero', 'constant', 'symmetric', 'periodic', 'smooth',
|
|
'periodization', 'reflect', 'antisymmetric', 'antireflect']
|
|
assert_equal(pywt.Modes.modes, modes)
|
|
assert_equal(pywt.Modes.from_object('constant'), 2)
|
|
|
|
|
|
def test_invalid_modes():
|
|
x = np.arange(4)
|
|
assert_raises(ValueError, pywt.dwt, x, 'db2', 'unknown')
|
|
assert_raises(ValueError, pywt.dwt, x, 'db2', -1)
|
|
assert_raises(ValueError, pywt.dwt, x, 'db2', 9)
|
|
assert_raises(TypeError, pywt.dwt, x, 'db2', None)
|
|
|
|
assert_raises(ValueError, pywt.Modes.from_object, 'unknown')
|
|
assert_raises(ValueError, pywt.Modes.from_object, -1)
|
|
assert_raises(ValueError, pywt.Modes.from_object, 9)
|
|
assert_raises(TypeError, pywt.Modes.from_object, None)
|
|
|
|
|
|
def test_dwt_idwt_allmodes():
|
|
# Test that :func:`dwt` and :func:`idwt` can be performed using every mode
|
|
x = [1, 2, 1, 5, -1, 8, 4, 6]
|
|
dwt_results = {
|
|
'zero': ([-0.03467518, 1.73309178, 3.40612438, 6.32928585, 6.95094948],
|
|
[-0.12940952, -2.15599552, -5.95034847, -1.21545369,
|
|
-1.8625013]),
|
|
'constant': ([1.28480404, 1.73309178, 3.40612438, 6.32928585,
|
|
7.51935555],
|
|
[-0.48296291, -2.15599552, -5.95034847, -1.21545369,
|
|
0.25881905]),
|
|
'symmetric': ([1.76776695, 1.73309178, 3.40612438, 6.32928585,
|
|
7.77817459],
|
|
[-0.61237244, -2.15599552, -5.95034847, -1.21545369,
|
|
1.22474487]),
|
|
'reflect': ([2.12132034, 1.73309178, 3.40612438, 6.32928585,
|
|
6.81224877],
|
|
[-0.70710678, -2.15599552, -5.95034847, -1.21545369,
|
|
-2.38013939]),
|
|
'periodic': ([6.9162743, 1.73309178, 3.40612438, 6.32928585,
|
|
6.9162743],
|
|
[-1.99191082, -2.15599552, -5.95034847, -1.21545369,
|
|
-1.99191082]),
|
|
'smooth': ([-0.51763809, 1.73309178, 3.40612438, 6.32928585,
|
|
7.45000519],
|
|
[0, -2.15599552, -5.95034847, -1.21545369, 0]),
|
|
'periodization': ([4.053172, 3.05257099, 2.85381112, 8.42522221],
|
|
[0.18946869, 4.18258152, 4.33737503, 2.60428326]),
|
|
'antisymmetric': ([-1.83711731, 1.73309178, 3.40612438, 6.32928585,
|
|
6.12372436],
|
|
[0.353553391, -2.15599552, -5.95034847, -1.21545369,
|
|
-4.94974747]),
|
|
'antireflect': ([0.44828774, 1.73309178, 3.40612438, 6.32928585,
|
|
8.22646233],
|
|
[-0.25881905, -2.15599552, -5.95034847, -1.21545369,
|
|
2.89777748])
|
|
}
|
|
|
|
for mode in pywt.Modes.modes:
|
|
cA, cD = pywt.dwt(x, 'db2', mode)
|
|
assert_allclose(cA, dwt_results[mode][0], rtol=1e-7, atol=1e-8)
|
|
assert_allclose(cD, dwt_results[mode][1], rtol=1e-7, atol=1e-8)
|
|
assert_allclose(pywt.idwt(cA, cD, 'db2', mode), x, rtol=1e-10)
|
|
|
|
|
|
def test_dwt_short_input_allmodes():
|
|
# some test cases where the input is shorter than the DWT filter
|
|
x = [1, 3, 2]
|
|
wavelet = 'db2'
|
|
# manually pad each end by the filter size (4 for 'db2' used here)
|
|
padded_x = {'zero': [0, 0, 0, 0, 1, 3, 2, 0, 0, 0, 0],
|
|
'constant': [1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 2],
|
|
'symmetric': [2, 2, 3, 1, 1, 3, 2, 2, 3, 1, 1],
|
|
'reflect': [1, 3, 2, 3, 1, 3, 2, 3, 1, 3, 2],
|
|
'periodic': [2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1],
|
|
'smooth': [-7, -5, -3, -1, 1, 3, 2, 1, 0, -1, -2],
|
|
'antisymmetric': [2, -2, -3, -1, 1, 3, 2, -2, -3, -1, 1],
|
|
'antireflect': [1, -1, 0, -1, 1, 3, 2, 1, 3, 5, 4],
|
|
}
|
|
for mode, xpad in padded_x.items():
|
|
# DWT of the manually padded array. will discard edges later so
|
|
# symmetric mode used here doesn't matter.
|
|
cApad, cDpad = pywt.dwt(xpad, wavelet, mode='symmetric')
|
|
|
|
# central region of the padded output (unaffected by mode )
|
|
expected_result = (cApad[2:-2], cDpad[2:-2])
|
|
|
|
cA, cD = pywt.dwt(x, wavelet, mode)
|
|
assert_allclose(cA, expected_result[0], rtol=1e-7, atol=1e-8)
|
|
assert_allclose(cD, expected_result[1], rtol=1e-7, atol=1e-8)
|
|
|
|
|
|
def test_default_mode():
|
|
# The default mode should be 'symmetric'
|
|
x = [1, 2, 1, 5, -1, 8, 4, 6]
|
|
cA, cD = pywt.dwt(x, 'db2')
|
|
cA2, cD2 = pywt.dwt(x, 'db2', mode='symmetric')
|
|
assert_allclose(cA, cA2)
|
|
assert_allclose(cD, cD2)
|
|
assert_allclose(pywt.idwt(cA, cD, 'db2'), x)
|