Merge branch 'typehints_vtk_util' into 'development'

02 Typehints vtk util

See merge request damask/DAMASK!502
This commit is contained in:
Philip Eisenlohr 2022-01-23 09:31:47 +00:00
commit 160eb1c600
4 changed files with 140 additions and 110 deletions

View File

@ -4,7 +4,7 @@ import warnings
import multiprocessing as mp import multiprocessing as mp
from functools import partial from functools import partial
import typing import typing
from typing import Union, Optional, TextIO, List, Sequence from typing import Union, Optional, TextIO, List, Sequence, Literal
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@ -70,7 +70,7 @@ class Grid:
]) ])
def __copy__(self) -> "Grid": def __copy__(self) -> 'Grid':
"""Create deep copy.""" """Create deep copy."""
return copy.deepcopy(self) return copy.deepcopy(self)
@ -161,7 +161,7 @@ class Grid:
@staticmethod @staticmethod
def load(fname: Union[str, Path]) -> "Grid": def load(fname: Union[str, Path]) -> 'Grid':
""" """
Load from VTK image data file. Load from VTK image data file.
@ -190,7 +190,7 @@ class Grid:
@typing. no_type_check @typing. no_type_check
@staticmethod @staticmethod
def load_ASCII(fname)-> "Grid": def load_ASCII(fname)-> 'Grid':
""" """
Load from geom file. Load from geom file.
@ -264,7 +264,7 @@ class Grid:
@staticmethod @staticmethod
def load_Neper(fname: Union[str, Path]) -> "Grid": def load_Neper(fname: Union[str, Path]) -> 'Grid':
""" """
Load from Neper VTK file. Load from Neper VTK file.
@ -279,7 +279,7 @@ class Grid:
Grid-based geometry from file. Grid-based geometry from file.
""" """
v = VTK.load(fname,'vtkImageData') v = VTK.load(fname,'ImageData')
cells = np.array(v.vtk_data.GetDimensions())-1 cells = np.array(v.vtk_data.GetDimensions())-1
bbox = np.array(v.vtk_data.GetBounds()).reshape(3,2).T bbox = np.array(v.vtk_data.GetBounds()).reshape(3,2).T
@ -292,7 +292,7 @@ class Grid:
def load_DREAM3D(fname: Union[str, Path], def load_DREAM3D(fname: Union[str, Path],
feature_IDs: str = None, cell_data: str = None, feature_IDs: str = None, cell_data: str = None,
phases: str = 'Phases', Euler_angles: str = 'EulerAngles', phases: str = 'Phases', Euler_angles: str = 'EulerAngles',
base_group: str = None) -> "Grid": base_group: str = None) -> 'Grid':
""" """
Load DREAM.3D (HDF5) file. Load DREAM.3D (HDF5) file.
@ -354,7 +354,7 @@ class Grid:
@staticmethod @staticmethod
def from_table(table: Table, def from_table(table: Table,
coordinates: str, coordinates: str,
labels: Union[str, Sequence[str]]) -> "Grid": labels: Union[str, Sequence[str]]) -> 'Grid':
""" """
Create grid from ASCII table. Create grid from ASCII table.
@ -422,6 +422,7 @@ class Grid:
Grid-based geometry from tessellation. Grid-based geometry from tessellation.
""" """
weights_p: FloatSequence
if periodic: if periodic:
weights_p = np.tile(weights,27) # Laguerre weights (1,2,3,1,2,3,...,1,2,3) weights_p = np.tile(weights,27) # Laguerre weights (1,2,3,1,2,3,...,1,2,3)
seeds_p = np.vstack((seeds -np.array([size[0],0.,0.]),seeds, seeds +np.array([size[0],0.,0.]))) seeds_p = np.vstack((seeds -np.array([size[0],0.,0.]),seeds, seeds +np.array([size[0],0.,0.])))
@ -452,7 +453,7 @@ class Grid:
size: FloatSequence, size: FloatSequence,
seeds: np.ndarray, seeds: np.ndarray,
material: IntSequence = None, material: IntSequence = None,
periodic: bool = True) -> "Grid": periodic: bool = True) -> 'Grid':
""" """
Create grid from Voronoi tessellation. Create grid from Voronoi tessellation.
@ -538,7 +539,7 @@ class Grid:
surface: str, surface: str,
threshold: float = 0.0, threshold: float = 0.0,
periods: int = 1, periods: int = 1,
materials: IntSequence = (0,1)) -> "Grid": materials: IntSequence = (0,1)) -> 'Grid':
""" """
Create grid from definition of triply periodic minimal surface. Create grid from definition of triply periodic minimal surface.
@ -684,7 +685,7 @@ class Grid:
fill: int = None, fill: int = None,
R: Rotation = Rotation(), R: Rotation = Rotation(),
inverse: bool = False, inverse: bool = False,
periodic: bool = True) -> "Grid": periodic: bool = True) -> 'Grid':
""" """
Insert a primitive geometric object at a given position. Insert a primitive geometric object at a given position.
@ -769,7 +770,7 @@ class Grid:
) )
def mirror(self, directions: Sequence[str], reflect: bool = False) -> "Grid": def mirror(self, directions: Sequence[str], reflect: bool = False) -> 'Grid':
""" """
Mirror grid along given directions. Mirror grid along given directions.
@ -821,7 +822,7 @@ class Grid:
) )
def flip(self, directions: Sequence[str]) -> "Grid": def flip(self, directions: Union[Literal['x', 'y', 'z'], Sequence[Literal['x', 'y', 'z']]]) -> 'Grid':
""" """
Flip grid along given directions. Flip grid along given directions.
@ -851,7 +852,7 @@ class Grid:
) )
def scale(self, cells: IntSequence, periodic: bool = True) -> "Grid": def scale(self, cells: IntSequence, periodic: bool = True) -> 'Grid':
""" """
Scale grid to new cells. Scale grid to new cells.
@ -898,7 +899,7 @@ class Grid:
def clean(self, def clean(self,
stencil: int = 3, stencil: int = 3,
selection: IntSequence = None, selection: IntSequence = None,
periodic: bool = True) -> "Grid": periodic: bool = True) -> 'Grid':
""" """
Smooth grid by selecting most frequent material index within given stencil at each location. Smooth grid by selecting most frequent material index within given stencil at each location.
@ -938,7 +939,7 @@ class Grid:
) )
def renumber(self) -> "Grid": def renumber(self) -> 'Grid':
""" """
Renumber sorted material indices as 0,...,N-1. Renumber sorted material indices as 0,...,N-1.
@ -957,7 +958,7 @@ class Grid:
) )
def rotate(self, R: Rotation, fill: int = None) -> "Grid": def rotate(self, R: Rotation, fill: int = None) -> 'Grid':
""" """
Rotate grid (pad if required). Rotate grid (pad if required).
@ -997,7 +998,7 @@ class Grid:
def canvas(self, def canvas(self,
cells: IntSequence = None, cells: IntSequence = None,
offset: IntSequence = None, offset: IntSequence = None,
fill: int = None) -> "Grid": fill: int = None) -> 'Grid':
""" """
Crop or enlarge/pad grid. Crop or enlarge/pad grid.
@ -1048,7 +1049,7 @@ class Grid:
) )
def substitute(self, from_material: IntSequence, to_material: IntSequence) -> "Grid": def substitute(self, from_material: IntSequence, to_material: IntSequence) -> 'Grid':
""" """
Substitute material indices. Substitute material indices.
@ -1076,7 +1077,7 @@ class Grid:
) )
def sort(self) -> "Grid": def sort(self) -> 'Grid':
""" """
Sort material indices such that min(material) is located at (0,0,0). Sort material indices such that min(material) is located at (0,0,0).
@ -1102,7 +1103,7 @@ class Grid:
vicinity: int = 1, vicinity: int = 1,
offset: int = None, offset: int = None,
trigger: IntSequence = [], trigger: IntSequence = [],
periodic: bool = True) -> "Grid": periodic: bool = True) -> 'Grid':
""" """
Offset material index of points in the vicinity of xxx. Offset material index of points in the vicinity of xxx.

