correcting types

Not really sure if srepr and friends take really 'Any'. They take
everything that can be casted (piecewise) to a string. So keep it open
at the moment and leave out a typehint
This commit is contained in:
Martin Diehl 2022-01-21 23:50:16 +01:00
parent 76ccd4aaaa
commit 7e9ce682e7
1 changed files with 54 additions and 52 deletions

View File

@ -9,13 +9,13 @@ import re
import fractions import fractions
from functools import reduce from functools import reduce
from typing import Union, Tuple, Sequence, Callable, Dict, List, Any, Literal, Optional from typing import Union, Tuple, Sequence, Callable, Dict, List, Any, Literal, Optional
import pathlib from pathlib import Path
import numpy as np import numpy as np
import h5py import h5py
from ._typehints import IntSequence
from . import version from . import version
from ._typehints import FloatSequence, IntSequence
# limit visibility # limit visibility
__all__=[ __all__=[
@ -53,16 +53,16 @@ _colors = {
#################################################################################################### ####################################################################################################
# Functions # Functions
#################################################################################################### ####################################################################################################
def srepr(arg: Union[np.ndarray, Sequence[Any]], glue: str = '\n') -> str: def srepr(msg, glue: str = '\n') -> str:
r""" r"""
Join items with glue string. Join items with glue string.
Parameters Parameters
---------- ----------
arg : iterable msg : object with __repr__ or sequence of objects with __repr__
Items to join. Items to join.
glue : str, optional glue : str, optional
Glue used for joining operation. Defaults to \n. Glue used for joining operation. Defaults to '\n'.
Returns Returns
------- -------
@ -70,21 +70,21 @@ def srepr(arg: Union[np.ndarray, Sequence[Any]], glue: str = '\n') -> str:
String representation of the joined items. String representation of the joined items.
""" """
if (not hasattr(arg, 'strip') and if (not hasattr(msg, 'strip') and
(hasattr(arg, '__getitem__') or (hasattr(msg, '__getitem__') or
hasattr(arg, '__iter__'))): hasattr(msg, '__iter__'))):
return glue.join(str(x) for x in arg) return glue.join(str(x) for x in msg)
else: else:
return arg if isinstance(arg,str) else repr(arg) return msg if isinstance(msg,str) else repr(msg)
def emph(what: Any) -> str: def emph(msg) -> str:
""" """
Format with emphasis. Format with emphasis.
Parameters Parameters
---------- ----------
what : object with __repr__ or iterable of objects with __repr__. msg : object with __repr__ or sequence of objects with __repr__
Message to format. Message to format.
Returns Returns
@ -93,15 +93,15 @@ def emph(what: Any) -> str:
Formatted string representation of the joined items. Formatted string representation of the joined items.
""" """
return _colors['bold']+srepr(what)+_colors['end_color'] return _colors['bold']+srepr(msg)+_colors['end_color']
def deemph(what: Any) -> str: def deemph(msg) -> str:
""" """
Format with deemphasis. Format with deemphasis.
Parameters Parameters
---------- ----------
what : object with __repr__ or iterable of objects with __repr__. msg : object with __repr__ or sequence of objects with __repr__
Message to format. Message to format.
Returns Returns
@ -110,15 +110,15 @@ def deemph(what: Any) -> str:
Formatted string representation of the joined items. Formatted string representation of the joined items.
""" """
return _colors['dim']+srepr(what)+_colors['end_color'] return _colors['dim']+srepr(msg)+_colors['end_color']
def warn(what: Any) -> str: def warn(msg) -> str:
""" """
Format for warning. Format for warning.
Parameters Parameters
---------- ----------
what : object with __repr__ or iterable of objects with __repr__. msg : object with __repr__ or sequence of objects with __repr__
Message to format. Message to format.
Returns Returns
@ -127,15 +127,15 @@ def warn(what: Any) -> str:
Formatted string representation of the joined items. Formatted string representation of the joined items.
""" """
return _colors['warning']+emph(what)+_colors['end_color'] return _colors['warning']+emph(msg)+_colors['end_color']
def strikeout(what: Any) -> str: def strikeout(msg) -> str:
""" """
Format as strikeout. Format as strikeout.
Parameters Parameters
---------- ----------
what : object with __repr__ or iterable of objects with __repr__. msg : object with __repr__ or iterable of objects with __repr__
Message to format. Message to format.
Returns Returns
@ -144,10 +144,10 @@ def strikeout(what: Any) -> str:
Formatted string representation of the joined items. Formatted string representation of the joined items.
""" """
return _colors['crossout']+srepr(what)+_colors['end_color'] return _colors['crossout']+srepr(msg)+_colors['end_color']
def run(cmd: str, wd: str = './', env: Dict[str, Any] = None, timeout: int = None) -> Tuple[str, str]: def run(cmd: str, wd: str = './', env: Dict[str, str] = None, timeout: int = None) -> Tuple[str, str]:
""" """
Run a command. Run a command.
@ -156,7 +156,7 @@ def run(cmd: str, wd: str = './', env: Dict[str, Any] = None, timeout: int = Non
cmd : str cmd : str
Command to be executed. Command to be executed.
wd : str, optional wd : str, optional
Working directory of process. Defaults to ./ . Working directory of process. Defaults to './'.
env : dict, optional env : dict, optional
Environment for execution. Environment for execution.
timeout : integer, optional timeout : integer, optional
@ -235,18 +235,18 @@ def show_progress(iterable: Sequence[Any],
status.update(i) status.update(i)
def scale_to_coprime(v: np.ndarray) -> np.ndarray: def scale_to_coprime(v: FloatSequence) -> np.ndarray:
""" """
Scale vector to co-prime (relatively prime) integers. Scale vector to co-prime (relatively prime) integers.
Parameters Parameters
---------- ----------
v : numpy.ndarray of shape (:) v : sequence of float, len (:)
Vector to scale. Vector to scale.
Returns Returns
------- -------
m : numpy.ndarray of shape (:) m : numpy.ndarray, shape (:)
Vector scaled to co-prime numbers. Vector scaled to co-prime numbers.
""" """
@ -263,27 +263,30 @@ def scale_to_coprime(v: np.ndarray) -> np.ndarray:
except AttributeError: except AttributeError:
return a * b // np.gcd(a, b) return a * b // np.gcd(a, b)
m = (np.array(v) * reduce(lcm, map(lambda x: int(get_square_denominator(x)),v)) ** 0.5).astype(int) v_ = np.array(v)
m = (v_ * reduce(lcm, map(lambda x: int(get_square_denominator(x)),v_))**0.5).astype(int)
m = m//reduce(np.gcd,m) m = m//reduce(np.gcd,m)
with np.errstate(invalid='ignore'): with np.errstate(invalid='ignore'):
if not np.allclose(np.ma.masked_invalid(v/m),v[np.argmax(abs(v))]/m[np.argmax(abs(v))]): if not np.allclose(np.ma.masked_invalid(v_/m),v_[np.argmax(abs(v_))]/m[np.argmax(abs(v_))]):
raise ValueError(f'Invalid result {m} for input {v}. Insufficient precision?') raise ValueError(f'Invalid result {m} for input {v_}. Insufficient precision?')
return m return m
def project_equal_angle(vector: np.ndarray, direction: str = 'z', normalize: bool = True, keepdims: bool = False) -> np.ndarray: def project_equal_angle(vector: np.ndarray,
direction: Literal['x', 'y', 'z'] = 'z',
normalize: bool = True,
keepdims: bool = False) -> np.ndarray:
""" """
Apply equal-angle projection to vector. Apply equal-angle projection to vector.
Parameters Parameters
---------- ----------
vector : numpy.ndarray of shape (...,3) vector : numpy.ndarray, shape (...,3)
Vector coordinates to be projected. Vector coordinates to be projected.
direction : str direction : {'x', 'y', 'z'}
Projection direction 'x', 'y', or 'z'. Projection direction. Defaults to 'z'.
Defaults to 'z'.
normalize : bool normalize : bool
Ensure unit length of input vector. Defaults to True. Ensure unit length of input vector. Defaults to True.
keepdims : bool keepdims : bool
@ -326,9 +329,8 @@ def project_equal_area(vector: np.ndarray,
---------- ----------
vector : numpy.ndarray, shape (...,3) vector : numpy.ndarray, shape (...,3)
Vector coordinates to be projected. Vector coordinates to be projected.
direction : str direction : {'x', 'y', 'z'}
Projection direction 'x', 'y', or 'z'. Projection direction. Defaults to 'z'.
Defaults to 'z'.
normalize : bool normalize : bool
Ensure unit length of input vector. Defaults to True. Ensure unit length of input vector. Defaults to True.
keepdims : bool keepdims : bool
@ -367,7 +369,7 @@ def execution_stamp(class_name: str, function_name: str = None) -> str:
return f'damask.{class_name}{_function_name} v{version} ({now})' return f'damask.{class_name}{_function_name} v{version} ({now})'
def hybrid_IA(dist: np.ndarray, N: int, rng_seed: Union[int, IntSequence] = None) -> np.ndarray: def hybrid_IA(dist: np.ndarray, N: int, rng_seed = None) -> np.ndarray:
""" """
Hybrid integer approximation. Hybrid integer approximation.
@ -409,9 +411,9 @@ def shapeshifter(fro: Tuple[int, ...],
to : tuple to : tuple
Target shape of array after broadcasting. Target shape of array after broadcasting.
len(to) cannot be less than len(fro). len(to) cannot be less than len(fro).
mode : str, optional mode : {'left', 'right'}, optional
Indicates whether new axes are preferably added to Indicates whether new axes are preferably added to
either 'left' or 'right' of the original shape. either left or right of the original shape.
Defaults to 'left'. Defaults to 'left'.
keep_ones : bool, optional keep_ones : bool, optional
Treat '1' in fro as literal value instead of dimensional placeholder. Treat '1' in fro as literal value instead of dimensional placeholder.
@ -445,15 +447,15 @@ def shapeshifter(fro: Tuple[int, ...],
fro = (1,) if not len(fro) else fro fro = (1,) if not len(fro) else fro
to = (1,) if not len(to) else to to = (1,) if not len(to) else to
try: try:
match = re.match((beg[mode] match = re.match(beg[mode]
+f',{sep[mode]}'.join(map(lambda x: f'{x}' +f',{sep[mode]}'.join(map(lambda x: f'{x}'
if x>1 or (keep_ones and len(fro)>1) else if x>1 or (keep_ones and len(fro)>1) else
'\\d+',fro)) '\\d+',fro))
+f',{end[mode]}'),','.join(map(str,to))+',') +f',{end[mode]}',','.join(map(str,to))+',')
assert match assert match
grp = match.groups()
except AssertionError: except AssertionError:
raise ValueError(f'Shapes can not be shifted {fro} --> {to}') raise ValueError(f'Shapes can not be shifted {fro} --> {to}')
grp: Sequence[str] = match.groups()
fill: Tuple[Optional[int], ...] = () fill: Tuple[Optional[int], ...] = ()
for g,d in zip(grp,fro+(None,)): for g,d in zip(grp,fro+(None,)):
fill += (1,)*g.count(',')+(d,) fill += (1,)*g.count(',')+(d,)
@ -522,7 +524,7 @@ def extended_docstring(f: Callable, extra_docstring: str) -> Callable:
return _decorator return _decorator
def DREAM3D_base_group(fname: Union[str, pathlib.Path]) -> str: def DREAM3D_base_group(fname: Union[str, Path]) -> str:
""" """
Determine the base group of a DREAM.3D file. Determine the base group of a DREAM.3D file.
@ -548,7 +550,7 @@ def DREAM3D_base_group(fname: Union[str, pathlib.Path]) -> str:
return base_group return base_group
def DREAM3D_cell_data_group(fname: Union[str, pathlib.Path]) -> str: def DREAM3D_cell_data_group(fname: Union[str, Path]) -> str:
""" """
Determine the cell data group of a DREAM.3D file. Determine the cell data group of a DREAM.3D file.
@ -586,12 +588,12 @@ def Bravais_to_Miller(*, uvtw: np.ndarray = None, hkil: np.ndarray = None) -> np
Parameters Parameters
---------- ----------
uvtw|hkil : numpy.ndarray of shape (...,4) uvtw|hkil : numpy.ndarray, shape (...,4)
MillerBravais indices of crystallographic direction [uvtw] or plane normal (hkil). MillerBravais indices of crystallographic direction [uvtw] or plane normal (hkil).
Returns Returns
------- -------
uvw|hkl : numpy.ndarray of shape (...,3) uvw|hkl : numpy.ndarray, shape (...,3)
Miller indices of [uvw] direction or (hkl) plane normal. Miller indices of [uvw] direction or (hkl) plane normal.
""" """
@ -613,12 +615,12 @@ def Miller_to_Bravais(*, uvw: np.ndarray = None, hkl: np.ndarray = None) -> np.n
Parameters Parameters
---------- ----------
uvw|hkl : numpy.ndarray of shape (...,3) uvw|hkl : numpy.ndarray, shape (...,3)
Miller indices of crystallographic direction [uvw] or plane normal (hkl). Miller indices of crystallographic direction [uvw] or plane normal (hkl).
Returns Returns
------- -------
uvtw|hkil : numpy.ndarray of shape (...,4) uvtw|hkil : numpy.ndarray, shape (...,4)
MillerBravais indices of [uvtw] direction or (hkil) plane normal. MillerBravais indices of [uvtw] direction or (hkil) plane normal.
""" """
@ -636,7 +638,7 @@ def Miller_to_Bravais(*, uvw: np.ndarray = None, hkl: np.ndarray = None) -> np.n
return np.einsum('il,...l',basis,axis) return np.einsum('il,...l',basis,axis)
def dict_prune(d: Dict[Any, Any]) -> Dict[Any, Any]: def dict_prune(d: Dict) -> Dict:
""" """
Recursively remove empty dictionaries. Recursively remove empty dictionaries.
@ -662,7 +664,7 @@ def dict_prune(d: Dict[Any, Any]) -> Dict[Any, Any]:
return new return new
def dict_flatten(d: Dict[Any, Any]) -> Dict[Any, Any]: def dict_flatten(d: Dict) -> Dict:
""" """
Recursively remove keys of single-entry dictionaries. Recursively remove keys of single-entry dictionaries.