added fist typehints for _grid module

This commit is contained in:
Daniel Otto de Mentock 2021-12-06 14:22:52 +01:00
parent 65c4417a20
commit 5558c301fa
2 changed files with 101 additions and 69 deletions

View File

@ -3,6 +3,8 @@ import copy
import warnings import warnings
import multiprocessing as mp import multiprocessing as mp
from functools import partial from functools import partial
from typing import Union, Optional, TextIO, List, Sequence
from pathlib import Path
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -13,7 +15,7 @@ from . import VTK
from . import util from . import util
from . import grid_filters from . import grid_filters
from . import Rotation from . import Rotation
from . import Table
class Grid: class Grid:
""" """
@ -25,7 +27,11 @@ class Grid:
the physical size. the physical size.
""" """
def __init__(self,material,size,origin=[0.0,0.0,0.0],comments=[]): def __init__(self,
material: np.ndarray,
size,
origin = [0.0,0.0,0.0],
comments = []):
""" """
New geometry definition for grid solvers. New geometry definition for grid solvers.
@ -43,12 +49,12 @@ class Grid:
""" """
self.material = material self.material = material
self.size = size self._size = size
self.origin = origin self._origin = origin
self.comments = comments self.comments = comments
def __repr__(self): def __repr__(self) -> str:
"""Basic information on grid definition.""" """Basic information on grid definition."""
mat_min = np.nanmin(self.material) mat_min = np.nanmin(self.material)
mat_max = np.nanmax(self.material) mat_max = np.nanmax(self.material)
@ -62,14 +68,14 @@ class Grid:
]) ])
def __copy__(self): def __copy__(self) -> "Grid":
"""Create deep copy.""" """Create deep copy."""
return copy.deepcopy(self) return copy.deepcopy(self)
copy = __copy__ copy = __copy__
def __eq__(self,other): def __eq__(self, other):
""" """
Test equality of other. Test equality of other.
@ -79,6 +85,8 @@ class Grid:
Grid to compare self against. Grid to compare self against.
""" """
if not isinstance(other, Grid):
raise TypeError
return (np.allclose(other.size,self.size) return (np.allclose(other.size,self.size)
and np.allclose(other.origin,self.origin) and np.allclose(other.origin,self.origin)
and np.all(other.cells == self.cells) and np.all(other.cells == self.cells)
@ -86,15 +94,15 @@ class Grid:
@property @property
def material(self): def material(self) -> np.ndarray:
"""Material indices.""" """Material indices."""
return self._material return self._material
@material.setter @material.setter
def material(self,material): def material(self, material: np.ndarray):
if len(material.shape) != 3: if len(material.shape) != 3:
raise ValueError(f'invalid material shape {material.shape}') raise ValueError(f'invalid material shape {material.shape}')
elif material.dtype not in np.sctypes['float'] + np.sctypes['int']: elif material.dtype not in np.sctypes['float'] and material.dtype not in np.sctypes['int']:
raise TypeError(f'invalid material data type {material.dtype}') raise TypeError(f'invalid material data type {material.dtype}')
else: else:
self._material = np.copy(material) self._material = np.copy(material)
@ -105,53 +113,53 @@ class Grid:
@property @property
def size(self): def size(self) -> np.ndarray:
"""Physical size of grid in meter.""" """Physical size of grid in meter."""
return self._size return self._size
@size.setter @size.setter
def size(self,size): def size(self, size: Union[Sequence[float], np.ndarray]):
if len(size) != 3 or any(np.array(size) < 0): if len(size) != 3 or any(np.array(size) < 0):
raise ValueError(f'invalid size {size}') raise ValueError(f'invalid size {size}')
else: else:
self._size = np.array(size) self._size = np.array(size)
@property @property
def origin(self): def origin(self) -> Union[Sequence[float], np.ndarray]:
"""Coordinates of grid origin in meter.""" """Coordinates of grid origin in meter."""
return self._origin return self._origin
@origin.setter @origin.setter
def origin(self,origin): def origin(self, origin: np.ndarray):
if len(origin) != 3: if len(origin) != 3:
raise ValueError(f'invalid origin {origin}') raise ValueError(f'invalid origin {origin}')
else: else:
self._origin = np.array(origin) self._origin = np.array(origin)
@property @property
def comments(self): def comments(self) -> List[str]:
"""Comments, e.g. history of operations.""" """Comments, e.g. history of operations."""
return self._comments return self._comments
@comments.setter @comments.setter
def comments(self,comments): def comments(self, comments: Union[str, Sequence[str]]):
self._comments = [str(c) for c in comments] if isinstance(comments,list) else [str(comments)] self._comments = [str(c) for c in comments] if isinstance(comments,list) else [str(comments)]
@property @property
def cells(self): def cells(self) -> np.ndarray:
"""Number of cells in x,y,z direction.""" """Number of cells in x,y,z direction."""
return np.asarray(self.material.shape) return np.asarray(self.material.shape)
@property @property
def N_materials(self): def N_materials(self) -> int:
"""Number of (unique) material indices within grid.""" """Number of (unique) material indices within grid."""
return np.unique(self.material).size return np.unique(self.material).size
@staticmethod @staticmethod
def load(fname): def load(fname: Union[str, Path]) -> "Grid":
""" """
Load from VTK image data file. Load from VTK image data file.
@ -198,15 +206,17 @@ class Grid:
""" """
warnings.warn('Support for ASCII-based geom format will be removed in DAMASK 3.1.0', DeprecationWarning,2) warnings.warn('Support for ASCII-based geom format will be removed in DAMASK 3.1.0', DeprecationWarning,2)
try: if isinstance(fname, (str, Path)):
f = open(fname) f = open(fname)
except TypeError: elif isinstance(fname, TextIO):
f = fname f = fname
else:
raise TypeError
f.seek(0) f.seek(0)
try: try:
header_length,keyword = f.readline().split()[:2] header_length_,keyword = f.readline().split()[:2]
header_length = int(header_length) header_length = int(header_length_)
except ValueError: except ValueError:
header_length,keyword = (-1, 'invalid') header_length,keyword = (-1, 'invalid')
if not keyword.startswith('head') or header_length < 3: if not keyword.startswith('head') or header_length < 3:
@ -215,10 +225,10 @@ class Grid:
comments = [] comments = []
content = f.readlines() content = f.readlines()
for i,line in enumerate(content[:header_length]): for i,line in enumerate(content[:header_length]):
items = line.split('#')[0].lower().strip().split() items: List[str] = line.split('#')[0].lower().strip().split()
key = items[0] if items else '' key = items[0] if items else ''
if key == 'grid': if key == 'grid':
cells = np.array([ int(dict(zip(items[1::2],items[2::2]))[i]) for i in ['a','b','c']]) cells = np.array([int(dict(zip(items[1::2],items[2::2]))[i]) for i in ['a','b','c']])
elif key == 'size': elif key == 'size':
size = np.array([float(dict(zip(items[1::2],items[2::2]))[i]) for i in ['x','y','z']]) size = np.array([float(dict(zip(items[1::2],items[2::2]))[i]) for i in ['x','y','z']])
elif key == 'origin': elif key == 'origin':
@ -226,19 +236,19 @@ class Grid:
else: else:
comments.append(line.strip()) comments.append(line.strip())
material = np.empty(cells.prod()) # initialize as flat array material = np.empty(int(cells.prod())) # initialize as flat array
i = 0 i = 0
for line in content[header_length:]: for line in content[header_length:]:
items = line.split('#')[0].split() items = line.split('#')[0].split()
if len(items) == 3: if len(items) == 3:
if items[1].lower() == 'of': if items[1].lower() == 'of':
items = np.ones(int(items[0]))*float(items[2]) material_entry = np.ones(int(items[0]))*float(items[2])
elif items[1].lower() == 'to': elif items[1].lower() == 'to':
items = np.linspace(int(items[0]),int(items[2]), material_entry = np.linspace(int(items[0]),int(items[2]),
abs(int(items[2])-int(items[0]))+1,dtype=float) abs(int(items[2])-int(items[0]))+1,dtype=float)
else: items = list(map(float,items)) else: material_entry = list(map(float, items))
else: items = list(map(float,items)) else: material_entry = list(map(float, items))
material[i:i+len(items)] = items material[i:i+len(material_entry)] = material_entry
i += len(items) i += len(items)
if i != cells.prod(): if i != cells.prod():
@ -251,7 +261,7 @@ class Grid:
@staticmethod @staticmethod
def load_Neper(fname): def load_Neper(fname: Union[str, Path]) -> "Grid":
""" """
Load from Neper VTK file. Load from Neper VTK file.
@ -276,10 +286,10 @@ class Grid:
@staticmethod @staticmethod
def load_DREAM3D(fname, def load_DREAM3D(fname: str,
feature_IDs=None,cell_data=None, feature_IDs: str = None, cell_data: str = None,
phases='Phases',Euler_angles='EulerAngles', phases: str = 'Phases', Euler_angles: str = 'EulerAngles',
base_group=None): base_group: str = None) -> "Grid":
""" """
Load DREAM.3D (HDF5) file. Load DREAM.3D (HDF5) file.
@ -339,7 +349,7 @@ class Grid:
@staticmethod @staticmethod
def from_table(table,coordinates,labels): def from_table(table: Table, coordinates: str, labels: Union[str, Sequence[str]]) -> "Grid":
""" """
Create grid from ASCII table. Create grid from ASCII table.
@ -372,11 +382,16 @@ class Grid:
@staticmethod @staticmethod
def _find_closest_seed(seeds, weights, point): def _find_closest_seed(seeds: np.ndarray, weights: np.ndarray, point: np.ndarray) -> np.integer:
return np.argmin(np.sum((np.broadcast_to(point,(len(seeds),3))-seeds)**2,axis=1) - weights) return np.argmin(np.sum((np.broadcast_to(point,(len(seeds),3))-seeds)**2,axis=1) - weights)
@staticmethod @staticmethod
def from_Laguerre_tessellation(cells,size,seeds,weights,material=None,periodic=True): def from_Laguerre_tessellation(cells,
size,
seeds,
weights,
material = None,
periodic = True):
""" """
Create grid from Laguerre tessellation. Create grid from Laguerre tessellation.
@ -412,7 +427,6 @@ class Grid:
seeds_p = seeds seeds_p = seeds
coords = grid_filters.coordinates0_point(cells,size).reshape(-1,3) coords = grid_filters.coordinates0_point(cells,size).reshape(-1,3)
pool = mp.Pool(int(os.environ.get('OMP_NUM_THREADS',4))) pool = mp.Pool(int(os.environ.get('OMP_NUM_THREADS',4)))
result = pool.map_async(partial(Grid._find_closest_seed,seeds_p,weights_p), coords) result = pool.map_async(partial(Grid._find_closest_seed,seeds_p,weights_p), coords)
pool.close() pool.close()
@ -428,7 +442,11 @@ class Grid:
@staticmethod @staticmethod
def from_Voronoi_tessellation(cells,size,seeds,material=None,periodic=True): def from_Voronoi_tessellation(cells: np.ndarray,
size: Union[Sequence[float], np.ndarray],
seeds: np.ndarray,
material: np.ndarray = None,
periodic: bool = True) -> "Grid":
""" """
Create grid from Voronoi tessellation. Create grid from Voronoi tessellation.
@ -509,7 +527,12 @@ class Grid:
@staticmethod @staticmethod
def from_minimal_surface(cells,size,surface,threshold=0.0,periods=1,materials=(0,1)): def from_minimal_surface(cells: np.ndarray,
size: Union[Sequence[float], np.ndarray],
surface: str,
threshold: float = 0.0,
periods: int = 1,
materials: tuple = (0,1)) -> "Grid":
""" """
Create grid from definition of triply periodic minimal surface. Create grid from definition of triply periodic minimal surface.
@ -595,7 +618,7 @@ class Grid:
) )
def save(self,fname,compress=True): def save(self, fname: Union[str, Path], compress: bool = True):
""" """
Save as VTK image data file. Save as VTK image data file.
@ -614,7 +637,7 @@ class Grid:
v.save(fname if str(fname).endswith('.vti') else str(fname)+'.vti',parallel=False,compress=compress) v.save(fname if str(fname).endswith('.vti') else str(fname)+'.vti',parallel=False,compress=compress)
def save_ASCII(self,fname): def save_ASCII(self, fname: Union[str, TextIO]):
""" """
Save as geom file. Save as geom file.
@ -649,8 +672,14 @@ class Grid:
VTK.from_rectilinear_grid(self.cells,self.size,self.origin).show() VTK.from_rectilinear_grid(self.cells,self.size,self.origin).show()
def add_primitive(self,dimension,center,exponent, def add_primitive(self,
fill=None,R=Rotation(),inverse=False,periodic=True): dimension: np.ndarray,
center: np.ndarray,
exponent: Union[np.ndarray, float],
fill: int = None,
R: Rotation = Rotation(),
inverse: bool = False,
periodic: bool = True) -> "Grid":
""" """
Insert a primitive geometric object at a given position. Insert a primitive geometric object at a given position.
@ -734,7 +763,7 @@ class Grid:
) )
def mirror(self,directions,reflect=False): def mirror(self, directions: Sequence[str], reflect: bool = False) -> "Grid":
""" """
Mirror grid along given directions. Mirror grid along given directions.
@ -769,7 +798,7 @@ class Grid:
if not set(directions).issubset(valid): if not set(directions).issubset(valid):
raise ValueError(f'invalid direction {set(directions).difference(valid)} specified') raise ValueError(f'invalid direction {set(directions).difference(valid)} specified')
limits = [None,None] if reflect else [-2,0] limits: Sequence[Optional[int]] = [None,None] if reflect else [-2,0]
mat = self.material.copy() mat = self.material.copy()
if 'x' in directions: if 'x' in directions:
@ -786,7 +815,7 @@ class Grid:
) )
def flip(self,directions): def flip(self, directions: Sequence[str]) -> "Grid":
""" """
Flip grid along given directions. Flip grid along given directions.
@ -815,7 +844,7 @@ class Grid:
) )
def scale(self,cells,periodic=True): def scale(self, cells: np.ndarray, periodic: bool = True) -> "Grid":
""" """
Scale grid to new cells. Scale grid to new cells.
@ -859,7 +888,7 @@ class Grid:
) )
def clean(self,stencil=3,selection=None,periodic=True): def clean(self, stencil: int = 3, selection: Sequence[float] = None, periodic: bool = True) -> "Grid":
""" """
Smooth grid by selecting most frequent material index within given stencil at each location. Smooth grid by selecting most frequent material index within given stencil at each location.
@ -878,7 +907,7 @@ class Grid:
Updated grid-based geometry. Updated grid-based geometry.
""" """
def mostFrequent(arr,selection=None): def mostFrequent(arr, selection = None):
me = arr[arr.size//2] me = arr[arr.size//2]
if selection is None or me in selection: if selection is None or me in selection:
unique, inverse = np.unique(arr, return_inverse=True) unique, inverse = np.unique(arr, return_inverse=True)
@ -899,7 +928,7 @@ class Grid:
) )
def renumber(self): def renumber(self) -> "Grid":
""" """
Renumber sorted material indices as 0,...,N-1. Renumber sorted material indices as 0,...,N-1.
@ -918,7 +947,7 @@ class Grid:
) )
def rotate(self,R,fill=None): def rotate(self, R: Rotation, fill: Union[int, float] = None) -> "Grid":
""" """
Rotate grid (pad if required). Rotate grid (pad if required).
@ -956,7 +985,7 @@ class Grid:
) )
def canvas(self,cells=None,offset=None,fill=None): def canvas(self, cells = None, offset = None, fill = None):
""" """
Crop or enlarge/pad grid. Crop or enlarge/pad grid.
@ -1008,7 +1037,7 @@ class Grid:
) )
def substitute(self,from_material,to_material): def substitute(self, from_material: np.ndarray, to_material: np.ndarray) -> "Grid":
""" """
Substitute material indices. Substitute material indices.
@ -1025,7 +1054,7 @@ class Grid:
Updated grid-based geometry. Updated grid-based geometry.
""" """
def mp(entry,mapper): def mp(entry, mapper):
return mapper[entry] if entry in mapper else entry return mapper[entry] if entry in mapper else entry
mp = np.vectorize(mp) mp = np.vectorize(mp)
@ -1038,7 +1067,7 @@ class Grid:
) )
def sort(self): def sort(self) -> "Grid":
""" """
Sort material indices such that min(material) is located at (0,0,0). Sort material indices such that min(material) is located at (0,0,0).
@ -1060,7 +1089,11 @@ class Grid:
) )
def vicinity_offset(self,vicinity=1,offset=None,trigger=[],periodic=True): def vicinity_offset(self,
vicinity: int = 1,
offset: int = None,
trigger: Sequence[int] = [],
periodic: bool = True) -> "Grid":
""" """
Offset material index of points in the vicinity of xxx. Offset material index of points in the vicinity of xxx.
@ -1088,8 +1121,7 @@ class Grid:
Updated grid-based geometry. Updated grid-based geometry.
""" """
def tainted_neighborhood(stencil,trigger): def tainted_neighborhood(stencil, trigger):
me = stencil[stencil.shape[0]//2] me = stencil[stencil.shape[0]//2]
return np.any(stencil != me if len(trigger) == 0 else return np.any(stencil != me if len(trigger) == 0 else
np.in1d(stencil,np.array(list(set(trigger) - {me})))) np.in1d(stencil,np.array(list(set(trigger) - {me}))))
@ -1108,7 +1140,7 @@ class Grid:
) )
def get_grain_boundaries(self,periodic=True,directions='xyz'): def get_grain_boundaries(self, periodic = True, directions = 'xyz'):
""" """
Create VTK unstructured grid containing grain boundaries. Create VTK unstructured grid containing grain boundaries.

