Merge branch 'typehints_vtk_util' into 'development'

02 Typehints vtk util

See merge request damask/DAMASK!502
This commit is contained in:
Philip Eisenlohr 2022-01-23 09:31:47 +00:00
commit 160eb1c600
4 changed files with 140 additions and 110 deletions

View File

@ -4,7 +4,7 @@ import warnings
import multiprocessing as mp
from functools import partial
import typing
from typing import Union, Optional, TextIO, List, Sequence
from typing import Union, Optional, TextIO, List, Sequence, Literal
from pathlib import Path
import numpy as np
@ -70,7 +70,7 @@ class Grid:
])
def __copy__(self) -> "Grid":
def __copy__(self) -> 'Grid':
"""Create deep copy."""
return copy.deepcopy(self)
@ -161,7 +161,7 @@ class Grid:
@staticmethod
def load(fname: Union[str, Path]) -> "Grid":
def load(fname: Union[str, Path]) -> 'Grid':
"""
Load from VTK image data file.
@ -190,7 +190,7 @@ class Grid:
@typing. no_type_check
@staticmethod
def load_ASCII(fname)-> "Grid":
def load_ASCII(fname)-> 'Grid':
"""
Load from geom file.
@ -264,7 +264,7 @@ class Grid:
@staticmethod
def load_Neper(fname: Union[str, Path]) -> "Grid":
def load_Neper(fname: Union[str, Path]) -> 'Grid':
"""
Load from Neper VTK file.
@ -279,7 +279,7 @@ class Grid:
Grid-based geometry from file.
"""
v = VTK.load(fname,'vtkImageData')
v = VTK.load(fname,'ImageData')
cells = np.array(v.vtk_data.GetDimensions())-1
bbox = np.array(v.vtk_data.GetBounds()).reshape(3,2).T
@ -292,7 +292,7 @@ class Grid:
def load_DREAM3D(fname: Union[str, Path],
feature_IDs: str = None, cell_data: str = None,
phases: str = 'Phases', Euler_angles: str = 'EulerAngles',
base_group: str = None) -> "Grid":
base_group: str = None) -> 'Grid':
"""
Load DREAM.3D (HDF5) file.
@ -354,7 +354,7 @@ class Grid:
@staticmethod
def from_table(table: Table,
coordinates: str,
labels: Union[str, Sequence[str]]) -> "Grid":
labels: Union[str, Sequence[str]]) -> 'Grid':
"""
Create grid from ASCII table.
@ -422,6 +422,7 @@ class Grid:
Grid-based geometry from tessellation.
"""
weights_p: FloatSequence
if periodic:
weights_p = np.tile(weights,27) # Laguerre weights (1,2,3,1,2,3,...,1,2,3)
seeds_p = np.vstack((seeds -np.array([size[0],0.,0.]),seeds, seeds +np.array([size[0],0.,0.])))
@ -452,7 +453,7 @@ class Grid:
size: FloatSequence,
seeds: np.ndarray,
material: IntSequence = None,
periodic: bool = True) -> "Grid":
periodic: bool = True) -> 'Grid':
"""
Create grid from Voronoi tessellation.
@ -538,7 +539,7 @@ class Grid:
surface: str,
threshold: float = 0.0,
periods: int = 1,
materials: IntSequence = (0,1)) -> "Grid":
materials: IntSequence = (0,1)) -> 'Grid':
"""
Create grid from definition of triply periodic minimal surface.
@ -684,7 +685,7 @@ class Grid:
fill: int = None,
R: Rotation = Rotation(),
inverse: bool = False,
periodic: bool = True) -> "Grid":
periodic: bool = True) -> 'Grid':
"""
Insert a primitive geometric object at a given position.
@ -769,7 +770,7 @@ class Grid:
)
def mirror(self, directions: Sequence[str], reflect: bool = False) -> "Grid":
def mirror(self, directions: Sequence[str], reflect: bool = False) -> 'Grid':
"""
Mirror grid along given directions.
@ -821,7 +822,7 @@ class Grid:
)
def flip(self, directions: Sequence[str]) -> "Grid":
def flip(self, directions: Union[Literal['x', 'y', 'z'], Sequence[Literal['x', 'y', 'z']]]) -> 'Grid':
"""
Flip grid along given directions.
@ -851,7 +852,7 @@ class Grid:
)
def scale(self, cells: IntSequence, periodic: bool = True) -> "Grid":
def scale(self, cells: IntSequence, periodic: bool = True) -> 'Grid':
"""
Scale grid to new cells.
@ -898,7 +899,7 @@ class Grid:
def clean(self,
stencil: int = 3,
selection: IntSequence = None,
periodic: bool = True) -> "Grid":
periodic: bool = True) -> 'Grid':
"""
Smooth grid by selecting most frequent material index within given stencil at each location.
@ -938,7 +939,7 @@ class Grid:
)
def renumber(self) -> "Grid":
def renumber(self) -> 'Grid':
"""
Renumber sorted material indices as 0,...,N-1.
@ -957,7 +958,7 @@ class Grid:
)
def rotate(self, R: Rotation, fill: int = None) -> "Grid":
def rotate(self, R: Rotation, fill: int = None) -> 'Grid':
"""
Rotate grid (pad if required).
@ -997,7 +998,7 @@ class Grid:
def canvas(self,
cells: IntSequence = None,
offset: IntSequence = None,
fill: int = None) -> "Grid":
fill: int = None) -> 'Grid':
"""
Crop or enlarge/pad grid.
@ -1048,7 +1049,7 @@ class Grid:
)
def substitute(self, from_material: IntSequence, to_material: IntSequence) -> "Grid":
def substitute(self, from_material: IntSequence, to_material: IntSequence) -> 'Grid':
"""
Substitute material indices.
@ -1076,7 +1077,7 @@ class Grid:
)
def sort(self) -> "Grid":
def sort(self) -> 'Grid':
"""
Sort material indices such that min(material) is located at (0,0,0).
@ -1102,7 +1103,7 @@ class Grid:
vicinity: int = 1,
offset: int = None,
trigger: IntSequence = [],
periodic: bool = True) -> "Grid":
periodic: bool = True) -> 'Grid':
"""
Offset material index of points in the vicinity of xxx.

