170 lines
6.4 KiB
Python
170 lines
6.4 KiB
Python
from __future__ import division, print_function, absolute_import
|
|
import numpy as np
|
|
from numpy.testing import assert_allclose, assert_raises, assert_, assert_equal
|
|
|
|
import pywt
|
|
|
|
|
|
float_dtypes = [np.float32, np.float64, np.complex64, np.complex128]
|
|
real_dtypes = [np.float32, np.float64]
|
|
|
|
|
|
def _sign(x):
|
|
# Matlab-like sign function (numpy uses a different convention).
|
|
return x / np.abs(x)
|
|
|
|
|
|
def _soft(x, thresh):
|
|
"""soft thresholding supporting complex values.
|
|
|
|
Notes
|
|
-----
|
|
This version is not robust to zeros in x.
|
|
"""
|
|
return _sign(x) * np.maximum(np.abs(x) - thresh, 0)
|
|
|
|
|
|
def test_threshold():
|
|
data = np.linspace(1, 4, 7)
|
|
|
|
# soft
|
|
soft_result = [0., 0., 0., 0.5, 1., 1.5, 2.]
|
|
assert_allclose(pywt.threshold(data, 2, 'soft'),
|
|
np.array(soft_result), rtol=1e-12)
|
|
assert_allclose(pywt.threshold(-data, 2, 'soft'),
|
|
-np.array(soft_result), rtol=1e-12)
|
|
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'soft'),
|
|
[[0, 1]] * 2, rtol=1e-12)
|
|
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'soft'),
|
|
[[0, 0]] * 2, rtol=1e-12)
|
|
|
|
# soft thresholding complex values
|
|
assert_allclose(pywt.threshold([[1j, 2j]] * 2, 1, 'soft'),
|
|
[[0j, 1j]] * 2, rtol=1e-12)
|
|
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 6, 'soft'),
|
|
[[0, 0]] * 2, rtol=1e-12)
|
|
complex_data = [[1+2j, 2+2j]]*2
|
|
for thresh in [1, 2]:
|
|
assert_allclose(pywt.threshold(complex_data, thresh, 'soft'),
|
|
_soft(complex_data, thresh), rtol=1e-12)
|
|
|
|
# test soft thresholding with non-default substitute argument
|
|
s = 5
|
|
assert_allclose(pywt.threshold([[1j, 2]] * 2, 1.5, 'soft', substitute=s),
|
|
[[s, 0.5]] * 2, rtol=1e-12)
|
|
|
|
# soft: no divide by zero warnings when input contains zeros
|
|
assert_allclose(pywt.threshold(np.zeros(16), 2, 'soft'),
|
|
np.zeros(16), rtol=1e-12)
|
|
|
|
# hard
|
|
hard_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
|
|
assert_allclose(pywt.threshold(data, 2, 'hard'),
|
|
np.array(hard_result), rtol=1e-12)
|
|
assert_allclose(pywt.threshold(-data, 2, 'hard'),
|
|
-np.array(hard_result), rtol=1e-12)
|
|
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'hard'),
|
|
[[1, 2]] * 2, rtol=1e-12)
|
|
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard'),
|
|
[[0, 2]] * 2, rtol=1e-12)
|
|
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard', substitute=s),
|
|
[[s, 2]] * 2, rtol=1e-12)
|
|
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 2, 'hard'),
|
|
[[0, 2+2j]] * 2, rtol=1e-12)
|
|
|
|
# greater
|
|
greater_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
|
|
assert_allclose(pywt.threshold(data, 2, 'greater'),
|
|
np.array(greater_result), rtol=1e-12)
|
|
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'greater'),
|
|
[[1, 2]] * 2, rtol=1e-12)
|
|
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater'),
|
|
[[0, 2]] * 2, rtol=1e-12)
|
|
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater', substitute=s),
|
|
[[s, 2]] * 2, rtol=1e-12)
|
|
# greater doesn't allow complex-valued inputs
|
|
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'greater')
|
|
|
|
# less
|
|
assert_allclose(pywt.threshold(data, 2, 'less'),
|
|
np.array([1., 1.5, 2., 0., 0., 0., 0.]), rtol=1e-12)
|
|
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less'),
|
|
[[1, 0]] * 2, rtol=1e-12)
|
|
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less', substitute=s),
|
|
[[1, s]] * 2, rtol=1e-12)
|
|
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'less'),
|
|
[[1, 2]] * 2, rtol=1e-12)
|
|
|
|
# less doesn't allow complex-valued inputs
|
|
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'less')
|
|
|
|
# invalid
|
|
assert_raises(ValueError, pywt.threshold, data, 2, 'foo')
|
|
|
|
|
|
def test_nonnegative_garotte():
|
|
thresh = 0.3
|
|
data_real = np.linspace(-1, 1, 100)
|
|
for dtype in float_dtypes:
|
|
if dtype in real_dtypes:
|
|
data = np.asarray(data_real, dtype=dtype)
|
|
else:
|
|
data = np.asarray(data_real + 0.1j, dtype=dtype)
|
|
d_hard = pywt.threshold(data, thresh, 'hard')
|
|
d_soft = pywt.threshold(data, thresh, 'soft')
|
|
d_garotte = pywt.threshold(data, thresh, 'garotte')
|
|
|
|
# check dtypes
|
|
assert_equal(d_hard.dtype, data.dtype)
|
|
assert_equal(d_soft.dtype, data.dtype)
|
|
assert_equal(d_garotte.dtype, data.dtype)
|
|
|
|
# values < threshold are zero
|
|
lt = np.where(np.abs(data) < thresh)
|
|
assert_(np.all(d_garotte[lt] == 0))
|
|
|
|
# values > than the threshold are intermediate between soft and hard
|
|
gt = np.where(np.abs(data) > thresh)
|
|
gt_abs_garotte = np.abs(d_garotte[gt])
|
|
assert_(np.all(gt_abs_garotte < np.abs(d_hard[gt])))
|
|
assert_(np.all(gt_abs_garotte > np.abs(d_soft[gt])))
|
|
|
|
|
|
def test_threshold_firm():
|
|
thresh = 0.2
|
|
thresh2 = 3 * thresh
|
|
data_real = np.linspace(-1, 1, 100)
|
|
for dtype in float_dtypes:
|
|
if dtype in real_dtypes:
|
|
data = np.asarray(data_real, dtype=dtype)
|
|
else:
|
|
data = np.asarray(data_real + 0.1j, dtype=dtype)
|
|
if data.real.dtype == np.float32:
|
|
rtol = atol = 1e-6
|
|
else:
|
|
rtol = atol = 1e-14
|
|
d_hard = pywt.threshold(data, thresh, 'hard')
|
|
d_soft = pywt.threshold(data, thresh, 'soft')
|
|
d_firm = pywt.threshold_firm(data, thresh, thresh2)
|
|
|
|
# check dtypes
|
|
assert_equal(d_hard.dtype, data.dtype)
|
|
assert_equal(d_soft.dtype, data.dtype)
|
|
assert_equal(d_firm.dtype, data.dtype)
|
|
|
|
# values < threshold are zero
|
|
lt = np.where(np.abs(data) < thresh)
|
|
assert_(np.all(d_firm[lt] == 0))
|
|
|
|
# values > than the threshold are equal to hard-thresholding
|
|
gt = np.where(np.abs(data) >= thresh2)
|
|
assert_allclose(np.abs(d_hard[gt]), np.abs(d_firm[gt]),
|
|
rtol=rtol, atol=atol)
|
|
|
|
# other values are intermediate between soft and hard thresholding
|
|
mt = np.where(np.logical_and(np.abs(data) > thresh,
|
|
np.abs(data) < thresh2))
|
|
mt_abs_firm = np.abs(d_firm[mt])
|
|
assert_(np.all(mt_abs_firm < np.abs(d_hard[mt])))
|
|
assert_(np.all(mt_abs_firm > np.abs(d_soft[mt])))
|