diff --git a/python/damask/util.py b/python/damask/util.py index 0581302db..a28a9e2eb 100644 --- a/python/damask/util.py +++ b/python/damask/util.py @@ -8,6 +8,8 @@ import shlex import re import fractions from functools import reduce +from typing import Union, Tuple, Sequence, Callable, Dict, List, Any, Literal +import pathlib import numpy as np import h5py @@ -50,7 +52,7 @@ _colors = { #################################################################################################### # Functions #################################################################################################### -def srepr(arg,glue = '\n'): +def srepr(arg: Union[np.ndarray, Sequence[Any]], glue: str = '\n') -> str: r""" Join items with glue string. @@ -75,7 +77,7 @@ def srepr(arg,glue = '\n'): return arg if isinstance(arg,str) else repr(arg) -def emph(what): +def emph(what: Any) -> str: """ Format with emphasis. @@ -92,7 +94,7 @@ def emph(what): """ return _colors['bold']+srepr(what)+_colors['end_color'] -def deemph(what): +def deemph(what: Any) -> str: """ Format with deemphasis. @@ -109,7 +111,7 @@ def deemph(what): """ return _colors['dim']+srepr(what)+_colors['end_color'] -def warn(what): +def warn(what: Any) -> str: """ Format for warning. @@ -126,7 +128,7 @@ def warn(what): """ return _colors['warning']+emph(what)+_colors['end_color'] -def strikeout(what): +def strikeout(what: Any) -> str: """ Format as strikeout. @@ -144,7 +146,7 @@ def strikeout(what): return _colors['crossout']+srepr(what)+_colors['end_color'] -def run(cmd,wd='./',env=None,timeout=None): +def run(cmd: str, wd: str = './', env: Dict[str, Any] = None, timeout: int = None) -> Tuple[str, str]: """ Run a command. @@ -185,7 +187,7 @@ def run(cmd,wd='./',env=None,timeout=None): execute = run -def natural_sort(key): +def natural_sort(key: str) -> List[Union[int, str]]: """ Natural sort. @@ -200,7 +202,10 @@ def natural_sort(key): return [ convert(c) for c in re.split('([0-9]+)', key) ] -def show_progress(iterable,N_iter=None,prefix='',bar_length=50): +def show_progress(iterable: Sequence[Any], + N_iter: int = None, + prefix: str = '', + bar_length: int = 50) -> Any: """ Decorate a loop with a progress bar. @@ -229,7 +234,7 @@ def show_progress(iterable,N_iter=None,prefix='',bar_length=50): status.update(i) -def scale_to_coprime(v): +def scale_to_coprime(v: np.ndarray) -> np.ndarray: """ Scale vector to co-prime (relatively prime) integers. @@ -267,13 +272,13 @@ def scale_to_coprime(v): return m -def project_equal_angle(vector,direction='z',normalize=True,keepdims=False): +def project_equal_angle(vector: np.ndarray, direction: str = 'z', normalize: bool = True, keepdims: bool = False) -> np.ndarray: """ Apply equal-angle projection to vector. Parameters ---------- - vector : numpy.ndarray, shape (...,3) + vector : numpy.ndarray of shape (...,3) Vector coordinates to be projected. direction : str Projection direction 'x', 'y', or 'z'. @@ -309,7 +314,10 @@ def project_equal_angle(vector,direction='z',normalize=True,keepdims=False): return np.roll(np.block([v[...,:2]/(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]), -shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2] -def project_equal_area(vector,direction='z',normalize=True,keepdims=False): +def project_equal_area(vector: np.ndarray, + direction: Literal['x', 'y', 'z'] = 'z', + normalize: bool = True, + keepdims: bool = False) -> np.ndarray: """ Apply equal-area projection to vector. @@ -351,15 +359,14 @@ def project_equal_area(vector,direction='z',normalize=True,keepdims=False): return np.roll(np.block([v[...,:2]/np.sqrt(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]), -shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2] - -def execution_stamp(class_name,function_name=None): +def execution_stamp(class_name: str, function_name: str = None) -> str: """Timestamp the execution of a (function within a) class.""" now = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S%z') _function_name = '' if function_name is None else f'.{function_name}' return f'damask.{class_name}{_function_name} v{version} ({now})' -def hybrid_IA(dist,N,rng_seed=None): +def hybrid_IA(dist: np.ndarray, N: int, rng_seed: Union[int, np.ndarray] = None) -> np.ndarray: """ Hybrid integer approximation. @@ -387,7 +394,10 @@ def hybrid_IA(dist,N,rng_seed=None): return np.repeat(np.arange(len(dist)),repeats)[np.random.default_rng(rng_seed).permutation(N_inv_samples)[:N]] -def shapeshifter(fro,to,mode='left',keep_ones=False): +def shapeshifter(fro: Tuple[int, ...], + to: Tuple[int, ...], + mode: Literal['left','right'] = 'left', + keep_ones: bool = False) -> Tuple[int, ...]: """ Return dimensions that reshape 'fro' to become broadcastable to 'to'. @@ -434,21 +444,22 @@ def shapeshifter(fro,to,mode='left',keep_ones=False): fro = (1,) if not len(fro) else fro to = (1,) if not len(to) else to try: - grp = re.match(beg[mode] + match = re.match((beg[mode] +f',{sep[mode]}'.join(map(lambda x: f'{x}' if x>1 or (keep_ones and len(fro)>1) else '\\d+',fro)) - +f',{end[mode]}', - ','.join(map(str,to))+',').groups() - except AttributeError: + +f',{end[mode]}'),','.join(map(str,to))+',') + assert match + except AssertionError: raise ValueError(f'Shapes can not be shifted {fro} --> {to}') - fill = () + grp: Sequence[str] = match.groups() + fill: Tuple[int, ...] = () for g,d in zip(grp,fro+(None,)): fill += (1,)*g.count(',')+(d,) return fill[:-1] -def shapeblender(a,b): +def shapeblender(a: Tuple[int, ...], b: Tuple[int, ...]) -> Tuple[int, ...]: """ Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'. @@ -476,7 +487,7 @@ def shapeblender(a,b): return a + b[i:] -def extend_docstring(extra_docstring): +def extend_docstring(extra_docstring: str) -> Callable: """ Decorator: Append to function's docstring. @@ -492,7 +503,7 @@ def extend_docstring(extra_docstring): return _decorator -def extended_docstring(f,extra_docstring): +def extended_docstring(f: Callable, extra_docstring: str) -> Callable: """ Decorator: Combine another function's docstring with a given docstring. @@ -510,7 +521,7 @@ def extended_docstring(f,extra_docstring): return _decorator -def DREAM3D_base_group(fname): +def DREAM3D_base_group(fname: Union[str, pathlib.Path]) -> str: """ Determine the base group of a DREAM.3D file. @@ -536,7 +547,7 @@ def DREAM3D_base_group(fname): return base_group -def DREAM3D_cell_data_group(fname): +def DREAM3D_cell_data_group(fname: Union[str, pathlib.Path]) -> str: """ Determine the cell data group of a DREAM.3D file. @@ -568,7 +579,7 @@ def DREAM3D_cell_data_group(fname): return cell_data_group -def Bravais_to_Miller(*,uvtw=None,hkil=None): +def Bravais_to_Miller(*, uvtw: np.ndarray = None, hkil: np.ndarray = None) -> np.ndarray: """ Transform 4 Miller–Bravais indices to 3 Miller indices of crystal direction [uvw] or plane normal (hkl). @@ -595,7 +606,7 @@ def Bravais_to_Miller(*,uvtw=None,hkil=None): return np.einsum('il,...l',basis,axis) -def Miller_to_Bravais(*,uvw=None,hkl=None): +def Miller_to_Bravais(*, uvw: np.ndarray = None, hkl: np.ndarray = None) -> np.ndarray: """ Transform 3 Miller indices to 4 Miller–Bravais indices of crystal direction [uvtw] or plane normal (hkil). @@ -624,7 +635,7 @@ def Miller_to_Bravais(*,uvw=None,hkl=None): return np.einsum('il,...l',basis,axis) -def dict_prune(d): +def dict_prune(d: Dict[Any, Any]) -> Dict[Any, Any]: """ Recursively remove empty dictionaries. @@ -650,7 +661,7 @@ def dict_prune(d): return new -def dict_flatten(d): +def dict_flatten(d: Dict[Any, Any]) -> Dict[Any, Any]: """ Recursively remove keys of single-entry dictionaries. @@ -685,7 +696,7 @@ class _ProgressBar: Works for 0-based loops, ETA is estimated by linear extrapolation. """ - def __init__(self,total,prefix,bar_length): + def __init__(self, total: int, prefix: str, bar_length: int): """ Set current time as basis for ETA estimation. @@ -708,7 +719,7 @@ class _ProgressBar: sys.stderr.write(f"{self.prefix} {'░'*self.bar_length} 0% ETA n/a") sys.stderr.flush() - def update(self,iteration): + def update(self, iteration: int) -> None: fraction = (iteration+1) / self.total filled_length = int(self.bar_length * fraction)