View File

@ -2,6 +2,7 @@ import os
import warnings
import multiprocessing as mp
from pathlib import Path
from typing import Union, Literal, List
import numpy as np
import vtk
@ -9,6 +10,7 @@ from vtk.util.numpy_support import numpy_to_vtk as np_to_vtk
from vtk.util.numpy_support import numpy_to_vtkIdTypeArray as np_to_vtkIdTypeArray
from vtk.util.numpy_support import vtk_to_numpy as vtk_to_np
from ._typehints import FloatSequence, IntSequence
from . import util
from . import Table
@ -20,7 +22,7 @@ class VTK:
High-level interface to VTK.
"""
def __init__(self,vtk_data):
def __init__(self, vtk_data: vtk.vtkDataSet):
"""
New spatial visualization.
@ -36,7 +38,7 @@ class VTK:
@staticmethod
def from_image_data(cells,size,origin=np.zeros(3)):
def from_image_data(cells: IntSequence, size: FloatSequence, origin: FloatSequence = np.zeros(3)) -> 'VTK':
"""
Create VTK of type vtk.vtkImageData.
@ -60,13 +62,13 @@ class VTK:
vtk_data = vtk.vtkImageData()
vtk_data.SetDimensions(*(np.array(cells)+1))
vtk_data.SetOrigin(*(np.array(origin)))
vtk_data.SetSpacing(*(size/cells))
vtk_data.SetSpacing(*(np.array(size)/np.array(cells)))
return VTK(vtk_data)
@staticmethod
def from_rectilinear_grid(grid,size,origin=np.zeros(3)):
def from_rectilinear_grid(grid: np.ndarray, size: FloatSequence, origin: FloatSequence = np.zeros(3)) -> 'VTK':
"""
Create VTK of type vtk.vtkRectilinearGrid.
@ -98,7 +100,7 @@ class VTK:
@staticmethod
def from_unstructured_grid(nodes,connectivity,cell_type):
def from_unstructured_grid(nodes: np.ndarray, connectivity: np.ndarray, cell_type: str) -> 'VTK':
"""
Create VTK of type vtk.vtkUnstructuredGrid.
@ -138,7 +140,7 @@ class VTK:
@staticmethod
def from_poly_data(points):
def from_poly_data(points: np.ndarray) -> 'VTK':
"""
Create VTK of type vtk.polyData.
@ -172,15 +174,17 @@ class VTK:
@staticmethod
def load(fname,dataset_type=None):
def load(fname: Union[str, Path],
dataset_type: Literal['ImageData', 'UnstructuredGrid', 'PolyData'] = None) -> 'VTK':
"""
Load from VTK file.
Parameters
----------
fname : str or pathlib.Path
Filename for reading. Valid extensions are .vti, .vtr, .vtu, .vtp, and .vtk.
dataset_type : {'vtkImageData', ''vtkRectilinearGrid', 'vtkUnstructuredGrid', 'vtkPolyData'}, optional
Filename for reading.
Valid extensions are .vti, .vtr, .vtu, .vtp, and .vtk.
dataset_type : {'ImageData', 'UnstructuredGrid', 'PolyData'}, optional
Name of the vtk.vtkDataSet subclass when opening a .vtk file.
Returns
@ -234,7 +238,7 @@ class VTK:
def _write(writer):
"""Wrapper for parallel writing."""
writer.Write()
def save(self,fname,parallel=True,compress=True):
def save(self, fname: Union[str, Path], parallel: bool = True, compress: bool = True):
"""
Save as VTK file.
@ -280,7 +284,7 @@ class VTK:
# Check https://blog.kitware.com/ghost-and-blanking-visibility-changes/ for missing data
# Needs support for damask.Table
def add(self,data,label=None):
def add(self, data: Union[np.ndarray, np.ma.MaskedArray], label: str = None):
"""
Add data to either cells or points.
@ -327,7 +331,7 @@ class VTK:
raise TypeError
def get(self,label):
def get(self, label: str) -> np.ndarray:
"""
Get either cell or point data.
@ -369,7 +373,7 @@ class VTK:
raise ValueError(f'Array "{label}" not found.')
def get_comments(self):
def get_comments(self) -> List[str]:
"""Return the comments."""
fielddata = self.vtk_data.GetFieldData()
for a in range(fielddata.GetNumberOfArrays()):
@ -379,7 +383,7 @@ class VTK:
return []
def set_comments(self,comments):
def set_comments(self, comments: Union[str, List[str]]):
"""
Set comments.
@ -396,7 +400,7 @@ class VTK:
self.vtk_data.GetFieldData().AddArray(s)
def add_comments(self,comments):
def add_comments(self, comments: Union[str, List[str]]):
"""
Add comments.
@ -409,7 +413,7 @@ class VTK:
self.set_comments(self.get_comments() + ([comments] if isinstance(comments,str) else comments))
def __repr__(self):
def __repr__(self) -> str:
"""ASCII representation of the VTK data."""
writer = vtk.vtkDataSetWriter()
writer.SetHeader(f'# {util.execution_stamp("VTK")}')
@ -419,7 +423,7 @@ class VTK:
return writer.GetOutputString()
def show(self) -> None:
def show(self):
"""
Render.

View File

@ -79,7 +79,7 @@ def from_Poisson_disc(size: _FloatSequence, N_seeds: int, N_candidates: int, dis
s = 1
i = 0
progress = _util._ProgressBar(N_seeds+1,'',50)
progress = _util.ProgressBar(N_seeds+1,'',50)
while s < N_seeds:
i += 1
candidates = rng.random((N_candidates,3))*_np.broadcast_to(size,(N_candidates,3))

View File

@ -7,12 +7,16 @@ import subprocess
import shlex
import re
import fractions
import collections.abc as abc
from functools import reduce
from typing import Union, Tuple, Iterable, Callable, Dict, List, Any, Literal, Optional
from pathlib import Path
import numpy as np
import h5py
from . import version
from ._typehints import FloatSequence
# limit visibility
__all__=[
@ -50,16 +54,16 @@ _colors = {
####################################################################################################
# Functions
####################################################################################################
def srepr(arg,glue = '\n'):
def srepr(msg, glue: str = '\n') -> str:
r"""
Join items with glue string.
Parameters
----------
arg : iterable
msg : object with __repr__ or sequence of objects with __repr__
Items to join.
glue : str, optional
Glue used for joining operation. Defaults to \n.
Glue used for joining operation. Defaults to '\n'.
Returns
-------
@ -67,21 +71,21 @@ def srepr(arg,glue = '\n'):
String representation of the joined items.
"""
if (not hasattr(arg, 'strip') and
(hasattr(arg, '__getitem__') or
hasattr(arg, '__iter__'))):
return glue.join(str(x) for x in arg)
if (not hasattr(msg, 'strip') and
(hasattr(msg, '__getitem__') or
hasattr(msg, '__iter__'))):
return glue.join(str(x) for x in msg)
else:
return arg if isinstance(arg,str) else repr(arg)
return msg if isinstance(msg,str) else repr(msg)
def emph(what):
def emph(msg) -> str:
"""
Format with emphasis.
Parameters
----------
what : object with __repr__ or iterable of objects with __repr__.
msg : object with __repr__ or sequence of objects with __repr__
Message to format.
Returns
@ -90,15 +94,15 @@ def emph(what):
Formatted string representation of the joined items.
"""
return _colors['bold']+srepr(what)+_colors['end_color']
return _colors['bold']+srepr(msg)+_colors['end_color']
def deemph(what):
def deemph(msg) -> str:
"""
Format with deemphasis.
Parameters
----------
what : object with __repr__ or iterable of objects with __repr__.
msg : object with __repr__ or sequence of objects with __repr__
Message to format.
Returns
@ -107,15 +111,15 @@ def deemph(what):
Formatted string representation of the joined items.
"""
return _colors['dim']+srepr(what)+_colors['end_color']
return _colors['dim']+srepr(msg)+_colors['end_color']
def warn(what):
def warn(msg) -> str:
"""
Format for warning.
Parameters
----------
what : object with __repr__ or iterable of objects with __repr__.
msg : object with __repr__ or sequence of objects with __repr__
Message to format.
Returns
@ -124,15 +128,15 @@ def warn(what):
Formatted string representation of the joined items.
"""
return _colors['warning']+emph(what)+_colors['end_color']
return _colors['warning']+emph(msg)+_colors['end_color']
def strikeout(what):
def strikeout(msg) -> str:
"""
Format as strikeout.
Parameters
----------
what : object with __repr__ or iterable of objects with __repr__.
msg : object with __repr__ or iterable of objects with __repr__
Message to format.
Returns
@ -141,10 +145,10 @@ def strikeout(what):
Formatted string representation of the joined items.
"""
return _colors['crossout']+srepr(what)+_colors['end_color']
return _colors['crossout']+srepr(msg)+_colors['end_color']
def run(cmd,wd='./',env=None,timeout=None):
def run(cmd: str, wd: str = './', env: Dict[str, str] = None, timeout: int = None) -> Tuple[str, str]:
"""
Run a command.
@ -153,7 +157,7 @@ def run(cmd,wd='./',env=None,timeout=None):
cmd : str
Command to be executed.
wd : str, optional
Working directory of process. Defaults to ./ .
Working directory of process. Defaults to './'.
env : dict, optional
Environment for execution.
timeout : integer, optional
@ -185,7 +189,7 @@ def run(cmd,wd='./',env=None,timeout=None):
execute = run
def natural_sort(key):
def natural_sort(key: str) -> List[Union[int, str]]:
"""
Natural sort.
@ -200,7 +204,10 @@ def natural_sort(key):
return [ convert(c) for c in re.split('([0-9]+)', key) ]
def show_progress(iterable,N_iter=None,prefix='',bar_length=50):
def show_progress(iterable: Iterable,
N_iter: int = None,
prefix: str = '',
bar_length: int = 50) -> Any:
"""
Decorate a loop with a progress bar.
@ -208,39 +215,49 @@ def show_progress(iterable,N_iter=None,prefix='',bar_length=50):
Parameters
----------
iterable : iterable or function with yield statement
Iterable (or function with yield statement) to be decorated.
iterable : iterable
Iterable to be decorated.
N_iter : int, optional
Total number of iterations. Required unless obtainable as len(iterable).
Total number of iterations. Required if iterable is not a sequence.
prefix : str, optional
Prefix string.
bar_length : int, optional
Length of progress bar in characters. Defaults to 50.
"""
if N_iter in [0,1] or (hasattr(iterable,'__len__') and len(iterable) <= 1):
if isinstance(iterable,abc.Sequence):
if N_iter is None:
N = len(iterable)
else:
raise ValueError('N_iter given for sequence')
else:
if N_iter is None:
raise ValueError('N_iter not given')
else:
N = N_iter
if N <= 1:
for item in iterable:
yield item
else:
status = _ProgressBar(N_iter if N_iter is not None else len(iterable),prefix,bar_length)
status = ProgressBar(N,prefix,bar_length)
for i,item in enumerate(iterable):
yield item
status.update(i)
def scale_to_coprime(v):
def scale_to_coprime(v: FloatSequence) -> np.ndarray:
"""
Scale vector to co-prime (relatively prime) integers.
Parameters
----------
v : numpy.ndarray of shape (:)
v : sequence of float, len (:)
Vector to scale.
Returns
-------
m : numpy.ndarray of shape (:)
m : numpy.ndarray, shape (:)
Vector scaled to co-prime numbers.
"""
@ -257,17 +274,21 @@ def scale_to_coprime(v):
except AttributeError:
return a * b // np.gcd(a, b)
m = (np.array(v) * reduce(lcm, map(lambda x: int(get_square_denominator(x)),v)) ** 0.5).astype(int)
v_ = np.array(v)
m = (v_ * reduce(lcm, map(lambda x: int(get_square_denominator(x)),v_))**0.5).astype(int)
m = m//reduce(np.gcd,m)
with np.errstate(invalid='ignore'):
if not np.allclose(np.ma.masked_invalid(v/m),v[np.argmax(abs(v))]/m[np.argmax(abs(v))]):
raise ValueError(f'Invalid result {m} for input {v}. Insufficient precision?')
if not np.allclose(np.ma.masked_invalid(v_/m),v_[np.argmax(abs(v_))]/m[np.argmax(abs(v_))]):
raise ValueError(f'Invalid result {m} for input {v_}. Insufficient precision?')
return m
def project_equal_angle(vector,direction='z',normalize=True,keepdims=False):
def project_equal_angle(vector: np.ndarray,
direction: Literal['x', 'y', 'z'] = 'z',
normalize: bool = True,
keepdims: bool = False) -> np.ndarray:
"""
Apply equal-angle projection to vector.
@ -275,9 +296,8 @@ def project_equal_angle(vector,direction='z',normalize=True,keepdims=False):
----------
vector : numpy.ndarray, shape (...,3)
Vector coordinates to be projected.
direction : str
Projection direction 'x', 'y', or 'z'.
Defaults to 'z'.
direction : {'x', 'y', 'z'}
Projection direction. Defaults to 'z'.
normalize : bool
Ensure unit length of input vector. Defaults to True.
keepdims : bool
@ -309,7 +329,10 @@ def project_equal_angle(vector,direction='z',normalize=True,keepdims=False):
return np.roll(np.block([v[...,:2]/(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
-shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]
def project_equal_area(vector,direction='z',normalize=True,keepdims=False):
def project_equal_area(vector: np.ndarray,
direction: Literal['x', 'y', 'z'] = 'z',
normalize: bool = True,
keepdims: bool = False) -> np.ndarray:
"""
Apply equal-area projection to vector.
@ -317,9 +340,8 @@ def project_equal_area(vector,direction='z',normalize=True,keepdims=False):
----------
vector : numpy.ndarray, shape (...,3)
Vector coordinates to be projected.
direction : str
Projection direction 'x', 'y', or 'z'.
Defaults to 'z'.
direction : {'x', 'y', 'z'}
Projection direction. Defaults to 'z'.
normalize : bool
Ensure unit length of input vector. Defaults to True.
keepdims : bool
@ -351,15 +373,14 @@ def project_equal_area(vector,direction='z',normalize=True,keepdims=False):
return np.roll(np.block([v[...,:2]/np.sqrt(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
-shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]
def execution_stamp(class_name,function_name=None):
def execution_stamp(class_name: str, function_name: str = None) -> str:
"""Timestamp the execution of a (function within a) class."""
now = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S%z')
_function_name = '' if function_name is None else f'.{function_name}'
return f'damask.{class_name}{_function_name} v{version} ({now})'
def hybrid_IA(dist,N,rng_seed=None):
def hybrid_IA(dist: np.ndarray, N: int, rng_seed = None) -> np.ndarray:
"""
Hybrid integer approximation.
@ -387,7 +408,10 @@ def hybrid_IA(dist,N,rng_seed=None):
return np.repeat(np.arange(len(dist)),repeats)[np.random.default_rng(rng_seed).permutation(N_inv_samples)[:N]]
def shapeshifter(fro,to,mode='left',keep_ones=False):
def shapeshifter(fro: Tuple[int, ...],
to: Tuple[int, ...],
mode: Literal['left','right'] = 'left',
keep_ones: bool = False) -> Tuple[Optional[int], ...]:
"""
Return dimensions that reshape 'fro' to become broadcastable to 'to'.
@ -398,9 +422,9 @@ def shapeshifter(fro,to,mode='left',keep_ones=False):
to : tuple
Target shape of array after broadcasting.
len(to) cannot be less than len(fro).
mode : str, optional
mode : {'left', 'right'}, optional
Indicates whether new axes are preferably added to
either 'left' or 'right' of the original shape.
either left or right of the original shape.
Defaults to 'left'.
keep_ones : bool, optional
Treat '1' in fro as literal value instead of dimensional placeholder.
@ -434,21 +458,22 @@ def shapeshifter(fro,to,mode='left',keep_ones=False):
fro = (1,) if not len(fro) else fro
to = (1,) if not len(to) else to
try:
grp = re.match(beg[mode]
match = re.match(beg[mode]
+f',{sep[mode]}'.join(map(lambda x: f'{x}'
if x>1 or (keep_ones and len(fro)>1) else
'\\d+',fro))
+f',{end[mode]}',
','.join(map(str,to))+',').groups()
except AttributeError:
+f',{end[mode]}',','.join(map(str,to))+',')
assert match
grp = match.groups()
except AssertionError:
raise ValueError(f'Shapes can not be shifted {fro} --> {to}')
fill = ()
fill: Tuple[Optional[int], ...] = ()
for g,d in zip(grp,fro+(None,)):
fill += (1,)*g.count(',')+(d,)
return fill[:-1]
def shapeblender(a,b):
def shapeblender(a: Tuple[int, ...], b: Tuple[int, ...]) -> Tuple[int, ...]:
"""
Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'.
@ -476,7 +501,7 @@ def shapeblender(a,b):
return a + b[i:]
def extend_docstring(extra_docstring):
def extend_docstring(extra_docstring: str) -> Callable:
"""
Decorator: Append to function's docstring.
@ -492,7 +517,7 @@ def extend_docstring(extra_docstring):
return _decorator
def extended_docstring(f,extra_docstring):
def extended_docstring(f: Callable, extra_docstring: str) -> Callable:
"""
Decorator: Combine another function's docstring with a given docstring.
@ -510,7 +535,7 @@ def extended_docstring(f,extra_docstring):
return _decorator
def DREAM3D_base_group(fname):
def DREAM3D_base_group(fname: Union[str, Path]) -> str:
"""
Determine the base group of a DREAM.3D file.
@ -536,7 +561,7 @@ def DREAM3D_base_group(fname):
return base_group
def DREAM3D_cell_data_group(fname):
def DREAM3D_cell_data_group(fname: Union[str, Path]) -> str:
"""
Determine the cell data group of a DREAM.3D file.
@ -568,18 +593,18 @@ def DREAM3D_cell_data_group(fname):
return cell_data_group
def Bravais_to_Miller(*,uvtw=None,hkil=None):
def Bravais_to_Miller(*, uvtw: np.ndarray = None, hkil: np.ndarray = None) -> np.ndarray:
"""
Transform 4 MillerBravais indices to 3 Miller indices of crystal direction [uvw] or plane normal (hkl).
Parameters
----------
uvtw|hkil : numpy.ndarray of shape (...,4)
uvtw|hkil : numpy.ndarray, shape (...,4)
MillerBravais indices of crystallographic direction [uvtw] or plane normal (hkil).
Returns
-------
uvw|hkl : numpy.ndarray of shape (...,3)
uvw|hkl : numpy.ndarray, shape (...,3)
Miller indices of [uvw] direction or (hkl) plane normal.
"""
@ -595,18 +620,18 @@ def Bravais_to_Miller(*,uvtw=None,hkil=None):
return np.einsum('il,...l',basis,axis)
def Miller_to_Bravais(*,uvw=None,hkl=None):
def Miller_to_Bravais(*, uvw: np.ndarray = None, hkl: np.ndarray = None) -> np.ndarray:
"""
Transform 3 Miller indices to 4 MillerBravais indices of crystal direction [uvtw] or plane normal (hkil).
Parameters
----------
uvw|hkl : numpy.ndarray of shape (...,3)
uvw|hkl : numpy.ndarray, shape (...,3)
Miller indices of crystallographic direction [uvw] or plane normal (hkl).
Returns
-------
uvtw|hkil : numpy.ndarray of shape (...,4)
uvtw|hkil : numpy.ndarray, shape (...,4)
MillerBravais indices of [uvtw] direction or (hkil) plane normal.
"""
@ -624,7 +649,7 @@ def Miller_to_Bravais(*,uvw=None,hkl=None):
return np.einsum('il,...l',basis,axis)
def dict_prune(d):
def dict_prune(d: Dict) -> Dict:
"""
Recursively remove empty dictionaries.
@ -650,7 +675,7 @@ def dict_prune(d):
return new
def dict_flatten(d):
def dict_flatten(d: Dict) -> Dict:
"""
Recursively remove keys of single-entry dictionaries.
@ -678,14 +703,14 @@ def dict_flatten(d):
####################################################################################################
# Classes
####################################################################################################
class _ProgressBar:
class ProgressBar:
"""
Report progress of an interation as a status bar.
Works for 0-based loops, ETA is estimated by linear extrapolation.
"""
def __init__(self,total,prefix,bar_length):
def __init__(self, total: int, prefix: str, bar_length: int):
"""
Set current time as basis for ETA estimation.
@ -708,7 +733,7 @@ class _ProgressBar:
sys.stderr.write(f"{self.prefix} {''*self.bar_length} 0% ETA n/a")
sys.stderr.flush()
def update(self,iteration):
def update(self, iteration: int) -> None:
fraction = (iteration+1) / self.total
filled_length = int(self.bar_length * fraction)