util imports need prefix instead of __all__ definition to prevent namespace pollution

This commit is contained in:
d.mentock 2022-06-10 12:00:54 +02:00
parent b3b14e9104
commit afbafd1d98
1 changed files with 110 additions and 128 deletions

View File

@ -1,43 +1,25 @@
"""Miscellaneous helper functionality."""
import sys
import datetime
import os
import subprocess
import shlex
import re
import signal
import fractions
from collections import abc
from functools import reduce, partial
from typing import Callable, Union, Iterable, Sequence, Dict, List, Tuple, Literal, Any, Collection, TextIO
from pathlib import Path
import sys as _sys
import datetime as _datetime
import os as _os
import subprocess as _subprocess
import shlex as _shlex
import re as _re
import signal as _signal
import fractions as _fractions
from collections import abc as _abc
from functools import reduce as _reduce, partial as _partial
from typing import Callable as _Callable, Union as _Union, Iterable as _Iterable, Sequence as _Sequence, Dict as _Dict, \
List as _List, Tuple as _Tuple, Literal as _Literal, Any as _Any, Collection as _Collection, TextIO as _TextIO
from pathlib import Path as _Path
import numpy as np
import h5py
import numpy as _np
import h5py as _h5py
from . import version
from ._typehints import FloatSequence, NumpyRngSeed, IntCollection, FileHandle
# limit visibility
__all__=[
'srepr',
'emph', 'deemph', 'warn', 'strikeout',
'run',
'open_text',
'natural_sort',
'show_progress',
'scale_to_coprime',
'project_equal_angle', 'project_equal_area',
'hybrid_IA',
'execution_stamp',
'shapeshifter', 'shapeblender',
'extend_docstring', 'extended_docstring',
'Bravais_to_Miller', 'Miller_to_Bravais',
'DREAM3D_base_group', 'DREAM3D_cell_data_group',
'dict_prune', 'dict_flatten',
'tail_repack',
]
from . import version as _version
from ._typehints import FloatSequence as _FloatSequence, NumpyRngSeed as _NumpyRngSeed, IntCollection as _IntCollection, \
FileHandle as _FileHandle
# https://svn.blender.org/svnroot/bf-blender/trunk/blender/build_files/scons/tools/bcolors.py
# https://stackoverflow.com/questions/287871
@ -154,8 +136,8 @@ def strikeout(msg) -> str:
def run(cmd: str,
wd: str = './',
env: Dict[str, str] = None,
timeout: int = None) -> Tuple[str, str]:
env: _Dict[str, str] = None,
timeout: int = None) -> _Tuple[str, str]:
"""
Run a command.
@ -178,26 +160,26 @@ def run(cmd: str,
"""
def pass_signal(sig,_,proc,default):
proc.send_signal(sig)
signal.signal(sig,default)
signal.raise_signal(sig)
_signal.signal(sig,default)
_signal.raise_signal(sig)
signals = [signal.SIGINT,signal.SIGTERM]
signals = [_signal.SIGINT,_signal.SIGTERM]
print(f"running '{cmd}' in '{wd}'")
process = subprocess.Popen(shlex.split(cmd),
stdout = subprocess.PIPE,
stderr = subprocess.PIPE,
env = os.environ if env is None else env,
process = _subprocess.Popen(_shlex.split(cmd),
stdout = _subprocess.PIPE,
stderr = _subprocess.PIPE,
env = _os.environ if env is None else env,
cwd = wd,
encoding = 'utf-8')
# ensure that process is terminated (https://stackoverflow.com/questions/22916783)
sig_states = [signal.signal(sig,partial(pass_signal,proc=process,default=signal.getsignal(sig))) for sig in signals]
sig_states = [_signal.signal(sig,_partial(pass_signal,proc=process,default=_signal.getsignal(sig))) for sig in signals]
try:
stdout,stderr = process.communicate(timeout=timeout)
finally:
for sig,state in zip(signals,sig_states):
signal.signal(sig,state)
_signal.signal(sig,state)
if process.returncode != 0:
print(stdout)
@ -207,8 +189,8 @@ def run(cmd: str,
return stdout, stderr
def open_text(fname: FileHandle,
mode: Literal['r','w'] = 'r') -> TextIO:
def open_text(fname: _FileHandle,
mode: _Literal['r','w'] = 'r') -> _TextIO:
"""
Open a text file.
@ -224,11 +206,11 @@ def open_text(fname: FileHandle,
f : file handle
"""
return fname if not isinstance(fname, (str,Path)) else \
open(Path(fname).expanduser(),mode,newline=('\n' if mode == 'w' else None))
return fname if not isinstance(fname, (str,_Path)) else \
open(_Path(fname).expanduser(),mode,newline=('\n' if mode == 'w' else None))
def natural_sort(key: str) -> List[Union[int, str]]:
def natural_sort(key: str) -> _List[_Union[int, str]]:
"""
Natural sort.
@ -240,13 +222,13 @@ def natural_sort(key: str) -> List[Union[int, str]]:
"""
convert = lambda text: int(text) if text.isdigit() else text
return [ convert(c) for c in re.split('([0-9]+)', key) ]
return [ convert(c) for c in _re.split('([0-9]+)', key) ]
def show_progress(iterable: Iterable,
def show_progress(iterable: _Iterable,
N_iter: int = None,
prefix: str = '',
bar_length: int = 50) -> Any:
bar_length: int = 50) -> _Any:
"""
Decorate a loop with a progress bar.
@ -264,7 +246,7 @@ def show_progress(iterable: Iterable,
Length of progress bar in characters. Defaults to 50.
"""
if isinstance(iterable,abc.Sequence):
if isinstance(iterable,_abc.Sequence):
if N_iter is None:
N = len(iterable)
else:
@ -285,7 +267,7 @@ def show_progress(iterable: Iterable,
status.update(i)
def scale_to_coprime(v: FloatSequence) -> np.ndarray:
def scale_to_coprime(v: _FloatSequence) -> _np.ndarray:
"""
Scale vector to co-prime (relatively prime) integers.
@ -304,30 +286,30 @@ def scale_to_coprime(v: FloatSequence) -> np.ndarray:
def get_square_denominator(x):
"""Denominator of the square of a number."""
return fractions.Fraction(x ** 2).limit_denominator(MAX_DENOMINATOR).denominator
return _fractions.Fraction(x ** 2).limit_denominator(MAX_DENOMINATOR).denominator
def lcm(a,b):
"""Least common multiple."""
try:
return np.lcm(a,b) # numpy > 1.18
return _np.lcm(a,b) # numpy > 1.18
except AttributeError:
return a * b // np.gcd(a, b)
return a * b // _np.gcd(a, b)
v_ = np.array(v)
m = (v_ * reduce(lcm, map(lambda x: int(get_square_denominator(x)),v_))**0.5).astype(np.int64)
m = m//reduce(np.gcd,m)
v_ = _np.array(v)
m = (v_ * _reduce(lcm, map(lambda x: int(get_square_denominator(x)),v_))**0.5).astype(_np.int64)
m = m//_reduce(_np.gcd,m)
with np.errstate(invalid='ignore'):
if not np.allclose(np.ma.masked_invalid(v_/m),v_[np.argmax(abs(v_))]/m[np.argmax(abs(v_))]):
with _np.errstate(invalid='ignore'):
if not _np.allclose(_np.ma.masked_invalid(v_/m),v_[_np.argmax(abs(v_))]/m[_np.argmax(abs(v_))]):
raise ValueError(f'invalid result "{m}" for input "{v_}"')
return m
def project_equal_angle(vector: np.ndarray,
direction: Literal['x', 'y', 'z'] = 'z',
def project_equal_angle(vector: _np.ndarray,
direction: _Literal['x', 'y', 'z'] = 'z',
normalize: bool = True,
keepdims: bool = False) -> np.ndarray:
keepdims: bool = False) -> _np.ndarray:
"""
Apply equal-angle projection to vector.
@ -367,15 +349,15 @@ def project_equal_angle(vector: np.ndarray,
"""
shift = 'zyx'.index(direction)
v = np.roll(vector/np.linalg.norm(vector,axis=-1,keepdims=True) if normalize else vector,
v = _np.roll(vector/_np.linalg.norm(vector,axis=-1,keepdims=True) if normalize else vector,
shift,axis=-1)
return np.roll(np.block([v[...,:2]/(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
return _np.roll(_np.block([v[...,:2]/(1.0+_np.abs(v[...,2:3])),_np.zeros_like(v[...,2:3])]),
-shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]
def project_equal_area(vector: np.ndarray,
direction: Literal['x', 'y', 'z'] = 'z',
def project_equal_area(vector: _np.ndarray,
direction: _Literal['x', 'y', 'z'] = 'z',
normalize: bool = True,
keepdims: bool = False) -> np.ndarray:
keepdims: bool = False) -> _np.ndarray:
"""
Apply equal-area projection to vector.
@ -416,22 +398,22 @@ def project_equal_area(vector: np.ndarray,
"""
shift = 'zyx'.index(direction)
v = np.roll(vector/np.linalg.norm(vector,axis=-1,keepdims=True) if normalize else vector,
v = _np.roll(vector/_np.linalg.norm(vector,axis=-1,keepdims=True) if normalize else vector,
shift,axis=-1)
return np.roll(np.block([v[...,:2]/np.sqrt(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
return _np.roll(_np.block([v[...,:2]/_np.sqrt(1.0+_np.abs(v[...,2:3])),_np.zeros_like(v[...,2:3])]),
-shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]
def execution_stamp(class_name: str,
function_name: str = None) -> str:
"""Timestamp the execution of a (function within a) class."""
now = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S%z')
now = _datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S%z')
_function_name = '' if function_name is None else f'.{function_name}'
return f'damask.{class_name}{_function_name} v{version} ({now})'
return f'damask.{class_name}{_function_name} v{_version} ({now})'
def hybrid_IA(dist: np.ndarray,
def hybrid_IA(dist: _np.ndarray,
N: int,
rng_seed: NumpyRngSeed = None) -> np.ndarray:
rng_seed: _NumpyRngSeed = None) -> _np.ndarray:
"""
Hybrid integer approximation.
@ -446,23 +428,23 @@ def hybrid_IA(dist: np.ndarray,
If None, then fresh, unpredictable entropy will be pulled from the OS.
"""
N_opt_samples,N_inv_samples = (max(np.count_nonzero(dist),N),0) # random subsampling if too little samples requested
N_opt_samples,N_inv_samples = (max(_np.count_nonzero(dist),N),0) # random subsampling if too little samples requested
scale_,scale,inc_factor = (0.0,float(N_opt_samples),1.0)
while (not np.isclose(scale, scale_)) and (N_inv_samples != N_opt_samples):
repeats = np.rint(scale*dist).astype(np.int64)
N_inv_samples = np.sum(repeats)
while (not _np.isclose(scale, scale_)) and (N_inv_samples != N_opt_samples):
repeats = _np.rint(scale*dist).astype(_np.int64)
N_inv_samples = _np.sum(repeats)
scale_,scale,inc_factor = (scale,scale+inc_factor*0.5*(scale - scale_), inc_factor*2.0) \
if N_inv_samples < N_opt_samples else \
(scale_,0.5*(scale_ + scale), 1.0)
return np.repeat(np.arange(len(dist)),repeats)[np.random.default_rng(rng_seed).permutation(N_inv_samples)[:N]]
return _np.repeat(_np.arange(len(dist)),repeats)[_np.random.default_rng(rng_seed).permutation(N_inv_samples)[:N]]
def shapeshifter(fro: Tuple[int, ...],
to: Tuple[int, ...],
mode: Literal['left','right'] = 'left',
keep_ones: bool = False) -> Tuple[int, ...]:
def shapeshifter(fro: _Tuple[int, ...],
to: _Tuple[int, ...],
mode: _Literal['left','right'] = 'left',
keep_ones: bool = False) -> _Tuple[int, ...]:
"""
Return dimensions that reshape 'fro' to become broadcastable to 'to'.
@ -509,7 +491,7 @@ def shapeshifter(fro: Tuple[int, ...],
fro = (1,) if len(fro) == 0 else fro
to = (1,) if len(to) == 0 else to
try:
match = re.match(beg[mode]
match = _re.match(beg[mode]
+f',{sep[mode]}'.join(map(lambda x: f'{x}'
if x>1 or (keep_ones and len(fro)>1) else
'\\d+',fro))
@ -518,14 +500,14 @@ def shapeshifter(fro: Tuple[int, ...],
grp = match.groups()
except AssertionError:
raise ValueError(f'shapes cannot be shifted {fro} --> {to}')
fill: Any = ()
fill: _Any = ()
for g,d in zip(grp,fro+(None,)):
fill += (1,)*g.count(',')+(d,)
return fill[:-1]
def shapeblender(a: Tuple[int, ...],
b: Tuple[int, ...]) -> Tuple[int, ...]:
def shapeblender(a: _Tuple[int, ...],
b: _Tuple[int, ...]) -> _Tuple[int, ...]:
"""
Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'.
@ -553,7 +535,7 @@ def shapeblender(a: Tuple[int, ...],
return a + b[i:]
def extend_docstring(extra_docstring: str) -> Callable:
def extend_docstring(extra_docstring: str) -> _Callable:
"""
Decorator: Append to function's docstring.
@ -569,8 +551,8 @@ def extend_docstring(extra_docstring: str) -> Callable:
return _decorator
def extended_docstring(f: Callable,
extra_docstring: str) -> Callable:
def extended_docstring(f: _Callable,
extra_docstring: str) -> _Callable:
"""
Decorator: Combine another function's docstring with a given docstring.
@ -588,7 +570,7 @@ def extended_docstring(f: Callable,
return _decorator
def DREAM3D_base_group(fname: Union[str, Path]) -> str:
def DREAM3D_base_group(fname: _Union[str, _Path]) -> str:
"""
Determine the base group of a DREAM.3D file.
@ -606,7 +588,7 @@ def DREAM3D_base_group(fname: Union[str, Path]) -> str:
Path to the base group.
"""
with h5py.File(Path(fname).expanduser(),'r') as f:
with _h5py.File(_Path(fname).expanduser(),'r') as f:
base_group = f.visit(lambda path: path.rsplit('/',2)[0] if '_SIMPL_GEOMETRY/SPACING' in path else None)
if base_group is None:
@ -614,7 +596,7 @@ def DREAM3D_base_group(fname: Union[str, Path]) -> str:
return base_group
def DREAM3D_cell_data_group(fname: Union[str, Path]) -> str:
def DREAM3D_cell_data_group(fname: _Union[str, _Path]) -> str:
"""
Determine the cell data group of a DREAM.3D file.
@ -634,10 +616,10 @@ def DREAM3D_cell_data_group(fname: Union[str, Path]) -> str:
"""
base_group = DREAM3D_base_group(fname)
with h5py.File(Path(fname).expanduser(),'r') as f:
with _h5py.File(_Path(fname).expanduser(),'r') as f:
cells = tuple(f['/'.join([base_group,'_SIMPL_GEOMETRY','DIMENSIONS'])][()][::-1])
cell_data_group = f[base_group].visititems(lambda path,obj: path.split('/')[0] \
if isinstance(obj,h5py._hl.dataset.Dataset) and np.shape(obj)[:-1] == cells \
if isinstance(obj,_h5py._hl.dataset.Dataset) and _np.shape(obj)[:-1] == cells \
else None)
if cell_data_group is None:
@ -647,8 +629,8 @@ def DREAM3D_cell_data_group(fname: Union[str, Path]) -> str:
def Bravais_to_Miller(*,
uvtw: np.ndarray = None,
hkil: np.ndarray = None) -> np.ndarray:
uvtw: _np.ndarray = None,
hkil: _np.ndarray = None) -> _np.ndarray:
"""
Transform 4 MillerBravais indices to 3 Miller indices of crystal direction [uvw] or plane normal (hkl).
@ -665,19 +647,19 @@ def Bravais_to_Miller(*,
"""
if (uvtw is not None) ^ (hkil is None):
raise KeyError('specify either "uvtw" or "hkil"')
axis,basis = (np.array(uvtw),np.array([[1,0,-1,0],
axis,basis = (_np.array(uvtw),_np.array([[1,0,-1,0],
[0,1,-1,0],
[0,0, 0,1]])) \
if hkil is None else \
(np.array(hkil),np.array([[1,0,0,0],
(_np.array(hkil),_np.array([[1,0,0,0],
[0,1,0,0],
[0,0,0,1]]))
return np.einsum('il,...l',basis,axis)
return _np.einsum('il,...l',basis,axis)
def Miller_to_Bravais(*,
uvw: np.ndarray = None,
hkl: np.ndarray = None) -> np.ndarray:
uvw: _np.ndarray = None,
hkl: _np.ndarray = None) -> _np.ndarray:
"""
Transform 3 Miller indices to 4 MillerBravais indices of crystal direction [uvtw] or plane normal (hkil).
@ -694,19 +676,19 @@ def Miller_to_Bravais(*,
"""
if (uvw is not None) ^ (hkl is None):
raise KeyError('specify either "uvw" or "hkl"')
axis,basis = (np.array(uvw),np.array([[ 2,-1, 0],
axis,basis = (_np.array(uvw),_np.array([[ 2,-1, 0],
[-1, 2, 0],
[-1,-1, 0],
[ 0, 0, 3]])/3) \
if hkl is None else \
(np.array(hkl),np.array([[ 1, 0, 0],
(_np.array(hkl),_np.array([[ 1, 0, 0],
[ 0, 1, 0],
[-1,-1, 0],
[ 0, 0, 1]]))
return np.einsum('il,...l',basis,axis)
return _np.einsum('il,...l',basis,axis)
def dict_prune(d: Dict) -> Dict:
def dict_prune(d: _Dict) -> _Dict:
"""
Recursively remove empty dictionaries.
@ -732,7 +714,7 @@ def dict_prune(d: Dict) -> Dict:
return new
def dict_flatten(d: Dict) -> Dict:
def dict_flatten(d: _Dict) -> _Dict:
"""
Recursively remove keys of single-entry dictionaries.
@ -756,8 +738,8 @@ def dict_flatten(d: Dict) -> Dict:
return new
def tail_repack(extended: Union[str, Sequence[str]],
existing: List[str] = []) -> List[str]:
def tail_repack(extended: _Union[str, _Sequence[str]],
existing: _List[str] = []) -> _List[str]:
"""
Repack tailing characters into single string if all are new.
@ -782,11 +764,11 @@ def tail_repack(extended: Union[str, Sequence[str]],
"""
return [extended] if isinstance(extended,str) else existing + \
([''.join(extended[len(existing):])] if np.prod([len(i) for i in extended[len(existing):]]) == 1 else
([''.join(extended[len(existing):])] if _np.prod([len(i) for i in extended[len(existing):]]) == 1 else
list(extended[len(existing):]))
def aslist(arg: Union[IntCollection,int,None]) -> List:
def aslist(arg: _Union[_IntCollection, int, None]) -> _List:
"""
Transform argument to list.
@ -801,7 +783,7 @@ def aslist(arg: Union[IntCollection,int,None]) -> List:
Entity transformed into list.
"""
return [] if arg is None else list(arg) if isinstance(arg,(np.ndarray,Collection)) else [arg]
return [] if arg is None else list(arg) if isinstance(arg,(_np.ndarray,_Collection)) else [arg]
####################################################################################################
@ -834,11 +816,11 @@ class ProgressBar:
self.total = total
self.prefix = prefix
self.bar_length = bar_length
self.time_start = self.time_last_update = datetime.datetime.now()
self.time_start = self.time_last_update = _datetime.datetime.now()
self.fraction_last = 0.0
sys.stderr.write(f"{self.prefix} {''*self.bar_length} 0% ETA n/a")
sys.stderr.flush()
_sys.stderr.write(f"{self.prefix} {''*self.bar_length} 0% ETA n/a")
_sys.stderr.flush()
def update(self,
iteration: int) -> None:
@ -846,17 +828,17 @@ class ProgressBar:
fraction = (iteration+1) / self.total
if (filled_length := int(self.bar_length * fraction)) > int(self.bar_length * self.fraction_last) or \
datetime.datetime.now() - self.time_last_update > datetime.timedelta(seconds=10):
self.time_last_update = datetime.datetime.now()
_datetime.datetime.now() - self.time_last_update > _datetime.timedelta(seconds=10):
self.time_last_update = _datetime.datetime.now()
bar = '' * filled_length + '' * (self.bar_length - filled_length)
remaining_time = (datetime.datetime.now() - self.time_start) \
remaining_time = (_datetime.datetime.now() - self.time_start) \
* (self.total - (iteration+1)) / (iteration+1)
remaining_time -= datetime.timedelta(microseconds=remaining_time.microseconds) # remove μs
sys.stderr.write(f'\r{self.prefix} {bar} {fraction:>4.0%} ETA {remaining_time}')
sys.stderr.flush()
remaining_time -= _datetime.timedelta(microseconds=remaining_time.microseconds) # remove μs
_sys.stderr.write(f'\r{self.prefix} {bar} {fraction:>4.0%} ETA {remaining_time}')
_sys.stderr.flush()
self.fraction_last = fraction
if iteration == self.total - 1:
sys.stderr.write('\n')
sys.stderr.flush()
_sys.stderr.write('\n')
_sys.stderr.flush()