fr/fr_env/lib/python3.8/site-packages/pywt/tests/test_concurrent.py

106 lines
3.9 KiB
Python
Raw Normal View History

2021-02-17 12:26:31 +05:30
"""
Tests used to verify running PyWavelets transforms in parallel via
concurrent.futures.ThreadPoolExecutor does not raise errors.
"""
from __future__ import division, print_function, absolute_import
import warnings
import numpy as np
from functools import partial
from numpy.testing import assert_array_equal, assert_allclose
from pywt._pytest import uses_futures, futures, max_workers
import pywt
def _assert_all_coeffs_equal(coefs1, coefs2):
# return True only if all coefficients of SWT or DWT match over all levels
if len(coefs1) != len(coefs2):
return False
for (c1, c2) in zip(coefs1, coefs2):
if isinstance(c1, tuple):
# for swt, swt2, dwt, dwt2, wavedec, wavedec2
for a1, a2 in zip(c1, c2):
assert_array_equal(a1, a2)
elif isinstance(c1, dict):
# for swtn, dwtn, wavedecn
for k, v in c1.items():
assert_array_equal(v, c2[k])
else:
return False
return True
@uses_futures
def test_concurrent_swt():
# tests error-free concurrent operation (see gh-288)
# swt on 1D data calls the Cython swt
# other cases call swt_axes
with warnings.catch_warnings():
# can remove catch_warnings once the swt2 FutureWarning is removed
warnings.simplefilter('ignore', FutureWarning)
for swt_func, x in zip([pywt.swt, pywt.swt2, pywt.swtn],
[np.ones(8), np.eye(16), np.eye(16)]):
transform = partial(swt_func, wavelet='haar', level=3)
for _ in range(10):
arrs = [x.copy() for _ in range(100)]
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
results = list(ex.map(transform, arrs))
# validate result from one of the concurrent runs
expected_result = transform(x)
_assert_all_coeffs_equal(expected_result, results[-1])
@uses_futures
def test_concurrent_wavedec():
# wavedec on 1D data calls the Cython dwt_single
# other cases call dwt_axis
for wavedec_func, x in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn],
[np.ones(8), np.eye(16), np.eye(16)]):
transform = partial(wavedec_func, wavelet='haar', level=1)
for _ in range(10):
arrs = [x.copy() for _ in range(100)]
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
results = list(ex.map(transform, arrs))
# validate result from one of the concurrent runs
expected_result = transform(x)
_assert_all_coeffs_equal(expected_result, results[-1])
@uses_futures
def test_concurrent_dwt():
# dwt on 1D data calls the Cython dwt_single
# other cases call dwt_axis
for dwt_func, x in zip([pywt.dwt, pywt.dwt2, pywt.dwtn],
[np.ones(8), np.eye(16), np.eye(16)]):
transform = partial(dwt_func, wavelet='haar')
for _ in range(10):
arrs = [x.copy() for _ in range(100)]
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
results = list(ex.map(transform, arrs))
# validate result from one of the concurrent runs
expected_result = transform(x)
_assert_all_coeffs_equal([expected_result, ], [results[-1], ])
@uses_futures
def test_concurrent_cwt():
atol = rtol = 1e-14
time, sst = pywt.data.nino()
dt = time[1]-time[0]
transform = partial(pywt.cwt, scales=np.arange(1, 4), wavelet='cmor1.5-1',
sampling_period=dt)
for _ in range(10):
arrs = [sst.copy() for _ in range(50)]
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
results = list(ex.map(transform, arrs))
# validate result from one of the concurrent runs
expected_result = transform(sst)
for a1, a2 in zip(expected_result, results[-1]):
assert_allclose(a1, a2, atol=atol, rtol=rtol)