added typehints for util module

This commit is contained in:
Daniel Otto de Mentock 2022-01-17 14:58:08 +01:00
parent b796bc0697
commit adf7abbda6
1 changed files with 43 additions and 32 deletions

View File

@ -8,6 +8,8 @@ import shlex
import re import re
import fractions import fractions
from functools import reduce from functools import reduce
from typing import Union, Tuple, Sequence, Callable, Dict, List, Any, Literal
import pathlib
import numpy as np import numpy as np
import h5py import h5py
@ -50,7 +52,7 @@ _colors = {
#################################################################################################### ####################################################################################################
# Functions # Functions
#################################################################################################### ####################################################################################################
def srepr(arg,glue = '\n'): def srepr(arg: Union[np.ndarray, Sequence[Any]], glue: str = '\n') -> str:
r""" r"""
Join items with glue string. Join items with glue string.
@ -75,7 +77,7 @@ def srepr(arg,glue = '\n'):
return arg if isinstance(arg,str) else repr(arg) return arg if isinstance(arg,str) else repr(arg)
def emph(what): def emph(what: Any) -> str:
""" """
Format with emphasis. Format with emphasis.
@ -92,7 +94,7 @@ def emph(what):
""" """
return _colors['bold']+srepr(what)+_colors['end_color'] return _colors['bold']+srepr(what)+_colors['end_color']
def deemph(what): def deemph(what: Any) -> str:
""" """
Format with deemphasis. Format with deemphasis.
@ -109,7 +111,7 @@ def deemph(what):
""" """
return _colors['dim']+srepr(what)+_colors['end_color'] return _colors['dim']+srepr(what)+_colors['end_color']
def warn(what): def warn(what: Any) -> str:
""" """
Format for warning. Format for warning.
@ -126,7 +128,7 @@ def warn(what):
""" """
return _colors['warning']+emph(what)+_colors['end_color'] return _colors['warning']+emph(what)+_colors['end_color']
def strikeout(what): def strikeout(what: Any) -> str:
""" """
Format as strikeout. Format as strikeout.
@ -144,7 +146,7 @@ def strikeout(what):
return _colors['crossout']+srepr(what)+_colors['end_color'] return _colors['crossout']+srepr(what)+_colors['end_color']
def run(cmd,wd='./',env=None,timeout=None): def run(cmd: str, wd: str = './', env: Dict[str, Any] = None, timeout: int = None) -> Tuple[str, str]:
""" """
Run a command. Run a command.
@ -185,7 +187,7 @@ def run(cmd,wd='./',env=None,timeout=None):
execute = run execute = run
def natural_sort(key): def natural_sort(key: str) -> List[Union[int, str]]:
""" """
Natural sort. Natural sort.
@ -200,7 +202,10 @@ def natural_sort(key):
return [ convert(c) for c in re.split('([0-9]+)', key) ] return [ convert(c) for c in re.split('([0-9]+)', key) ]
def show_progress(iterable,N_iter=None,prefix='',bar_length=50): def show_progress(iterable: Sequence[Any],
N_iter: int = None,
prefix: str = '',
bar_length: int = 50) -> Any:
""" """
Decorate a loop with a progress bar. Decorate a loop with a progress bar.
@ -229,7 +234,7 @@ def show_progress(iterable,N_iter=None,prefix='',bar_length=50):
status.update(i) status.update(i)
def scale_to_coprime(v): def scale_to_coprime(v: np.ndarray) -> np.ndarray:
""" """
Scale vector to co-prime (relatively prime) integers. Scale vector to co-prime (relatively prime) integers.
@ -267,13 +272,13 @@ def scale_to_coprime(v):
return m return m
def project_equal_angle(vector,direction='z',normalize=True,keepdims=False): def project_equal_angle(vector: np.ndarray, direction: str = 'z', normalize: bool = True, keepdims: bool = False) -> np.ndarray:
""" """
Apply equal-angle projection to vector. Apply equal-angle projection to vector.
Parameters Parameters
---------- ----------
vector : numpy.ndarray, shape (...,3) vector : numpy.ndarray of shape (...,3)
Vector coordinates to be projected. Vector coordinates to be projected.
direction : str direction : str
Projection direction 'x', 'y', or 'z'. Projection direction 'x', 'y', or 'z'.
@ -309,7 +314,10 @@ def project_equal_angle(vector,direction='z',normalize=True,keepdims=False):
return np.roll(np.block([v[...,:2]/(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]), return np.roll(np.block([v[...,:2]/(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
-shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2] -shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]
def project_equal_area(vector,direction='z',normalize=True,keepdims=False): def project_equal_area(vector: np.ndarray,
direction: Literal['x', 'y', 'z'] = 'z',
normalize: bool = True,
keepdims: bool = False) -> np.ndarray:
""" """
Apply equal-area projection to vector. Apply equal-area projection to vector.
@ -351,15 +359,14 @@ def project_equal_area(vector,direction='z',normalize=True,keepdims=False):
return np.roll(np.block([v[...,:2]/np.sqrt(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]), return np.roll(np.block([v[...,:2]/np.sqrt(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
-shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2] -shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]
def execution_stamp(class_name: str, function_name: str = None) -> str:
def execution_stamp(class_name,function_name=None):
"""Timestamp the execution of a (function within a) class.""" """Timestamp the execution of a (function within a) class."""
now = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S%z') now = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S%z')
_function_name = '' if function_name is None else f'.{function_name}' _function_name = '' if function_name is None else f'.{function_name}'
return f'damask.{class_name}{_function_name} v{version} ({now})' return f'damask.{class_name}{_function_name} v{version} ({now})'
def hybrid_IA(dist,N,rng_seed=None): def hybrid_IA(dist: np.ndarray, N: int, rng_seed: Union[int, np.ndarray] = None) -> np.ndarray:
""" """
Hybrid integer approximation. Hybrid integer approximation.
@ -387,7 +394,10 @@ def hybrid_IA(dist,N,rng_seed=None):
return np.repeat(np.arange(len(dist)),repeats)[np.random.default_rng(rng_seed).permutation(N_inv_samples)[:N]] return np.repeat(np.arange(len(dist)),repeats)[np.random.default_rng(rng_seed).permutation(N_inv_samples)[:N]]
def shapeshifter(fro,to,mode='left',keep_ones=False): def shapeshifter(fro: Tuple[int, ...],
to: Tuple[int, ...],
mode: Literal['left','right'] = 'left',
keep_ones: bool = False) -> Tuple[int, ...]:
""" """
Return dimensions that reshape 'fro' to become broadcastable to 'to'. Return dimensions that reshape 'fro' to become broadcastable to 'to'.
@ -434,21 +444,22 @@ def shapeshifter(fro,to,mode='left',keep_ones=False):
fro = (1,) if not len(fro) else fro fro = (1,) if not len(fro) else fro
to = (1,) if not len(to) else to to = (1,) if not len(to) else to
try: try:
grp = re.match(beg[mode] match = re.match((beg[mode]
+f',{sep[mode]}'.join(map(lambda x: f'{x}' +f',{sep[mode]}'.join(map(lambda x: f'{x}'
if x>1 or (keep_ones and len(fro)>1) else if x>1 or (keep_ones and len(fro)>1) else
'\\d+',fro)) '\\d+',fro))
+f',{end[mode]}', +f',{end[mode]}'),','.join(map(str,to))+',')
','.join(map(str,to))+',').groups() assert match
except AttributeError: except AssertionError:
raise ValueError(f'Shapes can not be shifted {fro} --> {to}') raise ValueError(f'Shapes can not be shifted {fro} --> {to}')
fill = () grp: Sequence[str] = match.groups()
fill: Tuple[int, ...] = ()
for g,d in zip(grp,fro+(None,)): for g,d in zip(grp,fro+(None,)):
fill += (1,)*g.count(',')+(d,) fill += (1,)*g.count(',')+(d,)
return fill[:-1] return fill[:-1]
def shapeblender(a,b): def shapeblender(a: Tuple[int, ...], b: Tuple[int, ...]) -> Tuple[int, ...]:
""" """
Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'. Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'.
@ -476,7 +487,7 @@ def shapeblender(a,b):
return a + b[i:] return a + b[i:]
def extend_docstring(extra_docstring): def extend_docstring(extra_docstring: str) -> Callable:
""" """
Decorator: Append to function's docstring. Decorator: Append to function's docstring.
@ -492,7 +503,7 @@ def extend_docstring(extra_docstring):
return _decorator return _decorator
def extended_docstring(f,extra_docstring): def extended_docstring(f: Callable, extra_docstring: str) -> Callable:
""" """
Decorator: Combine another function's docstring with a given docstring. Decorator: Combine another function's docstring with a given docstring.
@ -510,7 +521,7 @@ def extended_docstring(f,extra_docstring):
return _decorator return _decorator
def DREAM3D_base_group(fname): def DREAM3D_base_group(fname: Union[str, pathlib.Path]) -> str:
""" """
Determine the base group of a DREAM.3D file. Determine the base group of a DREAM.3D file.
@ -536,7 +547,7 @@ def DREAM3D_base_group(fname):
return base_group return base_group
def DREAM3D_cell_data_group(fname): def DREAM3D_cell_data_group(fname: Union[str, pathlib.Path]) -> str:
""" """
Determine the cell data group of a DREAM.3D file. Determine the cell data group of a DREAM.3D file.
@ -568,7 +579,7 @@ def DREAM3D_cell_data_group(fname):
return cell_data_group return cell_data_group
def Bravais_to_Miller(*,uvtw=None,hkil=None): def Bravais_to_Miller(*, uvtw: np.ndarray = None, hkil: np.ndarray = None) -> np.ndarray:
""" """
Transform 4 MillerBravais indices to 3 Miller indices of crystal direction [uvw] or plane normal (hkl). Transform 4 MillerBravais indices to 3 Miller indices of crystal direction [uvw] or plane normal (hkl).
@ -595,7 +606,7 @@ def Bravais_to_Miller(*,uvtw=None,hkil=None):
return np.einsum('il,...l',basis,axis) return np.einsum('il,...l',basis,axis)
def Miller_to_Bravais(*,uvw=None,hkl=None): def Miller_to_Bravais(*, uvw: np.ndarray = None, hkl: np.ndarray = None) -> np.ndarray:
""" """
Transform 3 Miller indices to 4 MillerBravais indices of crystal direction [uvtw] or plane normal (hkil). Transform 3 Miller indices to 4 MillerBravais indices of crystal direction [uvtw] or plane normal (hkil).
@ -624,7 +635,7 @@ def Miller_to_Bravais(*,uvw=None,hkl=None):
return np.einsum('il,...l',basis,axis) return np.einsum('il,...l',basis,axis)
def dict_prune(d): def dict_prune(d: Dict[Any, Any]) -> Dict[Any, Any]:
""" """
Recursively remove empty dictionaries. Recursively remove empty dictionaries.
@ -650,7 +661,7 @@ def dict_prune(d):
return new return new
def dict_flatten(d): def dict_flatten(d: Dict[Any, Any]) -> Dict[Any, Any]:
""" """
Recursively remove keys of single-entry dictionaries. Recursively remove keys of single-entry dictionaries.
@ -685,7 +696,7 @@ class _ProgressBar:
Works for 0-based loops, ETA is estimated by linear extrapolation. Works for 0-based loops, ETA is estimated by linear extrapolation.
""" """
def __init__(self,total,prefix,bar_length): def __init__(self, total: int, prefix: str, bar_length: int):
""" """
Set current time as basis for ETA estimation. Set current time as basis for ETA estimation.
@ -708,7 +719,7 @@ class _ProgressBar:
sys.stderr.write(f"{self.prefix} {''*self.bar_length} 0% ETA n/a") sys.stderr.write(f"{self.prefix} {''*self.bar_length} 0% ETA n/a")
sys.stderr.flush() sys.stderr.flush()
def update(self,iteration): def update(self, iteration: int) -> None:
fraction = (iteration+1) / self.total fraction = (iteration+1) / self.total
filled_length = int(self.bar_length * fraction) filled_length = int(self.bar_length * fraction)