forked from 170010011/fr
241 lines
6.8 KiB
Python
241 lines
6.8 KiB
Python
|
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
|
||
|
# Copyright (c) 2012-2016 The PyWavelets Developers
|
||
|
# <https://github.com/PyWavelets/pywt>
|
||
|
# See COPYING for license details.
|
||
|
|
||
|
"""
|
||
|
Other wavelet related functions.
|
||
|
"""
|
||
|
|
||
|
from __future__ import division, print_function, absolute_import
|
||
|
|
||
|
import warnings
|
||
|
|
||
|
import numpy as np
|
||
|
from numpy.fft import fft
|
||
|
|
||
|
from ._extensions._pywt import DiscreteContinuousWavelet, Wavelet, ContinuousWavelet
|
||
|
|
||
|
|
||
|
__all__ = ["integrate_wavelet", "central_frequency", "scale2frequency", "qmf",
|
||
|
"orthogonal_filter_bank",
|
||
|
"intwave", "centrfrq", "scal2frq", "orthfilt"]
|
||
|
|
||
|
|
||
|
_DEPRECATION_MSG = ("`{old}` has been renamed to `{new}` and will "
|
||
|
"be removed in a future version of pywt.")
|
||
|
|
||
|
|
||
|
def _integrate(arr, step):
|
||
|
integral = np.cumsum(arr)
|
||
|
integral *= step
|
||
|
return integral
|
||
|
|
||
|
|
||
|
def intwave(*args, **kwargs):
|
||
|
msg = _DEPRECATION_MSG.format(old='intwave', new='integrate_wavelet')
|
||
|
warnings.warn(msg, DeprecationWarning)
|
||
|
return integrate_wavelet(*args, **kwargs)
|
||
|
|
||
|
|
||
|
def centrfrq(*args, **kwargs):
|
||
|
msg = _DEPRECATION_MSG.format(old='centrfrq', new='central_frequency')
|
||
|
warnings.warn(msg, DeprecationWarning)
|
||
|
return central_frequency(*args, **kwargs)
|
||
|
|
||
|
|
||
|
def scal2frq(*args, **kwargs):
|
||
|
msg = _DEPRECATION_MSG.format(old='scal2frq', new='scale2frequency')
|
||
|
warnings.warn(msg, DeprecationWarning)
|
||
|
return scale2frequency(*args, **kwargs)
|
||
|
|
||
|
|
||
|
def orthfilt(*args, **kwargs):
|
||
|
msg = _DEPRECATION_MSG.format(old='orthfilt', new='orthogonal_filter_bank')
|
||
|
warnings.warn(msg, DeprecationWarning)
|
||
|
return orthogonal_filter_bank(*args, **kwargs)
|
||
|
|
||
|
|
||
|
def integrate_wavelet(wavelet, precision=8):
|
||
|
"""
|
||
|
Integrate `psi` wavelet function from -Inf to x using the rectangle
|
||
|
integration method.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
wavelet : Wavelet instance or str
|
||
|
Wavelet to integrate. If a string, should be the name of a wavelet.
|
||
|
precision : int, optional
|
||
|
Precision that will be used for wavelet function
|
||
|
approximation computed with the wavefun(level=precision)
|
||
|
Wavelet's method (default: 8).
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
[int_psi, x] :
|
||
|
for orthogonal wavelets
|
||
|
[int_psi_d, int_psi_r, x] :
|
||
|
for other wavelets
|
||
|
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> from pywt import Wavelet, integrate_wavelet
|
||
|
>>> wavelet1 = Wavelet('db2')
|
||
|
>>> [int_psi, x] = integrate_wavelet(wavelet1, precision=5)
|
||
|
>>> wavelet2 = Wavelet('bior1.3')
|
||
|
>>> [int_psi_d, int_psi_r, x] = integrate_wavelet(wavelet2, precision=5)
|
||
|
|
||
|
"""
|
||
|
# FIXME: this function should really use scipy.integrate.quad
|
||
|
|
||
|
if type(wavelet) in (tuple, list):
|
||
|
msg = ("Integration of a general signal is deprecated "
|
||
|
"and will be removed in a future version of pywt.")
|
||
|
warnings.warn(msg, DeprecationWarning)
|
||
|
elif not isinstance(wavelet, (Wavelet, ContinuousWavelet)):
|
||
|
wavelet = DiscreteContinuousWavelet(wavelet)
|
||
|
|
||
|
if type(wavelet) in (tuple, list):
|
||
|
psi, x = np.asarray(wavelet[0]), np.asarray(wavelet[1])
|
||
|
step = x[1] - x[0]
|
||
|
return _integrate(psi, step), x
|
||
|
|
||
|
functions_approximations = wavelet.wavefun(precision)
|
||
|
|
||
|
if len(functions_approximations) == 2: # continuous wavelet
|
||
|
psi, x = functions_approximations
|
||
|
step = x[1] - x[0]
|
||
|
return _integrate(psi, step), x
|
||
|
|
||
|
elif len(functions_approximations) == 3: # orthogonal wavelet
|
||
|
phi, psi, x = functions_approximations
|
||
|
step = x[1] - x[0]
|
||
|
return _integrate(psi, step), x
|
||
|
|
||
|
else: # biorthogonal wavelet
|
||
|
phi_d, psi_d, phi_r, psi_r, x = functions_approximations
|
||
|
step = x[1] - x[0]
|
||
|
return _integrate(psi_d, step), _integrate(psi_r, step), x
|
||
|
|
||
|
|
||
|
def central_frequency(wavelet, precision=8):
|
||
|
"""
|
||
|
Computes the central frequency of the `psi` wavelet function.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
wavelet : Wavelet instance, str or tuple
|
||
|
Wavelet to integrate. If a string, should be the name of a wavelet.
|
||
|
precision : int, optional
|
||
|
Precision that will be used for wavelet function
|
||
|
approximation computed with the wavefun(level=precision)
|
||
|
Wavelet's method (default: 8).
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
scalar
|
||
|
|
||
|
"""
|
||
|
|
||
|
if not isinstance(wavelet, (Wavelet, ContinuousWavelet)):
|
||
|
wavelet = DiscreteContinuousWavelet(wavelet)
|
||
|
|
||
|
functions_approximations = wavelet.wavefun(precision)
|
||
|
|
||
|
if len(functions_approximations) == 2:
|
||
|
psi, x = functions_approximations
|
||
|
else:
|
||
|
# (psi, x) for (phi, psi, x)
|
||
|
# (psi_d, x) for (phi_d, psi_d, phi_r, psi_r, x)
|
||
|
psi, x = functions_approximations[1], functions_approximations[-1]
|
||
|
|
||
|
domain = float(x[-1] - x[0])
|
||
|
assert domain > 0
|
||
|
|
||
|
index = np.argmax(abs(fft(psi)[1:])) + 2
|
||
|
if index > len(psi) / 2:
|
||
|
index = len(psi) - index + 2
|
||
|
|
||
|
return 1.0 / (domain / (index - 1))
|
||
|
|
||
|
|
||
|
def scale2frequency(wavelet, scale, precision=8):
|
||
|
"""
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
wavelet : Wavelet instance or str
|
||
|
Wavelet to integrate. If a string, should be the name of a wavelet.
|
||
|
scale : scalar
|
||
|
precision : int, optional
|
||
|
Precision that will be used for wavelet function approximation computed
|
||
|
with ``wavelet.wavefun(level=precision)``. Default is 8.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
freq : scalar
|
||
|
|
||
|
"""
|
||
|
return central_frequency(wavelet, precision=precision) / scale
|
||
|
|
||
|
|
||
|
def qmf(filt):
|
||
|
"""
|
||
|
Returns the Quadrature Mirror Filter(QMF).
|
||
|
|
||
|
The magnitude response of QMF is mirror image about `pi/2` of that of the
|
||
|
input filter.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
filt : array_like
|
||
|
Input filter for which QMF needs to be computed.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
qm_filter : ndarray
|
||
|
Quadrature mirror of the input filter.
|
||
|
|
||
|
"""
|
||
|
qm_filter = np.array(filt)[::-1]
|
||
|
qm_filter[1::2] = -qm_filter[1::2]
|
||
|
return qm_filter
|
||
|
|
||
|
|
||
|
def orthogonal_filter_bank(scaling_filter):
|
||
|
"""
|
||
|
Returns the orthogonal filter bank.
|
||
|
|
||
|
The orthogonal filter bank consists of the HPFs and LPFs at
|
||
|
decomposition and reconstruction stage for the input scaling filter.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
scaling_filter : array_like
|
||
|
Input scaling filter (father wavelet).
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
orth_filt_bank : tuple of 4 ndarrays
|
||
|
The orthogonal filter bank of the input scaling filter in the order :
|
||
|
1] Decomposition LPF
|
||
|
2] Decomposition HPF
|
||
|
3] Reconstruction LPF
|
||
|
4] Reconstruction HPF
|
||
|
|
||
|
"""
|
||
|
if not (len(scaling_filter) % 2 == 0):
|
||
|
raise ValueError("`scaling_filter` length has to be even.")
|
||
|
|
||
|
scaling_filter = np.asarray(scaling_filter, dtype=np.float64)
|
||
|
|
||
|
rec_lo = np.sqrt(2) * scaling_filter / np.sum(scaling_filter)
|
||
|
dec_lo = rec_lo[::-1]
|
||
|
|
||
|
rec_hi = qmf(rec_lo)
|
||
|
dec_hi = rec_hi[::-1]
|
||
|
|
||
|
orth_filt_bank = (dec_lo, dec_hi, rec_lo, rec_hi)
|
||
|
return orth_filt_bank
|