View File

@ -2,6 +2,7 @@ import os
import warnings import warnings
import multiprocessing as mp import multiprocessing as mp
from pathlib import Path from pathlib import Path
from typing import Union, Literal, List
import numpy as np import numpy as np
import vtk import vtk
@ -9,6 +10,7 @@ from vtk.util.numpy_support import numpy_to_vtk as np_to_vtk
from vtk.util.numpy_support import numpy_to_vtkIdTypeArray as np_to_vtkIdTypeArray from vtk.util.numpy_support import numpy_to_vtkIdTypeArray as np_to_vtkIdTypeArray
from vtk.util.numpy_support import vtk_to_numpy as vtk_to_np from vtk.util.numpy_support import vtk_to_numpy as vtk_to_np
from ._typehints import FloatSequence, IntSequence
from . import util from . import util
from . import Table from . import Table
@ -20,7 +22,7 @@ class VTK:
High-level interface to VTK. High-level interface to VTK.
""" """
def __init__(self,vtk_data): def __init__(self, vtk_data: vtk.vtkDataSet):
""" """
New spatial visualization. New spatial visualization.
@ -36,7 +38,7 @@ class VTK:
@staticmethod @staticmethod
def from_image_data(cells,size,origin=np.zeros(3)): def from_image_data(cells: IntSequence, size: FloatSequence, origin: FloatSequence = np.zeros(3)) -> 'VTK':
""" """
Create VTK of type vtk.vtkImageData. Create VTK of type vtk.vtkImageData.
@ -60,13 +62,13 @@ class VTK:
vtk_data = vtk.vtkImageData() vtk_data = vtk.vtkImageData()
vtk_data.SetDimensions(*(np.array(cells)+1)) vtk_data.SetDimensions(*(np.array(cells)+1))
vtk_data.SetOrigin(*(np.array(origin))) vtk_data.SetOrigin(*(np.array(origin)))
vtk_data.SetSpacing(*(size/cells)) vtk_data.SetSpacing(*(np.array(size)/np.array(cells)))
return VTK(vtk_data) return VTK(vtk_data)
@staticmethod @staticmethod
def from_rectilinear_grid(grid,size,origin=np.zeros(3)): def from_rectilinear_grid(grid: np.ndarray, size: FloatSequence, origin: FloatSequence = np.zeros(3)) -> 'VTK':
""" """
Create VTK of type vtk.vtkRectilinearGrid. Create VTK of type vtk.vtkRectilinearGrid.
@ -98,7 +100,7 @@ class VTK:
@staticmethod @staticmethod
def from_unstructured_grid(nodes,connectivity,cell_type): def from_unstructured_grid(nodes: np.ndarray, connectivity: np.ndarray, cell_type: str) -> 'VTK':
""" """
Create VTK of type vtk.vtkUnstructuredGrid. Create VTK of type vtk.vtkUnstructuredGrid.
@ -138,7 +140,7 @@ class VTK:
@staticmethod @staticmethod
def from_poly_data(points): def from_poly_data(points: np.ndarray) -> 'VTK':
""" """
Create VTK of type vtk.polyData. Create VTK of type vtk.polyData.
@ -172,15 +174,17 @@ class VTK:
@staticmethod @staticmethod
def load(fname,dataset_type=None): def load(fname: Union[str, Path],
dataset_type: Literal['ImageData', 'UnstructuredGrid', 'PolyData'] = None) -> 'VTK':
""" """
Load from VTK file. Load from VTK file.
Parameters Parameters
---------- ----------
fname : str or pathlib.Path fname : str or pathlib.Path
Filename for reading. Valid extensions are .vti, .vtr, .vtu, .vtp, and .vtk. Filename for reading.
dataset_type : {'vtkImageData', ''vtkRectilinearGrid', 'vtkUnstructuredGrid', 'vtkPolyData'}, optional Valid extensions are .vti, .vtr, .vtu, .vtp, and .vtk.
dataset_type : {'ImageData', 'UnstructuredGrid', 'PolyData'}, optional
Name of the vtk.vtkDataSet subclass when opening a .vtk file. Name of the vtk.vtkDataSet subclass when opening a .vtk file.
Returns Returns
@ -234,7 +238,7 @@ class VTK:
def _write(writer): def _write(writer):
"""Wrapper for parallel writing.""" """Wrapper for parallel writing."""
writer.Write() writer.Write()
def save(self,fname,parallel=True,compress=True): def save(self, fname: Union[str, Path], parallel: bool = True, compress: bool = True):
""" """
Save as VTK file. Save as VTK file.
@ -280,7 +284,7 @@ class VTK:
# Check https://blog.kitware.com/ghost-and-blanking-visibility-changes/ for missing data # Check https://blog.kitware.com/ghost-and-blanking-visibility-changes/ for missing data
# Needs support for damask.Table # Needs support for damask.Table
def add(self,data,label=None): def add(self, data: Union[np.ndarray, np.ma.MaskedArray], label: str = None):
""" """
Add data to either cells or points. Add data to either cells or points.
@ -327,7 +331,7 @@ class VTK:
raise TypeError raise TypeError
def get(self,label): def get(self, label: str) -> np.ndarray:
""" """
Get either cell or point data. Get either cell or point data.
@ -369,7 +373,7 @@ class VTK:
raise ValueError(f'Array "{label}" not found.') raise ValueError(f'Array "{label}" not found.')
def get_comments(self): def get_comments(self) -> List[str]:
"""Return the comments.""" """Return the comments."""
fielddata = self.vtk_data.GetFieldData() fielddata = self.vtk_data.GetFieldData()
for a in range(fielddata.GetNumberOfArrays()): for a in range(fielddata.GetNumberOfArrays()):
@ -379,7 +383,7 @@ class VTK:
return [] return []
def set_comments(self,comments): def set_comments(self, comments: Union[str, List[str]]):
""" """
Set comments. Set comments.
@ -396,7 +400,7 @@ class VTK:
self.vtk_data.GetFieldData().AddArray(s) self.vtk_data.GetFieldData().AddArray(s)
def add_comments(self,comments): def add_comments(self, comments: Union[str, List[str]]):
""" """
Add comments. Add comments.
@ -409,7 +413,7 @@ class VTK:
self.set_comments(self.get_comments() + ([comments] if isinstance(comments,str) else comments)) self.set_comments(self.get_comments() + ([comments] if isinstance(comments,str) else comments))
def __repr__(self): def __repr__(self) -> str:
"""ASCII representation of the VTK data.""" """ASCII representation of the VTK data."""
writer = vtk.vtkDataSetWriter() writer = vtk.vtkDataSetWriter()
writer.SetHeader(f'# {util.execution_stamp("VTK")}') writer.SetHeader(f'# {util.execution_stamp("VTK")}')
@ -419,7 +423,7 @@ class VTK:
return writer.GetOutputString() return writer.GetOutputString()
def show(self) -> None: def show(self):
""" """
Render. Render.

View File

@ -79,7 +79,7 @@ def from_Poisson_disc(size: _FloatSequence, N_seeds: int, N_candidates: int, dis
s = 1 s = 1
i = 0 i = 0
progress = _util._ProgressBar(N_seeds+1,'',50) progress = _util.ProgressBar(N_seeds+1,'',50)
while s < N_seeds: while s < N_seeds:
i += 1 i += 1
candidates = rng.random((N_candidates,3))*_np.broadcast_to(size,(N_candidates,3)) candidates = rng.random((N_candidates,3))*_np.broadcast_to(size,(N_candidates,3))

View File

@ -7,12 +7,16 @@ import subprocess
import shlex import shlex
import re import re
import fractions import fractions
import collections.abc as abc
from functools import reduce from functools import reduce
from typing import Union, Tuple, Iterable, Callable, Dict, List, Any, Literal, Optional
from pathlib import Path
import numpy as np import numpy as np
import h5py import h5py
from . import version from . import version
from ._typehints import FloatSequence
# limit visibility # limit visibility
__all__=[ __all__=[
@ -50,16 +54,16 @@ _colors = {
#################################################################################################### ####################################################################################################
# Functions # Functions
#################################################################################################### ####################################################################################################
def srepr(arg,glue = '\n'): def srepr(msg, glue: str = '\n') -> str:
r""" r"""
Join items with glue string. Join items with glue string.
Parameters Parameters
---------- ----------
arg : iterable msg : object with __repr__ or sequence of objects with __repr__
Items to join. Items to join.
glue : str, optional glue : str, optional
Glue used for joining operation. Defaults to \n. Glue used for joining operation. Defaults to '\n'.
Returns Returns
------- -------
@ -67,21 +71,21 @@ def srepr(arg,glue = '\n'):
String representation of the joined items. String representation of the joined items.
""" """
if (not hasattr(arg, 'strip') and if (not hasattr(msg, 'strip') and
(hasattr(arg, '__getitem__') or (hasattr(msg, '__getitem__') or
hasattr(arg, '__iter__'))): hasattr(msg, '__iter__'))):
return glue.join(str(x) for x in arg) return glue.join(str(x) for x in msg)
else: else:
return arg if isinstance(arg,str) else repr(arg) return msg if isinstance(msg,str) else repr(msg)
def emph(what): def emph(msg) -> str:
""" """
Format with emphasis. Format with emphasis.
Parameters Parameters
---------- ----------
what : object with __repr__ or iterable of objects with __repr__. msg : object with __repr__ or sequence of objects with __repr__
Message to format. Message to format.
Returns Returns
@ -90,15 +94,15 @@ def emph(what):
Formatted string representation of the joined items. Formatted string representation of the joined items.
""" """
return _colors['bold']+srepr(what)+_colors['end_color'] return _colors['bold']+srepr(msg)+_colors['end_color']
def deemph(what): def deemph(msg) -> str:
""" """
Format with deemphasis. Format with deemphasis.
Parameters Parameters
---------- ----------
what : object with __repr__ or iterable of objects with __repr__. msg : object with __repr__ or sequence of objects with __repr__
Message to format. Message to format.
Returns Returns
@ -107,15 +111,15 @@ def deemph(what):
Formatted string representation of the joined items. Formatted string representation of the joined items.
""" """
return _colors['dim']+srepr(what)+_colors['end_color'] return _colors['dim']+srepr(msg)+_colors['end_color']
def warn(what): def warn(msg) -> str:
""" """
Format for warning. Format for warning.
Parameters Parameters
---------- ----------
what : object with __repr__ or iterable of objects with __repr__. msg : object with __repr__ or sequence of objects with __repr__
Message to format. Message to format.
Returns Returns
@ -124,15 +128,15 @@ def warn(what):
Formatted string representation of the joined items. Formatted string representation of the joined items.
""" """
return _colors['warning']+emph(what)+_colors['end_color'] return _colors['warning']+emph(msg)+_colors['end_color']
def strikeout(what): def strikeout(msg) -> str:
""" """
Format as strikeout. Format as strikeout.
Parameters Parameters
---------- ----------
what : object with __repr__ or iterable of objects with __repr__. msg : object with __repr__ or iterable of objects with __repr__
Message to format. Message to format.
Returns Returns
@ -141,10 +145,10 @@ def strikeout(what):
Formatted string representation of the joined items. Formatted string representation of the joined items.
""" """
return _colors['crossout']+srepr(what)+_colors['end_color'] return _colors['crossout']+srepr(msg)+_colors['end_color']
def run(cmd,wd='./',env=None,timeout=None): def run(cmd: str, wd: str = './', env: Dict[str, str] = None, timeout: int = None) -> Tuple[str, str]:
""" """
Run a command. Run a command.
@ -153,7 +157,7 @@ def run(cmd,wd='./',env=None,timeout=None):
cmd : str cmd : str
Command to be executed. Command to be executed.
wd : str, optional wd : str, optional
Working directory of process. Defaults to ./ . Working directory of process. Defaults to './'.
env : dict, optional env : dict, optional
Environment for execution. Environment for execution.
timeout : integer, optional timeout : integer, optional
@ -185,7 +189,7 @@ def run(cmd,wd='./',env=None,timeout=None):
execute = run execute = run
def natural_sort(key): def natural_sort(key: str) -> List[Union[int, str]]:
""" """
Natural sort. Natural sort.
@ -200,7 +204,10 @@ def natural_sort(key):
return [ convert(c) for c in re.split('([0-9]+)', key) ] return [ convert(c) for c in re.split('([0-9]+)', key) ]
def show_progress(iterable,N_iter=None,prefix='',bar_length=50): def show_progress(iterable: Iterable,
N_iter: int = None,
prefix: str = '',
bar_length: int = 50) -> Any:
""" """
Decorate a loop with a progress bar. Decorate a loop with a progress bar.
@ -208,39 +215,49 @@ def show_progress(iterable,N_iter=None,prefix='',bar_length=50):
Parameters Parameters
---------- ----------
iterable : iterable or function with yield statement iterable : iterable
Iterable (or function with yield statement) to be decorated. Iterable to be decorated.
N_iter : int, optional N_iter : int, optional
Total number of iterations. Required unless obtainable as len(iterable). Total number of iterations. Required if iterable is not a sequence.
prefix : str, optional prefix : str, optional
Prefix string. Prefix string.
bar_length : int, optional bar_length : int, optional
Length of progress bar in characters. Defaults to 50. Length of progress bar in characters. Defaults to 50.
""" """
if N_iter in [0,1] or (hasattr(iterable,'__len__') and len(iterable) <= 1): if isinstance(iterable,abc.Sequence):
if N_iter is None:
N = len(iterable)
else:
raise ValueError('N_iter given for sequence')
else:
if N_iter is None:
raise ValueError('N_iter not given')
else:
N = N_iter
if N <= 1:
for item in iterable: for item in iterable:
yield item yield item
else: else:
status = _ProgressBar(N_iter if N_iter is not None else len(iterable),prefix,bar_length) status = ProgressBar(N,prefix,bar_length)
for i,item in enumerate(iterable): for i,item in enumerate(iterable):
yield item yield item
status.update(i) status.update(i)
def scale_to_coprime(v): def scale_to_coprime(v: FloatSequence) -> np.ndarray:
""" """
Scale vector to co-prime (relatively prime) integers. Scale vector to co-prime (relatively prime) integers.
Parameters Parameters
---------- ----------
v : numpy.ndarray of shape (:) v : sequence of float, len (:)
Vector to scale. Vector to scale.
Returns Returns
------- -------
m : numpy.ndarray of shape (:) m : numpy.ndarray, shape (:)
Vector scaled to co-prime numbers. Vector scaled to co-prime numbers.
""" """
@ -257,17 +274,21 @@ def scale_to_coprime(v):
except AttributeError: except AttributeError:
return a * b // np.gcd(a, b) return a * b // np.gcd(a, b)
m = (np.array(v) * reduce(lcm, map(lambda x: int(get_square_denominator(x)),v)) ** 0.5).astype(int) v_ = np.array(v)
m = (v_ * reduce(lcm, map(lambda x: int(get_square_denominator(x)),v_))**0.5).astype(int)
m = m//reduce(np.gcd,m) m = m//reduce(np.gcd,m)
with np.errstate(invalid='ignore'): with np.errstate(invalid='ignore'):
if not np.allclose(np.ma.masked_invalid(v/m),v[np.argmax(abs(v))]/m[np.argmax(abs(v))]): if not np.allclose(np.ma.masked_invalid(v_/m),v_[np.argmax(abs(v_))]/m[np.argmax(abs(v_))]):
raise ValueError(f'Invalid result {m} for input {v}. Insufficient precision?') raise ValueError(f'Invalid result {m} for input {v_}. Insufficient precision?')
return m return m
def project_equal_angle(vector,direction='z',normalize=True,keepdims=False): def project_equal_angle(vector: np.ndarray,
direction: Literal['x', 'y', 'z'] = 'z',
normalize: bool = True,
keepdims: bool = False) -> np.ndarray:
""" """
Apply equal-angle projection to vector. Apply equal-angle projection to vector.
@ -275,9 +296,8 @@ def project_equal_angle(vector,direction='z',normalize=True,keepdims=False):
---------- ----------
vector : numpy.ndarray, shape (...,3) vector : numpy.ndarray, shape (...,3)
Vector coordinates to be projected. Vector coordinates to be projected.
direction : str direction : {'x', 'y', 'z'}
Projection direction 'x', 'y', or 'z'. Projection direction. Defaults to 'z'.
Defaults to 'z'.
normalize : bool normalize : bool
Ensure unit length of input vector. Defaults to True. Ensure unit length of input vector. Defaults to True.
keepdims : bool keepdims : bool
@ -309,7 +329,10 @@ def project_equal_angle(vector,direction='z',normalize=True,keepdims=False):
return np.roll(np.block([v[...,:2]/(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]), return np.roll(np.block([v[...,:2]/(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
-shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2] -shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]
def project_equal_area(vector,direction='z',normalize=True,keepdims=False): def project_equal_area(vector: np.ndarray,
direction: Literal['x', 'y', 'z'] = 'z',
normalize: bool = True,
keepdims: bool = False) -> np.ndarray:
""" """
Apply equal-area projection to vector. Apply equal-area projection to vector.
@ -317,9 +340,8 @@ def project_equal_area(vector,direction='z',normalize=True,keepdims=False):
---------- ----------
vector : numpy.ndarray, shape (...,3) vector : numpy.ndarray, shape (...,3)
Vector coordinates to be projected. Vector coordinates to be projected.
direction : str direction : {'x', 'y', 'z'}
Projection direction 'x', 'y', or 'z'. Projection direction. Defaults to 'z'.
Defaults to 'z'.
normalize : bool normalize : bool
Ensure unit length of input vector. Defaults to True. Ensure unit length of input vector. Defaults to True.
keepdims : bool keepdims : bool
@ -351,15 +373,14 @@ def project_equal_area(vector,direction='z',normalize=True,keepdims=False):
return np.roll(np.block([v[...,:2]/np.sqrt(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]), return np.roll(np.block([v[...,:2]/np.sqrt(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
-shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2] -shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]
def execution_stamp(class_name: str, function_name: str = None) -> str:
def execution_stamp(class_name,function_name=None):
"""Timestamp the execution of a (function within a) class.""" """Timestamp the execution of a (function within a) class."""
now = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S%z') now = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S%z')
_function_name = '' if function_name is None else f'.{function_name}' _function_name = '' if function_name is None else f'.{function_name}'
return f'damask.{class_name}{_function_name} v{version} ({now})' return f'damask.{class_name}{_function_name} v{version} ({now})'
def hybrid_IA(dist,N,rng_seed=None): def hybrid_IA(dist: np.ndarray, N: int, rng_seed = None) -> np.ndarray:
""" """
Hybrid integer approximation. Hybrid integer approximation.
@ -387,7 +408,10 @@ def hybrid_IA(dist,N,rng_seed=None):
return np.repeat(np.arange(len(dist)),repeats)[np.random.default_rng(rng_seed).permutation(N_inv_samples)[:N]] return np.repeat(np.arange(len(dist)),repeats)[np.random.default_rng(rng_seed).permutation(N_inv_samples)[:N]]
def shapeshifter(fro,to,mode='left',keep_ones=False): def shapeshifter(fro: Tuple[int, ...],
to: Tuple[int, ...],
mode: Literal['left','right'] = 'left',
keep_ones: bool = False) -> Tuple[Optional[int], ...]:
""" """
Return dimensions that reshape 'fro' to become broadcastable to 'to'. Return dimensions that reshape 'fro' to become broadcastable to 'to'.
@ -398,9 +422,9 @@ def shapeshifter(fro,to,mode='left',keep_ones=False):
to : tuple to : tuple
Target shape of array after broadcasting. Target shape of array after broadcasting.
len(to) cannot be less than len(fro). len(to) cannot be less than len(fro).
mode : str, optional mode : {'left', 'right'}, optional
Indicates whether new axes are preferably added to Indicates whether new axes are preferably added to
either 'left' or 'right' of the original shape. either left or right of the original shape.
Defaults to 'left'. Defaults to 'left'.
keep_ones : bool, optional keep_ones : bool, optional
Treat '1' in fro as literal value instead of dimensional placeholder. Treat '1' in fro as literal value instead of dimensional placeholder.
@ -434,21 +458,22 @@ def shapeshifter(fro,to,mode='left',keep_ones=False):
fro = (1,) if not len(fro) else fro fro = (1,) if not len(fro) else fro
to = (1,) if not len(to) else to to = (1,) if not len(to) else to
try: try:
grp = re.match(beg[mode] match = re.match(beg[mode]
+f',{sep[mode]}'.join(map(lambda x: f'{x}' +f',{sep[mode]}'.join(map(lambda x: f'{x}'
if x>1 or (keep_ones and len(fro)>1) else if x>1 or (keep_ones and len(fro)>1) else
'\\d+',fro)) '\\d+',fro))
+f',{end[mode]}', +f',{end[mode]}',','.join(map(str,to))+',')
','.join(map(str,to))+',').groups() assert match
except AttributeError: grp = match.groups()
except AssertionError:
raise ValueError(f'Shapes can not be shifted {fro} --> {to}') raise ValueError(f'Shapes can not be shifted {fro} --> {to}')
fill = () fill: Tuple[Optional[int], ...] = ()
for g,d in zip(grp,fro+(None,)): for g,d in zip(grp,fro+(None,)):
fill += (1,)*g.count(',')+(d,) fill += (1,)*g.count(',')+(d,)
return fill[:-1] return fill[:-1]
def shapeblender(a,b): def shapeblender(a: Tuple[int, ...], b: Tuple[int, ...]) -> Tuple[int, ...]:
""" """
Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'. Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'.
@ -476,7 +501,7 @@ def shapeblender(a,b):
return a + b[i:] return a + b[i:]
def extend_docstring(extra_docstring): def extend_docstring(extra_docstring: str) -> Callable:
""" """
Decorator: Append to function's docstring. Decorator: Append to function's docstring.
@ -492,7 +517,7 @@ def extend_docstring(extra_docstring):
return _decorator return _decorator
def extended_docstring(f,extra_docstring): def extended_docstring(f: Callable, extra_docstring: str) -> Callable:
""" """
Decorator: Combine another function's docstring with a given docstring. Decorator: Combine another function's docstring with a given docstring.
@ -510,7 +535,7 @@ def extended_docstring(f,extra_docstring):
return _decorator return _decorator
def DREAM3D_base_group(fname): def DREAM3D_base_group(fname: Union[str, Path]) -> str:
""" """
Determine the base group of a DREAM.3D file. Determine the base group of a DREAM.3D file.
@ -536,7 +561,7 @@ def DREAM3D_base_group(fname):
return base_group return base_group
def DREAM3D_cell_data_group(fname): def DREAM3D_cell_data_group(fname: Union[str, Path]) -> str:
""" """
Determine the cell data group of a DREAM.3D file. Determine the cell data group of a DREAM.3D file.
@ -568,18 +593,18 @@ def DREAM3D_cell_data_group(fname):
return cell_data_group return cell_data_group
def Bravais_to_Miller(*,uvtw=None,hkil=None): def Bravais_to_Miller(*, uvtw: np.ndarray = None, hkil: np.ndarray = None) -> np.ndarray:
""" """
Transform 4 MillerBravais indices to 3 Miller indices of crystal direction [uvw] or plane normal (hkl). Transform 4 MillerBravais indices to 3 Miller indices of crystal direction [uvw] or plane normal (hkl).
Parameters Parameters
---------- ----------
uvtw|hkil : numpy.ndarray of shape (...,4) uvtw|hkil : numpy.ndarray, shape (...,4)
MillerBravais indices of crystallographic direction [uvtw] or plane normal (hkil). MillerBravais indices of crystallographic direction [uvtw] or plane normal (hkil).
Returns Returns
------- -------
uvw|hkl : numpy.ndarray of shape (...,3) uvw|hkl : numpy.ndarray, shape (...,3)
Miller indices of [uvw] direction or (hkl) plane normal. Miller indices of [uvw] direction or (hkl) plane normal.
""" """
@ -595,18 +620,18 @@ def Bravais_to_Miller(*,uvtw=None,hkil=None):
return np.einsum('il,...l',basis,axis) return np.einsum('il,...l',basis,axis)
def Miller_to_Bravais(*,uvw=None,hkl=None): def Miller_to_Bravais(*, uvw: np.ndarray = None, hkl: np.ndarray = None) -> np.ndarray:
""" """
Transform 3 Miller indices to 4 MillerBravais indices of crystal direction [uvtw] or plane normal (hkil). Transform 3 Miller indices to 4 MillerBravais indices of crystal direction [uvtw] or plane normal (hkil).
Parameters Parameters
---------- ----------
uvw|hkl : numpy.ndarray of shape (...,3) uvw|hkl : numpy.ndarray, shape (...,3)
Miller indices of crystallographic direction [uvw] or plane normal (hkl). Miller indices of crystallographic direction [uvw] or plane normal (hkl).
Returns Returns
------- -------
uvtw|hkil : numpy.ndarray of shape (...,4) uvtw|hkil : numpy.ndarray, shape (...,4)
MillerBravais indices of [uvtw] direction or (hkil) plane normal. MillerBravais indices of [uvtw] direction or (hkil) plane normal.
""" """
@ -624,7 +649,7 @@ def Miller_to_Bravais(*,uvw=None,hkl=None):
return np.einsum('il,...l',basis,axis) return np.einsum('il,...l',basis,axis)
def dict_prune(d): def dict_prune(d: Dict) -> Dict:
""" """
Recursively remove empty dictionaries. Recursively remove empty dictionaries.
@ -650,7 +675,7 @@ def dict_prune(d):
return new return new
def dict_flatten(d): def dict_flatten(d: Dict) -> Dict:
""" """
Recursively remove keys of single-entry dictionaries. Recursively remove keys of single-entry dictionaries.
@ -678,14 +703,14 @@ def dict_flatten(d):
#################################################################################################### ####################################################################################################
# Classes # Classes
#################################################################################################### ####################################################################################################
class _ProgressBar: class ProgressBar:
""" """
Report progress of an interation as a status bar. Report progress of an interation as a status bar.
Works for 0-based loops, ETA is estimated by linear extrapolation. Works for 0-based loops, ETA is estimated by linear extrapolation.
""" """
def __init__(self,total,prefix,bar_length): def __init__(self, total: int, prefix: str, bar_length: int):
""" """
Set current time as basis for ETA estimation. Set current time as basis for ETA estimation.
@ -708,7 +733,7 @@ class _ProgressBar:
sys.stderr.write(f"{self.prefix} {''*self.bar_length} 0% ETA n/a") sys.stderr.write(f"{self.prefix} {''*self.bar_length} 0% ETA n/a")
sys.stderr.flush() sys.stderr.flush()
def update(self,iteration): def update(self, iteration: int) -> None:
fraction = (iteration+1) / self.total fraction = (iteration+1) / self.total
filled_length = int(self.bar_length * fraction) filled_length = int(self.bar_length * fraction)