223 lines
7.1 KiB
Python
223 lines
7.1 KiB
Python
"""Compatibility fixes for older version of python, numpy and scipy
|
|
|
|
If you add content to this file, please give the version of the package
|
|
at which the fixe is no longer needed.
|
|
"""
|
|
# Authors: Emmanuelle Gouillart <emmanuelle.gouillart@normalesup.org>
|
|
# Gael Varoquaux <gael.varoquaux@normalesup.org>
|
|
# Fabian Pedregosa <fpedregosa@acm.org>
|
|
# Lars Buitinck
|
|
#
|
|
# License: BSD 3 clause
|
|
|
|
from functools import update_wrapper
|
|
from distutils.version import LooseVersion
|
|
import functools
|
|
|
|
import numpy as np
|
|
import scipy.sparse as sp
|
|
import scipy
|
|
import scipy.stats
|
|
from scipy.sparse.linalg import lsqr as sparse_lsqr # noqa
|
|
from numpy.ma import MaskedArray as _MaskedArray # TODO: remove in 1.0
|
|
from .._config import config_context, get_config
|
|
|
|
from .deprecation import deprecated
|
|
|
|
try:
|
|
from pkg_resources import parse_version # type: ignore
|
|
except ImportError:
|
|
# setuptools not installed
|
|
parse_version = LooseVersion # type: ignore
|
|
|
|
|
|
np_version = parse_version(np.__version__)
|
|
sp_version = parse_version(scipy.__version__)
|
|
|
|
|
|
if sp_version >= parse_version('1.4'):
|
|
from scipy.sparse.linalg import lobpcg
|
|
else:
|
|
# Backport of lobpcg functionality from scipy 1.4.0, can be removed
|
|
# once support for sp_version < parse_version('1.4') is dropped
|
|
# mypy error: Name 'lobpcg' already defined (possibly by an import)
|
|
from ..externals._lobpcg import lobpcg # type: ignore # noqa
|
|
|
|
|
|
def _object_dtype_isnan(X):
|
|
return X != X
|
|
|
|
|
|
# TODO: replace by copy=False, when only scipy > 1.1 is supported.
|
|
def _astype_copy_false(X):
|
|
"""Returns the copy=False parameter for
|
|
{ndarray, csr_matrix, csc_matrix}.astype when possible,
|
|
otherwise don't specify
|
|
"""
|
|
if sp_version >= parse_version('1.1') or not sp.issparse(X):
|
|
return {'copy': False}
|
|
else:
|
|
return {}
|
|
|
|
|
|
def _joblib_parallel_args(**kwargs):
|
|
"""Set joblib.Parallel arguments in a compatible way for 0.11 and 0.12+
|
|
|
|
For joblib 0.11 this maps both ``prefer`` and ``require`` parameters to
|
|
a specific ``backend``.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
prefer : str in {'processes', 'threads'} or None
|
|
Soft hint to choose the default backend if no specific backend
|
|
was selected with the parallel_backend context manager.
|
|
|
|
require : 'sharedmem' or None
|
|
Hard condstraint to select the backend. If set to 'sharedmem',
|
|
the selected backend will be single-host and thread-based even
|
|
if the user asked for a non-thread based backend with
|
|
parallel_backend.
|
|
|
|
See joblib.Parallel documentation for more details
|
|
"""
|
|
import joblib
|
|
|
|
if parse_version(joblib.__version__) >= parse_version('0.12'):
|
|
return kwargs
|
|
|
|
extra_args = set(kwargs.keys()).difference({'prefer', 'require'})
|
|
if extra_args:
|
|
raise NotImplementedError('unhandled arguments %s with joblib %s'
|
|
% (list(extra_args), joblib.__version__))
|
|
args = {}
|
|
if 'prefer' in kwargs:
|
|
prefer = kwargs['prefer']
|
|
if prefer not in ['threads', 'processes', None]:
|
|
raise ValueError('prefer=%s is not supported' % prefer)
|
|
args['backend'] = {'threads': 'threading',
|
|
'processes': 'multiprocessing',
|
|
None: None}[prefer]
|
|
|
|
if 'require' in kwargs:
|
|
require = kwargs['require']
|
|
if require not in [None, 'sharedmem']:
|
|
raise ValueError('require=%s is not supported' % require)
|
|
if require == 'sharedmem':
|
|
args['backend'] = 'threading'
|
|
return args
|
|
|
|
|
|
class loguniform(scipy.stats.reciprocal):
|
|
"""A class supporting log-uniform random variables.
|
|
|
|
Parameters
|
|
----------
|
|
low : float
|
|
The minimum value
|
|
high : float
|
|
The maximum value
|
|
|
|
Methods
|
|
-------
|
|
rvs(self, size=None, random_state=None)
|
|
Generate log-uniform random variables
|
|
|
|
The most useful method for Scikit-learn usage is highlighted here.
|
|
For a full list, see
|
|
`scipy.stats.reciprocal
|
|
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.reciprocal.html>`_.
|
|
This list includes all functions of ``scipy.stats`` continuous
|
|
distributions such as ``pdf``.
|
|
|
|
Notes
|
|
-----
|
|
This class generates values between ``low`` and ``high`` or
|
|
|
|
low <= loguniform(low, high).rvs() <= high
|
|
|
|
The logarithmic probability density function (PDF) is uniform. When
|
|
``x`` is a uniformly distributed random variable between 0 and 1, ``10**x``
|
|
are random variables that are equally likely to be returned.
|
|
|
|
This class is an alias to ``scipy.stats.reciprocal``, which uses the
|
|
reciprocal distribution:
|
|
https://en.wikipedia.org/wiki/Reciprocal_distribution
|
|
|
|
Examples
|
|
--------
|
|
|
|
>>> from sklearn.utils.fixes import loguniform
|
|
>>> rv = loguniform(1e-3, 1e1)
|
|
>>> rvs = rv.rvs(random_state=42, size=1000)
|
|
>>> rvs.min() # doctest: +SKIP
|
|
0.0010435856341129003
|
|
>>> rvs.max() # doctest: +SKIP
|
|
9.97403052786026
|
|
"""
|
|
|
|
|
|
@deprecated(
|
|
'MaskedArray is deprecated in version 0.23 and will be removed in version '
|
|
'1.0 (renaming of 0.25). Use numpy.ma.MaskedArray instead.'
|
|
)
|
|
class MaskedArray(_MaskedArray):
|
|
pass # TODO: remove in 1.0
|
|
|
|
|
|
def _take_along_axis(arr, indices, axis):
|
|
"""Implements a simplified version of np.take_along_axis if numpy
|
|
version < 1.15"""
|
|
if np_version >= parse_version('1.15'):
|
|
return np.take_along_axis(arr=arr, indices=indices, axis=axis)
|
|
else:
|
|
if axis is None:
|
|
arr = arr.flatten()
|
|
|
|
if not np.issubdtype(indices.dtype, np.intp):
|
|
raise IndexError('`indices` must be an integer array')
|
|
if arr.ndim != indices.ndim:
|
|
raise ValueError(
|
|
"`indices` and `arr` must have the same number of dimensions")
|
|
|
|
shape_ones = (1,) * indices.ndim
|
|
dest_dims = (
|
|
list(range(axis)) +
|
|
[None] +
|
|
list(range(axis+1, indices.ndim))
|
|
)
|
|
|
|
# build a fancy index, consisting of orthogonal aranges, with the
|
|
# requested index inserted at the right location
|
|
fancy_index = []
|
|
for dim, n in zip(dest_dims, arr.shape):
|
|
if dim is None:
|
|
fancy_index.append(indices)
|
|
else:
|
|
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim+1:]
|
|
fancy_index.append(np.arange(n).reshape(ind_shape))
|
|
|
|
fancy_index = tuple(fancy_index)
|
|
return arr[fancy_index]
|
|
|
|
|
|
# remove when https://github.com/joblib/joblib/issues/1071 is fixed
|
|
def delayed(function):
|
|
"""Decorator used to capture the arguments of a function."""
|
|
@functools.wraps(function)
|
|
def delayed_function(*args, **kwargs):
|
|
return _FuncWrapper(function), args, kwargs
|
|
return delayed_function
|
|
|
|
|
|
class _FuncWrapper:
|
|
""""Load the global configuration before calling the function."""
|
|
def __init__(self, function):
|
|
self.function = function
|
|
self.config = get_config()
|
|
update_wrapper(self, self.function)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
with config_context(**self.config):
|
|
return self.function(*args, **kwargs)
|