forked from 170010011/fr
223 lines
6.6 KiB
Python
223 lines
6.6 KiB
Python
"""LU decomposition functions."""
|
|
|
|
from warnings import warn
|
|
|
|
from numpy import asarray, asarray_chkfinite
|
|
|
|
# Local imports
|
|
from .misc import _datacopied, LinAlgWarning
|
|
from .lapack import get_lapack_funcs
|
|
from .flinalg import get_flinalg_funcs
|
|
|
|
__all__ = ['lu', 'lu_solve', 'lu_factor']
|
|
|
|
|
|
def lu_factor(a, overwrite_a=False, check_finite=True):
|
|
"""
|
|
Compute pivoted LU decomposition of a matrix.
|
|
|
|
The decomposition is::
|
|
|
|
A = P L U
|
|
|
|
where P is a permutation matrix, L lower triangular with unit
|
|
diagonal elements, and U upper triangular.
|
|
|
|
Parameters
|
|
----------
|
|
a : (M, M) array_like
|
|
Matrix to decompose
|
|
overwrite_a : bool, optional
|
|
Whether to overwrite data in A (may increase performance)
|
|
check_finite : bool, optional
|
|
Whether to check that the input matrix contains only finite numbers.
|
|
Disabling may give a performance gain, but may result in problems
|
|
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
|
|
|
Returns
|
|
-------
|
|
lu : (N, N) ndarray
|
|
Matrix containing U in its upper triangle, and L in its lower triangle.
|
|
The unit diagonal elements of L are not stored.
|
|
piv : (N,) ndarray
|
|
Pivot indices representing the permutation matrix P:
|
|
row i of matrix was interchanged with row piv[i].
|
|
|
|
See also
|
|
--------
|
|
lu_solve : solve an equation system using the LU factorization of a matrix
|
|
|
|
Notes
|
|
-----
|
|
This is a wrapper to the ``*GETRF`` routines from LAPACK.
|
|
|
|
Examples
|
|
--------
|
|
>>> from scipy.linalg import lu_factor
|
|
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
|
|
>>> lu, piv = lu_factor(A)
|
|
>>> piv
|
|
array([2, 2, 3, 3], dtype=int32)
|
|
|
|
Convert LAPACK's ``piv`` array to NumPy index and test the permutation
|
|
|
|
>>> piv_py = [2, 0, 3, 1]
|
|
>>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu)
|
|
>>> np.allclose(A[piv_py] - L @ U, np.zeros((4, 4)))
|
|
True
|
|
"""
|
|
if check_finite:
|
|
a1 = asarray_chkfinite(a)
|
|
else:
|
|
a1 = asarray(a)
|
|
if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
|
|
raise ValueError('expected square matrix')
|
|
overwrite_a = overwrite_a or (_datacopied(a1, a))
|
|
getrf, = get_lapack_funcs(('getrf',), (a1,))
|
|
lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
|
|
if info < 0:
|
|
raise ValueError('illegal value in %dth argument of '
|
|
'internal getrf (lu_factor)' % -info)
|
|
if info > 0:
|
|
warn("Diagonal number %d is exactly zero. Singular matrix." % info,
|
|
LinAlgWarning, stacklevel=2)
|
|
return lu, piv
|
|
|
|
|
|
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
|
"""Solve an equation system, a x = b, given the LU factorization of a
|
|
|
|
Parameters
|
|
----------
|
|
(lu, piv)
|
|
Factorization of the coefficient matrix a, as given by lu_factor
|
|
b : array
|
|
Right-hand side
|
|
trans : {0, 1, 2}, optional
|
|
Type of system to solve:
|
|
|
|
===== =========
|
|
trans system
|
|
===== =========
|
|
0 a x = b
|
|
1 a^T x = b
|
|
2 a^H x = b
|
|
===== =========
|
|
overwrite_b : bool, optional
|
|
Whether to overwrite data in b (may increase performance)
|
|
check_finite : bool, optional
|
|
Whether to check that the input matrices contain only finite numbers.
|
|
Disabling may give a performance gain, but may result in problems
|
|
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
|
|
|
Returns
|
|
-------
|
|
x : array
|
|
Solution to the system
|
|
|
|
See also
|
|
--------
|
|
lu_factor : LU factorize a matrix
|
|
|
|
Examples
|
|
--------
|
|
>>> from scipy.linalg import lu_factor, lu_solve
|
|
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
|
|
>>> b = np.array([1, 1, 1, 1])
|
|
>>> lu, piv = lu_factor(A)
|
|
>>> x = lu_solve((lu, piv), b)
|
|
>>> np.allclose(A @ x - b, np.zeros((4,)))
|
|
True
|
|
|
|
"""
|
|
(lu, piv) = lu_and_piv
|
|
if check_finite:
|
|
b1 = asarray_chkfinite(b)
|
|
else:
|
|
b1 = asarray(b)
|
|
overwrite_b = overwrite_b or _datacopied(b1, b)
|
|
if lu.shape[0] != b1.shape[0]:
|
|
raise ValueError("Shapes of lu {} and b {} are incompatible"
|
|
.format(lu.shape, b1.shape))
|
|
|
|
getrs, = get_lapack_funcs(('getrs',), (lu, b1))
|
|
x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
|
|
if info == 0:
|
|
return x
|
|
raise ValueError('illegal value in %dth argument of internal gesv|posv'
|
|
% -info)
|
|
|
|
|
|
def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|
"""
|
|
Compute pivoted LU decomposition of a matrix.
|
|
|
|
The decomposition is::
|
|
|
|
A = P L U
|
|
|
|
where P is a permutation matrix, L lower triangular with unit
|
|
diagonal elements, and U upper triangular.
|
|
|
|
Parameters
|
|
----------
|
|
a : (M, N) array_like
|
|
Array to decompose
|
|
permute_l : bool, optional
|
|
Perform the multiplication P*L (Default: do not permute)
|
|
overwrite_a : bool, optional
|
|
Whether to overwrite data in a (may improve performance)
|
|
check_finite : bool, optional
|
|
Whether to check that the input matrix contains only finite numbers.
|
|
Disabling may give a performance gain, but may result in problems
|
|
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
|
|
|
Returns
|
|
-------
|
|
**(If permute_l == False)**
|
|
|
|
p : (M, M) ndarray
|
|
Permutation matrix
|
|
l : (M, K) ndarray
|
|
Lower triangular or trapezoidal matrix with unit diagonal.
|
|
K = min(M, N)
|
|
u : (K, N) ndarray
|
|
Upper triangular or trapezoidal matrix
|
|
|
|
**(If permute_l == True)**
|
|
|
|
pl : (M, K) ndarray
|
|
Permuted L matrix.
|
|
K = min(M, N)
|
|
u : (K, N) ndarray
|
|
Upper triangular or trapezoidal matrix
|
|
|
|
Notes
|
|
-----
|
|
This is a LU factorization routine written for SciPy.
|
|
|
|
Examples
|
|
--------
|
|
>>> from scipy.linalg import lu
|
|
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
|
|
>>> p, l, u = lu(A)
|
|
>>> np.allclose(A - p @ l @ u, np.zeros((4, 4)))
|
|
True
|
|
|
|
"""
|
|
if check_finite:
|
|
a1 = asarray_chkfinite(a)
|
|
else:
|
|
a1 = asarray(a)
|
|
if len(a1.shape) != 2:
|
|
raise ValueError('expected matrix')
|
|
overwrite_a = overwrite_a or (_datacopied(a1, a))
|
|
flu, = get_flinalg_funcs(('lu',), (a1,))
|
|
p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a)
|
|
if info < 0:
|
|
raise ValueError('illegal value in %dth argument of '
|
|
'internal lu.getrf' % -info)
|
|
if permute_l:
|
|
return l, u
|
|
return p, l, u
|