added typehints for util module

This commit is contained in:
Daniel Otto de Mentock 2022-01-17 14:58:08 +01:00
parent b796bc0697
commit adf7abbda6
1 changed files with 43 additions and 32 deletions

View File

@ -8,6 +8,8 @@ import shlex
import re
import fractions
from functools import reduce
from typing import Union, Tuple, Sequence, Callable, Dict, List, Any, Literal
import pathlib
import numpy as np
import h5py
@ -50,7 +52,7 @@ _colors = {
####################################################################################################
# Functions
####################################################################################################
def srepr(arg,glue = '\n'):
def srepr(arg: Union[np.ndarray, Sequence[Any]], glue: str = '\n') -> str:
r"""
Join items with glue string.
@ -75,7 +77,7 @@ def srepr(arg,glue = '\n'):
return arg if isinstance(arg,str) else repr(arg)
def emph(what):
def emph(what: Any) -> str:
"""
Format with emphasis.
@ -92,7 +94,7 @@ def emph(what):
"""
return _colors['bold']+srepr(what)+_colors['end_color']
def deemph(what):
def deemph(what: Any) -> str:
"""
Format with deemphasis.
@ -109,7 +111,7 @@ def deemph(what):
"""
return _colors['dim']+srepr(what)+_colors['end_color']
def warn(what):
def warn(what: Any) -> str:
"""
Format for warning.
@ -126,7 +128,7 @@ def warn(what):
"""
return _colors['warning']+emph(what)+_colors['end_color']
def strikeout(what):
def strikeout(what: Any) -> str:
"""
Format as strikeout.
@ -144,7 +146,7 @@ def strikeout(what):
return _colors['crossout']+srepr(what)+_colors['end_color']
def run(cmd,wd='./',env=None,timeout=None):
def run(cmd: str, wd: str = './', env: Dict[str, Any] = None, timeout: int = None) -> Tuple[str, str]:
"""
Run a command.
@ -185,7 +187,7 @@ def run(cmd,wd='./',env=None,timeout=None):
execute = run
def natural_sort(key):
def natural_sort(key: str) -> List[Union[int, str]]:
"""
Natural sort.
@ -200,7 +202,10 @@ def natural_sort(key):
return [ convert(c) for c in re.split('([0-9]+)', key) ]
def show_progress(iterable,N_iter=None,prefix='',bar_length=50):
def show_progress(iterable: Sequence[Any],
N_iter: int = None,
prefix: str = '',
bar_length: int = 50) -> Any:
"""
Decorate a loop with a progress bar.
@ -229,7 +234,7 @@ def show_progress(iterable,N_iter=None,prefix='',bar_length=50):
status.update(i)
def scale_to_coprime(v):
def scale_to_coprime(v: np.ndarray) -> np.ndarray:
"""
Scale vector to co-prime (relatively prime) integers.
@ -267,13 +272,13 @@ def scale_to_coprime(v):
return m
def project_equal_angle(vector,direction='z',normalize=True,keepdims=False):
def project_equal_angle(vector: np.ndarray, direction: str = 'z', normalize: bool = True, keepdims: bool = False) -> np.ndarray:
"""
Apply equal-angle projection to vector.
Parameters
----------
vector : numpy.ndarray, shape (...,3)
vector : numpy.ndarray of shape (...,3)
Vector coordinates to be projected.
direction : str
Projection direction 'x', 'y', or 'z'.
@ -309,7 +314,10 @@ def project_equal_angle(vector,direction='z',normalize=True,keepdims=False):
return np.roll(np.block([v[...,:2]/(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
-shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]
def project_equal_area(vector,direction='z',normalize=True,keepdims=False):
def project_equal_area(vector: np.ndarray,
direction: Literal['x', 'y', 'z'] = 'z',
normalize: bool = True,
keepdims: bool = False) -> np.ndarray:
"""
Apply equal-area projection to vector.
@ -351,15 +359,14 @@ def project_equal_area(vector,direction='z',normalize=True,keepdims=False):
return np.roll(np.block([v[...,:2]/np.sqrt(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
-shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]
def execution_stamp(class_name,function_name=None):
def execution_stamp(class_name: str, function_name: str = None) -> str:
"""Timestamp the execution of a (function within a) class."""
now = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S%z')
_function_name = '' if function_name is None else f'.{function_name}'
return f'damask.{class_name}{_function_name} v{version} ({now})'
def hybrid_IA(dist,N,rng_seed=None):
def hybrid_IA(dist: np.ndarray, N: int, rng_seed: Union[int, np.ndarray] = None) -> np.ndarray:
"""
Hybrid integer approximation.
@ -387,7 +394,10 @@ def hybrid_IA(dist,N,rng_seed=None):
return np.repeat(np.arange(len(dist)),repeats)[np.random.default_rng(rng_seed).permutation(N_inv_samples)[:N]]
def shapeshifter(fro,to,mode='left',keep_ones=False):
def shapeshifter(fro: Tuple[int, ...],
to: Tuple[int, ...],
mode: Literal['left','right'] = 'left',
keep_ones: bool = False) -> Tuple[int, ...]:
"""
Return dimensions that reshape 'fro' to become broadcastable to 'to'.
@ -434,21 +444,22 @@ def shapeshifter(fro,to,mode='left',keep_ones=False):
fro = (1,) if not len(fro) else fro
to = (1,) if not len(to) else to
try:
grp = re.match(beg[mode]
match = re.match((beg[mode]
+f',{sep[mode]}'.join(map(lambda x: f'{x}'
if x>1 or (keep_ones and len(fro)>1) else
'\\d+',fro))
+f',{end[mode]}',
','.join(map(str,to))+',').groups()
except AttributeError:
+f',{end[mode]}'),','.join(map(str,to))+',')
assert match
except AssertionError:
raise ValueError(f'Shapes can not be shifted {fro} --> {to}')
fill = ()
grp: Sequence[str] = match.groups()
fill: Tuple[int, ...] = ()
for g,d in zip(grp,fro+(None,)):
fill += (1,)*g.count(',')+(d,)
return fill[:-1]
def shapeblender(a,b):
def shapeblender(a: Tuple[int, ...], b: Tuple[int, ...]) -> Tuple[int, ...]:
"""
Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'.
@ -476,7 +487,7 @@ def shapeblender(a,b):
return a + b[i:]
def extend_docstring(extra_docstring):
def extend_docstring(extra_docstring: str) -> Callable:
"""
Decorator: Append to function's docstring.
@ -492,7 +503,7 @@ def extend_docstring(extra_docstring):
return _decorator
def extended_docstring(f,extra_docstring):
def extended_docstring(f: Callable, extra_docstring: str) -> Callable:
"""
Decorator: Combine another function's docstring with a given docstring.
@ -510,7 +521,7 @@ def extended_docstring(f,extra_docstring):
return _decorator
def DREAM3D_base_group(fname):
def DREAM3D_base_group(fname: Union[str, pathlib.Path]) -> str:
"""
Determine the base group of a DREAM.3D file.
@ -536,7 +547,7 @@ def DREAM3D_base_group(fname):
return base_group
def DREAM3D_cell_data_group(fname):
def DREAM3D_cell_data_group(fname: Union[str, pathlib.Path]) -> str:
"""
Determine the cell data group of a DREAM.3D file.
@ -568,7 +579,7 @@ def DREAM3D_cell_data_group(fname):
return cell_data_group
def Bravais_to_Miller(*,uvtw=None,hkil=None):
def Bravais_to_Miller(*, uvtw: np.ndarray = None, hkil: np.ndarray = None) -> np.ndarray:
"""
Transform 4 MillerBravais indices to 3 Miller indices of crystal direction [uvw] or plane normal (hkl).
@ -595,7 +606,7 @@ def Bravais_to_Miller(*,uvtw=None,hkil=None):
return np.einsum('il,...l',basis,axis)
def Miller_to_Bravais(*,uvw=None,hkl=None):
def Miller_to_Bravais(*, uvw: np.ndarray = None, hkl: np.ndarray = None) -> np.ndarray:
"""
Transform 3 Miller indices to 4 MillerBravais indices of crystal direction [uvtw] or plane normal (hkil).
@ -624,7 +635,7 @@ def Miller_to_Bravais(*,uvw=None,hkl=None):
return np.einsum('il,...l',basis,axis)
def dict_prune(d):
def dict_prune(d: Dict[Any, Any]) -> Dict[Any, Any]:
"""
Recursively remove empty dictionaries.
@ -650,7 +661,7 @@ def dict_prune(d):
return new
def dict_flatten(d):
def dict_flatten(d: Dict[Any, Any]) -> Dict[Any, Any]:
"""
Recursively remove keys of single-entry dictionaries.
@ -685,7 +696,7 @@ class _ProgressBar:
Works for 0-based loops, ETA is estimated by linear extrapolation.
"""
def __init__(self,total,prefix,bar_length):
def __init__(self, total: int, prefix: str, bar_length: int):
"""
Set current time as basis for ETA estimation.
@ -708,7 +719,7 @@ class _ProgressBar:
sys.stderr.write(f"{self.prefix} {''*self.bar_length} 0% ETA n/a")
sys.stderr.flush()
def update(self,iteration):
def update(self, iteration: int) -> None:
fraction = (iteration+1) / self.total
filled_length = int(self.bar_length * fraction)