View File

@ -130,9 +130,9 @@ def gradient(size: _np.ndarray, f: _np.ndarray) -> _np.ndarray:
return _np.fft.irfftn(grad_,axes=(0,1,2),s=f.shape[:3]) return _np.fft.irfftn(grad_,axes=(0,1,2),s=f.shape[:3])
def coordinates0_point(cells: Union[ _np.ndarray,Sequence[int]], def coordinates0_point(cells: Union[_np.ndarray, Sequence[int]],
size: _np.ndarray, size: Union[_np.ndarray, Sequence[float]],
origin: _np.ndarray = _np.zeros(3)) -> _np.ndarray: origin: Union[_np.ndarray, Sequence[float]] = _np.zeros(3)) -> _np.ndarray:
""" """
Cell center positions (undeformed). Cell center positions (undeformed).
@ -305,9 +305,9 @@ def cellsSizeOrigin_coordinates0_point(coordinates0: _np.ndarray,
return (cells,size,origin) return (cells,size,origin)
def coordinates0_node(cells: Union[_np.ndarray,Sequence[int]], def coordinates0_node(cells: Union[_np.ndarray, Sequence[int]],
size: _np.ndarray, size: Union[_np.ndarray, Sequence[int]],
origin: _np.ndarray = _np.zeros(3)) -> _np.ndarray: origin: Union[_np.ndarray, Sequence[int]] = _np.zeros(3)) -> _np.ndarray:
""" """
Nodal positions (undeformed). Nodal positions (undeformed).