forked from 170010011/fr
175 lines
5.6 KiB
175 lines
5.6 KiB
![]() |
from datetime import datetime
import platform
from unittest.mock import MagicMock
import matplotlib.pyplot as plt
from matplotlib.testing.decorators import check_figures_equal, image_comparison
import matplotlib.units as munits
import numpy as np
import pytest
# Basic class that wraps numpy array and has units
class Quantity:
def __init__(self, data, units):
self.magnitude = data
self.units = units
def to(self, new_units):
factors = {('hours', 'seconds'): 3600, ('minutes', 'hours'): 1 / 60,
('minutes', 'seconds'): 60, ('feet', 'miles'): 1 / 5280.,
('feet', 'inches'): 12, ('miles', 'inches'): 12 * 5280}
if self.units != new_units:
mult = factors[self.units, new_units]
return Quantity(mult * self.magnitude, new_units)
return Quantity(self.magnitude, self.units)
def __getattr__(self, attr):
return getattr(self.magnitude, attr)
def __getitem__(self, item):
if np.iterable(self.magnitude):
return Quantity(self.magnitude[item], self.units)
return Quantity(self.magnitude, self.units)
def __array__(self):
return np.asarray(self.magnitude)
def quantity_converter():
# Create an instance of the conversion interface and
# mock so we can check methods called
qc = munits.ConversionInterface()
def convert(value, unit, axis):
if hasattr(value, 'units'):
elif np.iterable(value):
return [ for v in value]
except AttributeError:
return [Quantity(v, axis.get_units()).to(unit).magnitude
for v in value]
return Quantity(value, axis.get_units()).to(unit).magnitude
def default_units(value, axis):
if hasattr(value, 'units'):
return value.units
elif np.iterable(value):
for v in value:
if hasattr(v, 'units'):
return v.units
return None
qc.convert = MagicMock(side_effect=convert)
qc.axisinfo = MagicMock(side_effect=lambda u, a: munits.AxisInfo(label=u))
qc.default_units = MagicMock(side_effect=default_units)
return qc
# Tests that the conversion machinery works properly for classes that
# work as a facade over numpy arrays (like pint)
@image_comparison(['plot_pint.png'], remove_text=False, style='mpl20',
tol=0 if platform.machine() == 'x86_64' else 0.01)
def test_numpy_facade(quantity_converter):
# use former defaults to match existing baseline image
plt.rcParams['axes.formatter.limits'] = -7, 7
# Register the class
munits.registry[Quantity] = quantity_converter
# Simple test
y = Quantity(np.linspace(0, 30), 'miles')
x = Quantity(np.linspace(0, 5), 'hours')
fig, ax = plt.subplots()
fig.subplots_adjust(left=0.15) # Make space for label
ax.plot(x, y, 'tab:blue')
ax.axhline(Quantity(26400, 'feet'), color='tab:red')
ax.axvline(Quantity(120, 'minutes'), color='tab:green')
assert quantity_converter.convert.called
assert quantity_converter.axisinfo.called
assert quantity_converter.default_units.called
# Tests gh-8908
@image_comparison(['plot_masked_units.png'], remove_text=True, style='mpl20',
tol=0 if platform.machine() == 'x86_64' else 0.01)
def test_plot_masked_units():
data = np.linspace(-5, 5)
data_masked =, mask=(data > -2) & (data < 2))
data_masked_units = Quantity(data_masked, 'meters')
fig, ax = plt.subplots()
def test_empty_set_limits_with_units(quantity_converter):
# Register the class
munits.registry[Quantity] = quantity_converter
fig, ax = plt.subplots()
ax.set_xlim(Quantity(-1, 'meters'), Quantity(6, 'meters'))
ax.set_ylim(Quantity(-1, 'hours'), Quantity(16, 'hours'))
savefig_kwarg={'dpi': 120}, style='mpl20')
def test_jpl_bar_units():
import matplotlib.testing.jpl_units as units
day = units.Duration("ET", 24.0 * 60.0 * 60.0)
x = [0*, 1*, 2*]
w = [1*day, 2*day, 3*day]
b = units.Epoch("ET", dt=datetime(2009, 4, 25))
fig, ax = plt.subplots()
|, w, bottom=b)
ax.set_ylim([b-1*day, b+w[-1]+(1.001)*day])
savefig_kwarg={'dpi': 120}, style='mpl20')
def test_jpl_barh_units():
import matplotlib.testing.jpl_units as units
day = units.Duration("ET", 24.0 * 60.0 * 60.0)
x = [0*, 1*, 2*]
w = [1*day, 2*day, 3*day]
b = units.Epoch("ET", dt=datetime(2009, 4, 25))
fig, ax = plt.subplots()
ax.barh(x, w, left=b)
ax.set_xlim([b-1*day, b+w[-1]+(1.001)*day])
def test_empty_arrays():
# Check that plotting an empty array with a dtype works
plt.scatter(np.array([], dtype='datetime64[ns]'), np.array([]))
def test_scatter_element0_masked():
times = np.arange('2005-02', '2005-03', dtype='datetime64[D]')
y = np.arange(len(times), dtype=float)
y[0] = np.nan
fig, ax = plt.subplots()
ax.scatter(times, y)
def test_subclass(fig_test, fig_ref):
class subdate(datetime):
fig_test.subplots().plot(subdate(2000, 1, 1), 0, "o")
fig_ref.subplots().plot(datetime(2000, 1, 1), 0, "o")