fixed type hinting for seeds.py and grid_filters.py

This commit is contained in:
Philip Eisenlohr 2021-11-02 14:28:54 -04:00
parent 8636c5dad4
commit ccfe276ae1
2 changed files with 34 additions and 26 deletions

View File

@ -12,9 +12,11 @@ the following operations are required for tensorial data:
"""
from scipy import spatial as _spatial
from typing import Sequence, Tuple, Union
import numpy as _np
def _ks(size,cells,first_order=False):
def _ks(size: _np.ndarray, cells: Union[_np.ndarray,Sequence[int]], first_order: bool = False) -> _np.ndarray:
"""
Get wave numbers operator.
@ -41,7 +43,7 @@ def _ks(size,cells,first_order=False):
return _np.stack(_np.meshgrid(k_sk,k_sj,k_si,indexing = 'ij'), axis=-1)
def curl(size,f):
def curl(size: _np.ndarray, f: _np.ndarray) -> _np.ndarray:
u"""
Calculate curl of a vector or tensor field in Fourier space.
@ -72,7 +74,7 @@ def curl(size,f):
return _np.fft.irfftn(curl_,axes=(0,1,2),s=f.shape[:3])
def divergence(size,f):
def divergence(size: _np.ndarray, f: _np.ndarray) -> _np.ndarray:
u"""
Calculate divergence of a vector or tensor field in Fourier space.
@ -99,7 +101,7 @@ def divergence(size,f):
return _np.fft.irfftn(div_,axes=(0,1,2),s=f.shape[:3])
def gradient(size,f):
def gradient(size: _np.ndarray, f: _np.ndarray) -> _np.ndarray:
u"""
Calculate gradient of a scalar or vector fieldin Fourier space.
@ -126,7 +128,9 @@ def gradient(size,f):
return _np.fft.irfftn(grad_,axes=(0,1,2),s=f.shape[:3])
def coordinates0_point(cells,size,origin=_np.zeros(3)):
def coordinates0_point(cells: Union[ _np.ndarray,Sequence[int]],
size: _np.ndarray,
origin: _np.ndarray = _np.zeros(3)) -> _np.ndarray:
"""
Cell center positions (undeformed).
@ -145,8 +149,8 @@ def coordinates0_point(cells,size,origin=_np.zeros(3)):
Undeformed cell center coordinates.
"""
start = origin + size/cells*.5
end = origin + size - size/cells*.5
start = origin + size/_np.array(cells)*.5
end = origin + size - size/_np.array(cells)*.5
return _np.stack(_np.meshgrid(_np.linspace(start[0],end[0],cells[0]),
_np.linspace(start[1],end[1],cells[1]),
@ -154,7 +158,7 @@ def coordinates0_point(cells,size,origin=_np.zeros(3)):
axis = -1)
def displacement_fluct_point(size,F):
def displacement_fluct_point(size: _np.ndarray, F: _np.ndarray) -> _np.ndarray:
"""
Cell center displacement field from fluctuation part of the deformation gradient field.
@ -186,7 +190,7 @@ def displacement_fluct_point(size,F):
return _np.fft.irfftn(displacement,axes=(0,1,2),s=F.shape[:3])
def displacement_avg_point(size,F):
def displacement_avg_point(size: _np.ndarray, F: _np.ndarray) -> _np.ndarray:
"""
Cell center displacement field from average part of the deformation gradient field.
@ -207,7 +211,7 @@ def displacement_avg_point(size,F):
return _np.einsum('ml,ijkl->ijkm',F_avg - _np.eye(3),coordinates0_point(F.shape[:3],size))
def displacement_point(size,F):
def displacement_point(size: _np.ndarray, F: _np.ndarray) -> _np.ndarray:
"""
Cell center displacement field from deformation gradient field.
@ -227,7 +231,7 @@ def displacement_point(size,F):
return displacement_avg_point(size,F) + displacement_fluct_point(size,F)
def coordinates_point(size,F,origin=_np.zeros(3)):
def coordinates_point(size: _np.ndarray, F: _np.ndarray, origin: _np.ndarray = _np.zeros(3)) -> _np.ndarray:
"""
Cell center positions.
@ -249,7 +253,8 @@ def coordinates_point(size,F,origin=_np.zeros(3)):
return coordinates0_point(F.shape[:3],size,origin) + displacement_point(size,F)
def cellsSizeOrigin_coordinates0_point(coordinates0,ordered=True):
def cellsSizeOrigin_coordinates0_point(coordinates0: _np.ndarray,
ordered: bool = True) -> Tuple[_np.ndarray,_np.ndarray,_np.ndarray]:
"""
Return grid 'DNA', i.e. cells, size, and origin from 1D array of point positions.
@ -292,13 +297,15 @@ def cellsSizeOrigin_coordinates0_point(coordinates0,ordered=True):
raise ValueError('Regular cell spacing violated.')
if ordered and not _np.allclose(coordinates0.reshape(tuple(cells)+(3,),order='F'),
coordinates0_point(cells,size,origin),atol=atol):
coordinates0_point(list(cells),size,origin),atol=atol):
raise ValueError('Input data is not ordered (x fast, z slow).')
return (cells,size,origin)
def coordinates0_node(cells,size,origin=_np.zeros(3)):
def coordinates0_node(cells: Union[_np.ndarray,Sequence[int]],
size: _np.ndarray,
origin: _np.ndarray = _np.zeros(3)) -> _np.ndarray:
"""
Nodal positions (undeformed).
@ -323,7 +330,7 @@ def coordinates0_node(cells,size,origin=_np.zeros(3)):
axis = -1)
def displacement_fluct_node(size,F):
def displacement_fluct_node(size: _np.ndarray, F: _np.ndarray) -> _np.ndarray:
"""
Nodal displacement field from fluctuation part of the deformation gradient field.
@ -343,7 +350,7 @@ def displacement_fluct_node(size,F):
return point_to_node(displacement_fluct_point(size,F))
def displacement_avg_node(size,F):
def displacement_avg_node(size: _np.ndarray, F: _np.ndarray) -> _np.ndarray:
"""
Nodal displacement field from average part of the deformation gradient field.
@ -364,7 +371,7 @@ def displacement_avg_node(size,F):
return _np.einsum('ml,ijkl->ijkm',F_avg - _np.eye(3),coordinates0_node(F.shape[:3],size))
def displacement_node(size,F):
def displacement_node(size: _np.ndarray, F: _np.ndarray) -> _np.ndarray:
"""
Nodal displacement field from deformation gradient field.
@ -384,7 +391,7 @@ def displacement_node(size,F):
return displacement_avg_node(size,F) + displacement_fluct_node(size,F)
def coordinates_node(size,F,origin=_np.zeros(3)):
def coordinates_node(size: _np.ndarray, F: _np.ndarray, origin: _np.ndarray = _np.zeros(3)) -> _np.ndarray:
"""
Nodal positions.
@ -406,7 +413,8 @@ def coordinates_node(size,F,origin=_np.zeros(3)):
return coordinates0_node(F.shape[:3],size,origin) + displacement_node(size,F)
def cellsSizeOrigin_coordinates0_node(coordinates0,ordered=True):
def cellsSizeOrigin_coordinates0_node(coordinates0: _np.ndarray,
ordered: bool = True) -> Tuple[_np.ndarray,_np.ndarray,_np.ndarray]:
"""
Return grid 'DNA', i.e. cells, size, and origin from 1D array of nodal positions.
@ -441,13 +449,13 @@ def cellsSizeOrigin_coordinates0_node(coordinates0,ordered=True):
raise ValueError('Regular cell spacing violated.')
if ordered and not _np.allclose(coordinates0.reshape(tuple(cells+1)+(3,),order='F'),
coordinates0_node(cells,size,origin),atol=atol):
coordinates0_node(list(cells),size,origin),atol=atol):
raise ValueError('Input data is not ordered (x fast, z slow).')
return (cells,size,origin)
def point_to_node(cell_data):
def point_to_node(cell_data: _np.ndarray) -> _np.ndarray:
"""
Interpolate periodic point data to nodal data.
@ -469,7 +477,7 @@ def point_to_node(cell_data):
return _np.pad(n,((0,1),(0,1),(0,1))+((0,0),)*len(cell_data.shape[3:]),mode='wrap')
def node_to_point(node_data):
def node_to_point(node_data: _np.ndarray) -> _np.ndarray:
"""
Interpolate periodic nodal data to point data.
@ -491,7 +499,7 @@ def node_to_point(node_data):
return c[1:,1:,1:]
def coordinates0_valid(coordinates0):
def coordinates0_valid(coordinates0: _np.ndarray) -> bool:
"""
Check whether coordinates form a regular grid.
@ -513,7 +521,7 @@ def coordinates0_valid(coordinates0):
return False
def regrid(size,F,cells):
def regrid(size: _np.ndarray, F: _np.ndarray, cells: Union[_np.ndarray,Sequence[int]]) -> _np.ndarray:
"""
Return mapping from coordinates in deformed configuration to a regular grid.

View File

@ -1,7 +1,7 @@
"""Functionality for generation of seed points for Voronoi or Laguerre tessellation."""
from scipy import spatial as _spatial
from typing import Sequence
from typing import Sequence,Tuple
import numpy as _np
@ -97,7 +97,7 @@ def from_Poisson_disc(size: _np.ndarray, N_seeds: int, N_candidates: int, distan
def from_grid(grid, selection: Sequence[int] = None,
invert: bool = False, average: bool = False, periodic: bool = True) -> tuple[_np.ndarray, _np.ndarray]:
invert: bool = False, average: bool = False, periodic: bool = True) -> Tuple[_np.ndarray, _np.ndarray]:
"""
Create seeds from grid description.