added typehints for util module
This commit is contained in:
parent
b796bc0697
commit
adf7abbda6
|
@ -8,6 +8,8 @@ import shlex
|
|||
import re
|
||||
import fractions
|
||||
from functools import reduce
|
||||
from typing import Union, Tuple, Sequence, Callable, Dict, List, Any, Literal
|
||||
import pathlib
|
||||
|
||||
import numpy as np
|
||||
import h5py
|
||||
|
@ -50,7 +52,7 @@ _colors = {
|
|||
####################################################################################################
|
||||
# Functions
|
||||
####################################################################################################
|
||||
def srepr(arg,glue = '\n'):
|
||||
def srepr(arg: Union[np.ndarray, Sequence[Any]], glue: str = '\n') -> str:
|
||||
r"""
|
||||
Join items with glue string.
|
||||
|
||||
|
@ -75,7 +77,7 @@ def srepr(arg,glue = '\n'):
|
|||
return arg if isinstance(arg,str) else repr(arg)
|
||||
|
||||
|
||||
def emph(what):
|
||||
def emph(what: Any) -> str:
|
||||
"""
|
||||
Format with emphasis.
|
||||
|
||||
|
@ -92,7 +94,7 @@ def emph(what):
|
|||
"""
|
||||
return _colors['bold']+srepr(what)+_colors['end_color']
|
||||
|
||||
def deemph(what):
|
||||
def deemph(what: Any) -> str:
|
||||
"""
|
||||
Format with deemphasis.
|
||||
|
||||
|
@ -109,7 +111,7 @@ def deemph(what):
|
|||
"""
|
||||
return _colors['dim']+srepr(what)+_colors['end_color']
|
||||
|
||||
def warn(what):
|
||||
def warn(what: Any) -> str:
|
||||
"""
|
||||
Format for warning.
|
||||
|
||||
|
@ -126,7 +128,7 @@ def warn(what):
|
|||
"""
|
||||
return _colors['warning']+emph(what)+_colors['end_color']
|
||||
|
||||
def strikeout(what):
|
||||
def strikeout(what: Any) -> str:
|
||||
"""
|
||||
Format as strikeout.
|
||||
|
||||
|
@ -144,7 +146,7 @@ def strikeout(what):
|
|||
return _colors['crossout']+srepr(what)+_colors['end_color']
|
||||
|
||||
|
||||
def run(cmd,wd='./',env=None,timeout=None):
|
||||
def run(cmd: str, wd: str = './', env: Dict[str, Any] = None, timeout: int = None) -> Tuple[str, str]:
|
||||
"""
|
||||
Run a command.
|
||||
|
||||
|
@ -185,7 +187,7 @@ def run(cmd,wd='./',env=None,timeout=None):
|
|||
execute = run
|
||||
|
||||
|
||||
def natural_sort(key):
|
||||
def natural_sort(key: str) -> List[Union[int, str]]:
|
||||
"""
|
||||
Natural sort.
|
||||
|
||||
|
@ -200,7 +202,10 @@ def natural_sort(key):
|
|||
return [ convert(c) for c in re.split('([0-9]+)', key) ]
|
||||
|
||||
|
||||
def show_progress(iterable,N_iter=None,prefix='',bar_length=50):
|
||||
def show_progress(iterable: Sequence[Any],
|
||||
N_iter: int = None,
|
||||
prefix: str = '',
|
||||
bar_length: int = 50) -> Any:
|
||||
"""
|
||||
Decorate a loop with a progress bar.
|
||||
|
||||
|
@ -229,7 +234,7 @@ def show_progress(iterable,N_iter=None,prefix='',bar_length=50):
|
|||
status.update(i)
|
||||
|
||||
|
||||
def scale_to_coprime(v):
|
||||
def scale_to_coprime(v: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Scale vector to co-prime (relatively prime) integers.
|
||||
|
||||
|
@ -267,13 +272,13 @@ def scale_to_coprime(v):
|
|||
return m
|
||||
|
||||
|
||||
def project_equal_angle(vector,direction='z',normalize=True,keepdims=False):
|
||||
def project_equal_angle(vector: np.ndarray, direction: str = 'z', normalize: bool = True, keepdims: bool = False) -> np.ndarray:
|
||||
"""
|
||||
Apply equal-angle projection to vector.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vector : numpy.ndarray, shape (...,3)
|
||||
vector : numpy.ndarray of shape (...,3)
|
||||
Vector coordinates to be projected.
|
||||
direction : str
|
||||
Projection direction 'x', 'y', or 'z'.
|
||||
|
@ -309,7 +314,10 @@ def project_equal_angle(vector,direction='z',normalize=True,keepdims=False):
|
|||
return np.roll(np.block([v[...,:2]/(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
|
||||
-shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]
|
||||
|
||||
def project_equal_area(vector,direction='z',normalize=True,keepdims=False):
|
||||
def project_equal_area(vector: np.ndarray,
|
||||
direction: Literal['x', 'y', 'z'] = 'z',
|
||||
normalize: bool = True,
|
||||
keepdims: bool = False) -> np.ndarray:
|
||||
"""
|
||||
Apply equal-area projection to vector.
|
||||
|
||||
|
@ -351,15 +359,14 @@ def project_equal_area(vector,direction='z',normalize=True,keepdims=False):
|
|||
return np.roll(np.block([v[...,:2]/np.sqrt(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
|
||||
-shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]
|
||||
|
||||
|
||||
def execution_stamp(class_name,function_name=None):
|
||||
def execution_stamp(class_name: str, function_name: str = None) -> str:
|
||||
"""Timestamp the execution of a (function within a) class."""
|
||||
now = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S%z')
|
||||
_function_name = '' if function_name is None else f'.{function_name}'
|
||||
return f'damask.{class_name}{_function_name} v{version} ({now})'
|
||||
|
||||
|
||||
def hybrid_IA(dist,N,rng_seed=None):
|
||||
def hybrid_IA(dist: np.ndarray, N: int, rng_seed: Union[int, np.ndarray] = None) -> np.ndarray:
|
||||
"""
|
||||
Hybrid integer approximation.
|
||||
|
||||
|
@ -387,7 +394,10 @@ def hybrid_IA(dist,N,rng_seed=None):
|
|||
return np.repeat(np.arange(len(dist)),repeats)[np.random.default_rng(rng_seed).permutation(N_inv_samples)[:N]]
|
||||
|
||||
|
||||
def shapeshifter(fro,to,mode='left',keep_ones=False):
|
||||
def shapeshifter(fro: Tuple[int, ...],
|
||||
to: Tuple[int, ...],
|
||||
mode: Literal['left','right'] = 'left',
|
||||
keep_ones: bool = False) -> Tuple[int, ...]:
|
||||
"""
|
||||
Return dimensions that reshape 'fro' to become broadcastable to 'to'.
|
||||
|
||||
|
@ -434,21 +444,22 @@ def shapeshifter(fro,to,mode='left',keep_ones=False):
|
|||
fro = (1,) if not len(fro) else fro
|
||||
to = (1,) if not len(to) else to
|
||||
try:
|
||||
grp = re.match(beg[mode]
|
||||
match = re.match((beg[mode]
|
||||
+f',{sep[mode]}'.join(map(lambda x: f'{x}'
|
||||
if x>1 or (keep_ones and len(fro)>1) else
|
||||
'\\d+',fro))
|
||||
+f',{end[mode]}',
|
||||
','.join(map(str,to))+',').groups()
|
||||
except AttributeError:
|
||||
+f',{end[mode]}'),','.join(map(str,to))+',')
|
||||
assert match
|
||||
except AssertionError:
|
||||
raise ValueError(f'Shapes can not be shifted {fro} --> {to}')
|
||||
fill = ()
|
||||
grp: Sequence[str] = match.groups()
|
||||
fill: Tuple[int, ...] = ()
|
||||
for g,d in zip(grp,fro+(None,)):
|
||||
fill += (1,)*g.count(',')+(d,)
|
||||
return fill[:-1]
|
||||
|
||||
|
||||
def shapeblender(a,b):
|
||||
def shapeblender(a: Tuple[int, ...], b: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
"""
|
||||
Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'.
|
||||
|
||||
|
@ -476,7 +487,7 @@ def shapeblender(a,b):
|
|||
return a + b[i:]
|
||||
|
||||
|
||||
def extend_docstring(extra_docstring):
|
||||
def extend_docstring(extra_docstring: str) -> Callable:
|
||||
"""
|
||||
Decorator: Append to function's docstring.
|
||||
|
||||
|
@ -492,7 +503,7 @@ def extend_docstring(extra_docstring):
|
|||
return _decorator
|
||||
|
||||
|
||||
def extended_docstring(f,extra_docstring):
|
||||
def extended_docstring(f: Callable, extra_docstring: str) -> Callable:
|
||||
"""
|
||||
Decorator: Combine another function's docstring with a given docstring.
|
||||
|
||||
|
@ -510,7 +521,7 @@ def extended_docstring(f,extra_docstring):
|
|||
return _decorator
|
||||
|
||||
|
||||
def DREAM3D_base_group(fname):
|
||||
def DREAM3D_base_group(fname: Union[str, pathlib.Path]) -> str:
|
||||
"""
|
||||
Determine the base group of a DREAM.3D file.
|
||||
|
||||
|
@ -536,7 +547,7 @@ def DREAM3D_base_group(fname):
|
|||
|
||||
return base_group
|
||||
|
||||
def DREAM3D_cell_data_group(fname):
|
||||
def DREAM3D_cell_data_group(fname: Union[str, pathlib.Path]) -> str:
|
||||
"""
|
||||
Determine the cell data group of a DREAM.3D file.
|
||||
|
||||
|
@ -568,7 +579,7 @@ def DREAM3D_cell_data_group(fname):
|
|||
return cell_data_group
|
||||
|
||||
|
||||
def Bravais_to_Miller(*,uvtw=None,hkil=None):
|
||||
def Bravais_to_Miller(*, uvtw: np.ndarray = None, hkil: np.ndarray = None) -> np.ndarray:
|
||||
"""
|
||||
Transform 4 Miller–Bravais indices to 3 Miller indices of crystal direction [uvw] or plane normal (hkl).
|
||||
|
||||
|
@ -595,7 +606,7 @@ def Bravais_to_Miller(*,uvtw=None,hkil=None):
|
|||
return np.einsum('il,...l',basis,axis)
|
||||
|
||||
|
||||
def Miller_to_Bravais(*,uvw=None,hkl=None):
|
||||
def Miller_to_Bravais(*, uvw: np.ndarray = None, hkl: np.ndarray = None) -> np.ndarray:
|
||||
"""
|
||||
Transform 3 Miller indices to 4 Miller–Bravais indices of crystal direction [uvtw] or plane normal (hkil).
|
||||
|
||||
|
@ -624,7 +635,7 @@ def Miller_to_Bravais(*,uvw=None,hkl=None):
|
|||
return np.einsum('il,...l',basis,axis)
|
||||
|
||||
|
||||
def dict_prune(d):
|
||||
def dict_prune(d: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
"""
|
||||
Recursively remove empty dictionaries.
|
||||
|
||||
|
@ -650,7 +661,7 @@ def dict_prune(d):
|
|||
return new
|
||||
|
||||
|
||||
def dict_flatten(d):
|
||||
def dict_flatten(d: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
"""
|
||||
Recursively remove keys of single-entry dictionaries.
|
||||
|
||||
|
@ -685,7 +696,7 @@ class _ProgressBar:
|
|||
Works for 0-based loops, ETA is estimated by linear extrapolation.
|
||||
"""
|
||||
|
||||
def __init__(self,total,prefix,bar_length):
|
||||
def __init__(self, total: int, prefix: str, bar_length: int):
|
||||
"""
|
||||
Set current time as basis for ETA estimation.
|
||||
|
||||
|
@ -708,7 +719,7 @@ class _ProgressBar:
|
|||
sys.stderr.write(f"{self.prefix} {'░'*self.bar_length} 0% ETA n/a")
|
||||
sys.stderr.flush()
|
||||
|
||||
def update(self,iteration):
|
||||
def update(self, iteration: int) -> None:
|
||||
|
||||
fraction = (iteration+1) / self.total
|
||||
filled_length = int(self.bar_length * fraction)
|
||||
|
|
Loading…
Reference in New Issue