704 lines
22 KiB
Python
704 lines
22 KiB
Python
"""
|
|
Streamline plotting for 2D vector fields.
|
|
|
|
"""
|
|
|
|
import numpy as np
|
|
|
|
import matplotlib
|
|
import matplotlib.cbook as cbook
|
|
import matplotlib.cm as cm
|
|
import matplotlib.colors as mcolors
|
|
import matplotlib.collections as mcollections
|
|
import matplotlib.lines as mlines
|
|
import matplotlib.patches as patches
|
|
|
|
|
|
__all__ = ['streamplot']
|
|
|
|
|
|
def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
|
|
cmap=None, norm=None, arrowsize=1, arrowstyle='-|>',
|
|
minlength=0.1, transform=None, zorder=None, start_points=None,
|
|
maxlength=4.0, integration_direction='both'):
|
|
"""
|
|
Draw streamlines of a vector flow.
|
|
|
|
Parameters
|
|
----------
|
|
x, y : 1D arrays
|
|
An evenly spaced grid.
|
|
u, v : 2D arrays
|
|
*x* and *y*-velocities. The number of rows and columns must match
|
|
the length of *y* and *x*, respectively.
|
|
density : float or (float, float)
|
|
Controls the closeness of streamlines. When ``density = 1``, the domain
|
|
is divided into a 30x30 grid. *density* linearly scales this grid.
|
|
Each cell in the grid can have, at most, one traversing streamline.
|
|
For different densities in each direction, use a tuple
|
|
(density_x, density_y).
|
|
linewidth : float or 2D array
|
|
The width of the stream lines. With a 2D array the line width can be
|
|
varied across the grid. The array must have the same shape as *u*
|
|
and *v*.
|
|
color : color or 2D array
|
|
The streamline color. If given an array, its values are converted to
|
|
colors using *cmap* and *norm*. The array must have the same shape
|
|
as *u* and *v*.
|
|
cmap : `~matplotlib.colors.Colormap`
|
|
Colormap used to plot streamlines and arrows. This is only used if
|
|
*color* is an array.
|
|
norm : `~matplotlib.colors.Normalize`
|
|
Normalize object used to scale luminance data to 0, 1. If ``None``,
|
|
stretch (min, max) to (0, 1). This is only used if *color* is an array.
|
|
arrowsize : float
|
|
Scaling factor for the arrow size.
|
|
arrowstyle : str
|
|
Arrow style specification.
|
|
See `~matplotlib.patches.FancyArrowPatch`.
|
|
minlength : float
|
|
Minimum length of streamline in axes coordinates.
|
|
start_points : Nx2 array
|
|
Coordinates of starting points for the streamlines in data coordinates
|
|
(the same coordinates as the *x* and *y* arrays).
|
|
zorder : int
|
|
The zorder of the stream lines and arrows.
|
|
Artists with lower zorder values are drawn first.
|
|
maxlength : float
|
|
Maximum length of streamline in axes coordinates.
|
|
integration_direction : {'forward', 'backward', 'both'}, default: 'both'
|
|
Integrate the streamline in forward, backward or both directions.
|
|
|
|
Returns
|
|
-------
|
|
StreamplotSet
|
|
Container object with attributes
|
|
|
|
- ``lines``: `.LineCollection` of streamlines
|
|
|
|
- ``arrows``: `.PatchCollection` containing `.FancyArrowPatch`
|
|
objects representing the arrows half-way along stream lines.
|
|
|
|
This container will probably change in the future to allow changes
|
|
to the colormap, alpha, etc. for both lines and arrows, but these
|
|
changes should be backward compatible.
|
|
"""
|
|
grid = Grid(x, y)
|
|
mask = StreamMask(density)
|
|
dmap = DomainMap(grid, mask)
|
|
|
|
if zorder is None:
|
|
zorder = mlines.Line2D.zorder
|
|
|
|
# default to data coordinates
|
|
if transform is None:
|
|
transform = axes.transData
|
|
|
|
if color is None:
|
|
color = axes._get_lines.get_next_color()
|
|
|
|
if linewidth is None:
|
|
linewidth = matplotlib.rcParams['lines.linewidth']
|
|
|
|
line_kw = {}
|
|
arrow_kw = dict(arrowstyle=arrowstyle, mutation_scale=10 * arrowsize)
|
|
|
|
cbook._check_in_list(['both', 'forward', 'backward'],
|
|
integration_direction=integration_direction)
|
|
|
|
if integration_direction == 'both':
|
|
maxlength /= 2.
|
|
|
|
use_multicolor_lines = isinstance(color, np.ndarray)
|
|
if use_multicolor_lines:
|
|
if color.shape != grid.shape:
|
|
raise ValueError("If 'color' is given, it must match the shape of "
|
|
"'Grid(x, y)'")
|
|
line_colors = []
|
|
color = np.ma.masked_invalid(color)
|
|
else:
|
|
line_kw['color'] = color
|
|
arrow_kw['color'] = color
|
|
|
|
if isinstance(linewidth, np.ndarray):
|
|
if linewidth.shape != grid.shape:
|
|
raise ValueError("If 'linewidth' is given, it must match the "
|
|
"shape of 'Grid(x, y)'")
|
|
line_kw['linewidth'] = []
|
|
else:
|
|
line_kw['linewidth'] = linewidth
|
|
arrow_kw['linewidth'] = linewidth
|
|
|
|
line_kw['zorder'] = zorder
|
|
arrow_kw['zorder'] = zorder
|
|
|
|
# Sanity checks.
|
|
if u.shape != grid.shape or v.shape != grid.shape:
|
|
raise ValueError("'u' and 'v' must match the shape of 'Grid(x, y)'")
|
|
|
|
u = np.ma.masked_invalid(u)
|
|
v = np.ma.masked_invalid(v)
|
|
|
|
integrate = get_integrator(u, v, dmap, minlength, maxlength,
|
|
integration_direction)
|
|
|
|
trajectories = []
|
|
if start_points is None:
|
|
for xm, ym in _gen_starting_points(mask.shape):
|
|
if mask[ym, xm] == 0:
|
|
xg, yg = dmap.mask2grid(xm, ym)
|
|
t = integrate(xg, yg)
|
|
if t is not None:
|
|
trajectories.append(t)
|
|
else:
|
|
sp2 = np.asanyarray(start_points, dtype=float).copy()
|
|
|
|
# Check if start_points are outside the data boundaries
|
|
for xs, ys in sp2:
|
|
if not (grid.x_origin <= xs <= grid.x_origin + grid.width and
|
|
grid.y_origin <= ys <= grid.y_origin + grid.height):
|
|
raise ValueError("Starting point ({}, {}) outside of data "
|
|
"boundaries".format(xs, ys))
|
|
|
|
# Convert start_points from data to array coords
|
|
# Shift the seed points from the bottom left of the data so that
|
|
# data2grid works properly.
|
|
sp2[:, 0] -= grid.x_origin
|
|
sp2[:, 1] -= grid.y_origin
|
|
|
|
for xs, ys in sp2:
|
|
xg, yg = dmap.data2grid(xs, ys)
|
|
t = integrate(xg, yg)
|
|
if t is not None:
|
|
trajectories.append(t)
|
|
|
|
if use_multicolor_lines:
|
|
if norm is None:
|
|
norm = mcolors.Normalize(color.min(), color.max())
|
|
if cmap is None:
|
|
cmap = cm.get_cmap(matplotlib.rcParams['image.cmap'])
|
|
else:
|
|
cmap = cm.get_cmap(cmap)
|
|
|
|
streamlines = []
|
|
arrows = []
|
|
for t in trajectories:
|
|
tgx = np.array(t[0])
|
|
tgy = np.array(t[1])
|
|
# Rescale from grid-coordinates to data-coordinates.
|
|
tx, ty = dmap.grid2data(*np.array(t))
|
|
tx += grid.x_origin
|
|
ty += grid.y_origin
|
|
|
|
points = np.transpose([tx, ty]).reshape(-1, 1, 2)
|
|
streamlines.extend(np.hstack([points[:-1], points[1:]]))
|
|
|
|
# Add arrows half way along each trajectory.
|
|
s = np.cumsum(np.hypot(np.diff(tx), np.diff(ty)))
|
|
n = np.searchsorted(s, s[-1] / 2.)
|
|
arrow_tail = (tx[n], ty[n])
|
|
arrow_head = (np.mean(tx[n:n + 2]), np.mean(ty[n:n + 2]))
|
|
|
|
if isinstance(linewidth, np.ndarray):
|
|
line_widths = interpgrid(linewidth, tgx, tgy)[:-1]
|
|
line_kw['linewidth'].extend(line_widths)
|
|
arrow_kw['linewidth'] = line_widths[n]
|
|
|
|
if use_multicolor_lines:
|
|
color_values = interpgrid(color, tgx, tgy)[:-1]
|
|
line_colors.append(color_values)
|
|
arrow_kw['color'] = cmap(norm(color_values[n]))
|
|
|
|
p = patches.FancyArrowPatch(
|
|
arrow_tail, arrow_head, transform=transform, **arrow_kw)
|
|
axes.add_patch(p)
|
|
arrows.append(p)
|
|
|
|
lc = mcollections.LineCollection(
|
|
streamlines, transform=transform, **line_kw)
|
|
lc.sticky_edges.x[:] = [grid.x_origin, grid.x_origin + grid.width]
|
|
lc.sticky_edges.y[:] = [grid.y_origin, grid.y_origin + grid.height]
|
|
if use_multicolor_lines:
|
|
lc.set_array(np.ma.hstack(line_colors))
|
|
lc.set_cmap(cmap)
|
|
lc.set_norm(norm)
|
|
axes.add_collection(lc)
|
|
axes.autoscale_view()
|
|
|
|
ac = matplotlib.collections.PatchCollection(arrows)
|
|
stream_container = StreamplotSet(lc, ac)
|
|
return stream_container
|
|
|
|
|
|
class StreamplotSet:
|
|
|
|
def __init__(self, lines, arrows, **kwargs):
|
|
if kwargs:
|
|
cbook.warn_deprecated(
|
|
"3.3",
|
|
message="Passing arbitrary keyword arguments to StreamplotSet "
|
|
"is deprecated since %(since) and will become an "
|
|
"error %(removal)s.")
|
|
self.lines = lines
|
|
self.arrows = arrows
|
|
|
|
|
|
# Coordinate definitions
|
|
# ========================
|
|
|
|
class DomainMap:
|
|
"""
|
|
Map representing different coordinate systems.
|
|
|
|
Coordinate definitions:
|
|
|
|
* axes-coordinates goes from 0 to 1 in the domain.
|
|
* data-coordinates are specified by the input x-y coordinates.
|
|
* grid-coordinates goes from 0 to N and 0 to M for an N x M grid,
|
|
where N and M match the shape of the input data.
|
|
* mask-coordinates goes from 0 to N and 0 to M for an N x M mask,
|
|
where N and M are user-specified to control the density of streamlines.
|
|
|
|
This class also has methods for adding trajectories to the StreamMask.
|
|
Before adding a trajectory, run `start_trajectory` to keep track of regions
|
|
crossed by a given trajectory. Later, if you decide the trajectory is bad
|
|
(e.g., if the trajectory is very short) just call `undo_trajectory`.
|
|
"""
|
|
|
|
def __init__(self, grid, mask):
|
|
self.grid = grid
|
|
self.mask = mask
|
|
# Constants for conversion between grid- and mask-coordinates
|
|
self.x_grid2mask = (mask.nx - 1) / (grid.nx - 1)
|
|
self.y_grid2mask = (mask.ny - 1) / (grid.ny - 1)
|
|
|
|
self.x_mask2grid = 1. / self.x_grid2mask
|
|
self.y_mask2grid = 1. / self.y_grid2mask
|
|
|
|
self.x_data2grid = 1. / grid.dx
|
|
self.y_data2grid = 1. / grid.dy
|
|
|
|
def grid2mask(self, xi, yi):
|
|
"""Return nearest space in mask-coords from given grid-coords."""
|
|
return (int(xi * self.x_grid2mask + 0.5),
|
|
int(yi * self.y_grid2mask + 0.5))
|
|
|
|
def mask2grid(self, xm, ym):
|
|
return xm * self.x_mask2grid, ym * self.y_mask2grid
|
|
|
|
def data2grid(self, xd, yd):
|
|
return xd * self.x_data2grid, yd * self.y_data2grid
|
|
|
|
def grid2data(self, xg, yg):
|
|
return xg / self.x_data2grid, yg / self.y_data2grid
|
|
|
|
def start_trajectory(self, xg, yg):
|
|
xm, ym = self.grid2mask(xg, yg)
|
|
self.mask._start_trajectory(xm, ym)
|
|
|
|
def reset_start_point(self, xg, yg):
|
|
xm, ym = self.grid2mask(xg, yg)
|
|
self.mask._current_xy = (xm, ym)
|
|
|
|
def update_trajectory(self, xg, yg):
|
|
if not self.grid.within_grid(xg, yg):
|
|
raise InvalidIndexError
|
|
xm, ym = self.grid2mask(xg, yg)
|
|
self.mask._update_trajectory(xm, ym)
|
|
|
|
def undo_trajectory(self):
|
|
self.mask._undo_trajectory()
|
|
|
|
|
|
class Grid:
|
|
"""Grid of data."""
|
|
def __init__(self, x, y):
|
|
|
|
if x.ndim == 1:
|
|
pass
|
|
elif x.ndim == 2:
|
|
x_row = x[0, :]
|
|
if not np.allclose(x_row, x):
|
|
raise ValueError("The rows of 'x' must be equal")
|
|
x = x_row
|
|
else:
|
|
raise ValueError("'x' can have at maximum 2 dimensions")
|
|
|
|
if y.ndim == 1:
|
|
pass
|
|
elif y.ndim == 2:
|
|
y_col = y[:, 0]
|
|
if not np.allclose(y_col, y.T):
|
|
raise ValueError("The columns of 'y' must be equal")
|
|
y = y_col
|
|
else:
|
|
raise ValueError("'y' can have at maximum 2 dimensions")
|
|
|
|
self.nx = len(x)
|
|
self.ny = len(y)
|
|
|
|
self.dx = x[1] - x[0]
|
|
self.dy = y[1] - y[0]
|
|
|
|
self.x_origin = x[0]
|
|
self.y_origin = y[0]
|
|
|
|
self.width = x[-1] - x[0]
|
|
self.height = y[-1] - y[0]
|
|
|
|
if not np.allclose(np.diff(x), self.width / (self.nx - 1)):
|
|
raise ValueError("'x' values must be equally spaced")
|
|
if not np.allclose(np.diff(y), self.height / (self.ny - 1)):
|
|
raise ValueError("'y' values must be equally spaced")
|
|
|
|
@property
|
|
def shape(self):
|
|
return self.ny, self.nx
|
|
|
|
def within_grid(self, xi, yi):
|
|
"""Return True if point is a valid index of grid."""
|
|
# Note that xi/yi can be floats; so, for example, we can't simply check
|
|
# `xi < self.nx` since *xi* can be `self.nx - 1 < xi < self.nx`
|
|
return 0 <= xi <= self.nx - 1 and 0 <= yi <= self.ny - 1
|
|
|
|
|
|
class StreamMask:
|
|
"""
|
|
Mask to keep track of discrete regions crossed by streamlines.
|
|
|
|
The resolution of this grid determines the approximate spacing between
|
|
trajectories. Streamlines are only allowed to pass through zeroed cells:
|
|
When a streamline enters a cell, that cell is set to 1, and no new
|
|
streamlines are allowed to enter.
|
|
"""
|
|
|
|
def __init__(self, density):
|
|
try:
|
|
self.nx, self.ny = (30 * np.broadcast_to(density, 2)).astype(int)
|
|
except ValueError as err:
|
|
raise ValueError("'density' must be a scalar or be of length "
|
|
"2") from err
|
|
if self.nx < 0 or self.ny < 0:
|
|
raise ValueError("'density' must be positive")
|
|
self._mask = np.zeros((self.ny, self.nx))
|
|
self.shape = self._mask.shape
|
|
|
|
self._current_xy = None
|
|
|
|
def __getitem__(self, args):
|
|
return self._mask[args]
|
|
|
|
def _start_trajectory(self, xm, ym):
|
|
"""Start recording streamline trajectory"""
|
|
self._traj = []
|
|
self._update_trajectory(xm, ym)
|
|
|
|
def _undo_trajectory(self):
|
|
"""Remove current trajectory from mask"""
|
|
for t in self._traj:
|
|
self._mask[t] = 0
|
|
|
|
def _update_trajectory(self, xm, ym):
|
|
"""
|
|
Update current trajectory position in mask.
|
|
|
|
If the new position has already been filled, raise `InvalidIndexError`.
|
|
"""
|
|
if self._current_xy != (xm, ym):
|
|
if self[ym, xm] == 0:
|
|
self._traj.append((ym, xm))
|
|
self._mask[ym, xm] = 1
|
|
self._current_xy = (xm, ym)
|
|
else:
|
|
raise InvalidIndexError
|
|
|
|
|
|
class InvalidIndexError(Exception):
|
|
pass
|
|
|
|
|
|
class TerminateTrajectory(Exception):
|
|
pass
|
|
|
|
|
|
# Integrator definitions
|
|
# =======================
|
|
|
|
def get_integrator(u, v, dmap, minlength, maxlength, integration_direction):
|
|
|
|
# rescale velocity onto grid-coordinates for integrations.
|
|
u, v = dmap.data2grid(u, v)
|
|
|
|
# speed (path length) will be in axes-coordinates
|
|
u_ax = u / (dmap.grid.nx - 1)
|
|
v_ax = v / (dmap.grid.ny - 1)
|
|
speed = np.ma.sqrt(u_ax ** 2 + v_ax ** 2)
|
|
|
|
def forward_time(xi, yi):
|
|
if not dmap.grid.within_grid(xi, yi):
|
|
raise OutOfBounds
|
|
ds_dt = interpgrid(speed, xi, yi)
|
|
if ds_dt == 0:
|
|
raise TerminateTrajectory()
|
|
dt_ds = 1. / ds_dt
|
|
ui = interpgrid(u, xi, yi)
|
|
vi = interpgrid(v, xi, yi)
|
|
return ui * dt_ds, vi * dt_ds
|
|
|
|
def backward_time(xi, yi):
|
|
dxi, dyi = forward_time(xi, yi)
|
|
return -dxi, -dyi
|
|
|
|
def integrate(x0, y0):
|
|
"""
|
|
Return x, y grid-coordinates of trajectory based on starting point.
|
|
|
|
Integrate both forward and backward in time from starting point in
|
|
grid coordinates.
|
|
|
|
Integration is terminated when a trajectory reaches a domain boundary
|
|
or when it crosses into an already occupied cell in the StreamMask. The
|
|
resulting trajectory is None if it is shorter than `minlength`.
|
|
"""
|
|
|
|
stotal, x_traj, y_traj = 0., [], []
|
|
|
|
try:
|
|
dmap.start_trajectory(x0, y0)
|
|
except InvalidIndexError:
|
|
return None
|
|
if integration_direction in ['both', 'backward']:
|
|
s, xt, yt = _integrate_rk12(x0, y0, dmap, backward_time, maxlength)
|
|
stotal += s
|
|
x_traj += xt[::-1]
|
|
y_traj += yt[::-1]
|
|
|
|
if integration_direction in ['both', 'forward']:
|
|
dmap.reset_start_point(x0, y0)
|
|
s, xt, yt = _integrate_rk12(x0, y0, dmap, forward_time, maxlength)
|
|
if len(x_traj) > 0:
|
|
xt = xt[1:]
|
|
yt = yt[1:]
|
|
stotal += s
|
|
x_traj += xt
|
|
y_traj += yt
|
|
|
|
if stotal > minlength:
|
|
return x_traj, y_traj
|
|
else: # reject short trajectories
|
|
dmap.undo_trajectory()
|
|
return None
|
|
|
|
return integrate
|
|
|
|
|
|
class OutOfBounds(IndexError):
|
|
pass
|
|
|
|
|
|
def _integrate_rk12(x0, y0, dmap, f, maxlength):
|
|
"""
|
|
2nd-order Runge-Kutta algorithm with adaptive step size.
|
|
|
|
This method is also referred to as the improved Euler's method, or Heun's
|
|
method. This method is favored over higher-order methods because:
|
|
|
|
1. To get decent looking trajectories and to sample every mask cell
|
|
on the trajectory we need a small timestep, so a lower order
|
|
solver doesn't hurt us unless the data is *very* high resolution.
|
|
In fact, for cases where the user inputs
|
|
data smaller or of similar grid size to the mask grid, the higher
|
|
order corrections are negligible because of the very fast linear
|
|
interpolation used in `interpgrid`.
|
|
|
|
2. For high resolution input data (i.e. beyond the mask
|
|
resolution), we must reduce the timestep. Therefore, an adaptive
|
|
timestep is more suited to the problem as this would be very hard
|
|
to judge automatically otherwise.
|
|
|
|
This integrator is about 1.5 - 2x as fast as both the RK4 and RK45
|
|
solvers in most setups on my machine. I would recommend removing the
|
|
other two to keep things simple.
|
|
"""
|
|
# This error is below that needed to match the RK4 integrator. It
|
|
# is set for visual reasons -- too low and corners start
|
|
# appearing ugly and jagged. Can be tuned.
|
|
maxerror = 0.003
|
|
|
|
# This limit is important (for all integrators) to avoid the
|
|
# trajectory skipping some mask cells. We could relax this
|
|
# condition if we use the code which is commented out below to
|
|
# increment the location gradually. However, due to the efficient
|
|
# nature of the interpolation, this doesn't boost speed by much
|
|
# for quite a bit of complexity.
|
|
maxds = min(1. / dmap.mask.nx, 1. / dmap.mask.ny, 0.1)
|
|
|
|
ds = maxds
|
|
stotal = 0
|
|
xi = x0
|
|
yi = y0
|
|
xf_traj = []
|
|
yf_traj = []
|
|
|
|
while True:
|
|
try:
|
|
if dmap.grid.within_grid(xi, yi):
|
|
xf_traj.append(xi)
|
|
yf_traj.append(yi)
|
|
else:
|
|
raise OutOfBounds
|
|
|
|
# Compute the two intermediate gradients.
|
|
# f should raise OutOfBounds if the locations given are
|
|
# outside the grid.
|
|
k1x, k1y = f(xi, yi)
|
|
k2x, k2y = f(xi + ds * k1x, yi + ds * k1y)
|
|
|
|
except OutOfBounds:
|
|
# Out of the domain during this step.
|
|
# Take an Euler step to the boundary to improve neatness
|
|
# unless the trajectory is currently empty.
|
|
if xf_traj:
|
|
ds, xf_traj, yf_traj = _euler_step(xf_traj, yf_traj,
|
|
dmap, f)
|
|
stotal += ds
|
|
break
|
|
except TerminateTrajectory:
|
|
break
|
|
|
|
dx1 = ds * k1x
|
|
dy1 = ds * k1y
|
|
dx2 = ds * 0.5 * (k1x + k2x)
|
|
dy2 = ds * 0.5 * (k1y + k2y)
|
|
|
|
nx, ny = dmap.grid.shape
|
|
# Error is normalized to the axes coordinates
|
|
error = np.hypot((dx2 - dx1) / (nx - 1), (dy2 - dy1) / (ny - 1))
|
|
|
|
# Only save step if within error tolerance
|
|
if error < maxerror:
|
|
xi += dx2
|
|
yi += dy2
|
|
try:
|
|
dmap.update_trajectory(xi, yi)
|
|
except InvalidIndexError:
|
|
break
|
|
if stotal + ds > maxlength:
|
|
break
|
|
stotal += ds
|
|
|
|
# recalculate stepsize based on step error
|
|
if error == 0:
|
|
ds = maxds
|
|
else:
|
|
ds = min(maxds, 0.85 * ds * (maxerror / error) ** 0.5)
|
|
|
|
return stotal, xf_traj, yf_traj
|
|
|
|
|
|
def _euler_step(xf_traj, yf_traj, dmap, f):
|
|
"""Simple Euler integration step that extends streamline to boundary."""
|
|
ny, nx = dmap.grid.shape
|
|
xi = xf_traj[-1]
|
|
yi = yf_traj[-1]
|
|
cx, cy = f(xi, yi)
|
|
if cx == 0:
|
|
dsx = np.inf
|
|
elif cx < 0:
|
|
dsx = xi / -cx
|
|
else:
|
|
dsx = (nx - 1 - xi) / cx
|
|
if cy == 0:
|
|
dsy = np.inf
|
|
elif cy < 0:
|
|
dsy = yi / -cy
|
|
else:
|
|
dsy = (ny - 1 - yi) / cy
|
|
ds = min(dsx, dsy)
|
|
xf_traj.append(xi + cx * ds)
|
|
yf_traj.append(yi + cy * ds)
|
|
return ds, xf_traj, yf_traj
|
|
|
|
|
|
# Utility functions
|
|
# ========================
|
|
|
|
def interpgrid(a, xi, yi):
|
|
"""Fast 2D, linear interpolation on an integer grid"""
|
|
|
|
Ny, Nx = np.shape(a)
|
|
if isinstance(xi, np.ndarray):
|
|
x = xi.astype(int)
|
|
y = yi.astype(int)
|
|
# Check that xn, yn don't exceed max index
|
|
xn = np.clip(x + 1, 0, Nx - 1)
|
|
yn = np.clip(y + 1, 0, Ny - 1)
|
|
else:
|
|
x = int(xi)
|
|
y = int(yi)
|
|
# conditional is faster than clipping for integers
|
|
if x == (Nx - 1):
|
|
xn = x
|
|
else:
|
|
xn = x + 1
|
|
if y == (Ny - 1):
|
|
yn = y
|
|
else:
|
|
yn = y + 1
|
|
|
|
a00 = a[y, x]
|
|
a01 = a[y, xn]
|
|
a10 = a[yn, x]
|
|
a11 = a[yn, xn]
|
|
xt = xi - x
|
|
yt = yi - y
|
|
a0 = a00 * (1 - xt) + a01 * xt
|
|
a1 = a10 * (1 - xt) + a11 * xt
|
|
ai = a0 * (1 - yt) + a1 * yt
|
|
|
|
if not isinstance(xi, np.ndarray):
|
|
if np.ma.is_masked(ai):
|
|
raise TerminateTrajectory
|
|
|
|
return ai
|
|
|
|
|
|
def _gen_starting_points(shape):
|
|
"""
|
|
Yield starting points for streamlines.
|
|
|
|
Trying points on the boundary first gives higher quality streamlines.
|
|
This algorithm starts with a point on the mask corner and spirals inward.
|
|
This algorithm is inefficient, but fast compared to rest of streamplot.
|
|
"""
|
|
ny, nx = shape
|
|
xfirst = 0
|
|
yfirst = 1
|
|
xlast = nx - 1
|
|
ylast = ny - 1
|
|
x, y = 0, 0
|
|
direction = 'right'
|
|
for i in range(nx * ny):
|
|
yield x, y
|
|
|
|
if direction == 'right':
|
|
x += 1
|
|
if x >= xlast:
|
|
xlast -= 1
|
|
direction = 'up'
|
|
elif direction == 'up':
|
|
y += 1
|
|
if y >= ylast:
|
|
ylast -= 1
|
|
direction = 'left'
|
|
elif direction == 'left':
|
|
x -= 1
|
|
if x <= xfirst:
|
|
xfirst += 1
|
|
direction = 'down'
|
|
elif direction == 'down':
|
|
y -= 1
|
|
if y <= yfirst:
|
|
yfirst += 1
|
|
direction = 'right'
|