added fist typehints for _grid module

This commit is contained in:
Daniel Otto de Mentock 2021-12-06 14:22:52 +01:00
parent 65c4417a20
commit 5558c301fa
2 changed files with 101 additions and 69 deletions

View File

@ -3,6 +3,8 @@ import copy
import warnings
import multiprocessing as mp
from functools import partial
from typing import Union, Optional, TextIO, List, Sequence
from pathlib import Path
import numpy as np
import pandas as pd
@ -13,7 +15,7 @@ from . import VTK
from . import util
from . import grid_filters
from . import Rotation
from . import Table
class Grid:
"""
@ -25,7 +27,11 @@ class Grid:
the physical size.
"""
def __init__(self,material,size,origin=[0.0,0.0,0.0],comments=[]):
def __init__(self,
material: np.ndarray,
size,
origin = [0.0,0.0,0.0],
comments = []):
"""
New geometry definition for grid solvers.
@ -43,12 +49,12 @@ class Grid:
"""
self.material = material
self.size = size
self.origin = origin
self._size = size
self._origin = origin
self.comments = comments
def __repr__(self):
def __repr__(self) -> str:
"""Basic information on grid definition."""
mat_min = np.nanmin(self.material)
mat_max = np.nanmax(self.material)
@ -62,14 +68,14 @@ class Grid:
])
def __copy__(self):
def __copy__(self) -> "Grid":
"""Create deep copy."""
return copy.deepcopy(self)
copy = __copy__
def __eq__(self,other):
def __eq__(self, other):
"""
Test equality of other.
@ -79,6 +85,8 @@ class Grid:
Grid to compare self against.
"""
if not isinstance(other, Grid):
raise TypeError
return (np.allclose(other.size,self.size)
and np.allclose(other.origin,self.origin)
and np.all(other.cells == self.cells)
@ -86,15 +94,15 @@ class Grid:
@property
def material(self):
def material(self) -> np.ndarray:
"""Material indices."""
return self._material
@material.setter
def material(self,material):
def material(self, material: np.ndarray):
if len(material.shape) != 3:
raise ValueError(f'invalid material shape {material.shape}')
elif material.dtype not in np.sctypes['float'] + np.sctypes['int']:
elif material.dtype not in np.sctypes['float'] and material.dtype not in np.sctypes['int']:
raise TypeError(f'invalid material data type {material.dtype}')
else:
self._material = np.copy(material)
@ -105,53 +113,53 @@ class Grid:
@property
def size(self):
def size(self) -> np.ndarray:
"""Physical size of grid in meter."""
return self._size
@size.setter
def size(self,size):
def size(self, size: Union[Sequence[float], np.ndarray]):
if len(size) != 3 or any(np.array(size) < 0):
raise ValueError(f'invalid size {size}')
else:
self._size = np.array(size)
@property
def origin(self):
def origin(self) -> Union[Sequence[float], np.ndarray]:
"""Coordinates of grid origin in meter."""
return self._origin
@origin.setter
def origin(self,origin):
def origin(self, origin: np.ndarray):
if len(origin) != 3:
raise ValueError(f'invalid origin {origin}')
else:
self._origin = np.array(origin)
@property
def comments(self):
def comments(self) -> List[str]:
"""Comments, e.g. history of operations."""
return self._comments
@comments.setter
def comments(self,comments):
def comments(self, comments: Union[str, Sequence[str]]):
self._comments = [str(c) for c in comments] if isinstance(comments,list) else [str(comments)]
@property
def cells(self):
def cells(self) -> np.ndarray:
"""Number of cells in x,y,z direction."""
return np.asarray(self.material.shape)
@property
def N_materials(self):
def N_materials(self) -> int:
"""Number of (unique) material indices within grid."""
return np.unique(self.material).size
@staticmethod
def load(fname):
def load(fname: Union[str, Path]) -> "Grid":
"""
Load from VTK image data file.
@ -198,15 +206,17 @@ class Grid:
"""
warnings.warn('Support for ASCII-based geom format will be removed in DAMASK 3.1.0', DeprecationWarning,2)
try:
if isinstance(fname, (str, Path)):
f = open(fname)
except TypeError:
elif isinstance(fname, TextIO):
f = fname
else:
raise TypeError
f.seek(0)
try:
header_length,keyword = f.readline().split()[:2]
header_length = int(header_length)
header_length_,keyword = f.readline().split()[:2]
header_length = int(header_length_)
except ValueError:
header_length,keyword = (-1, 'invalid')
if not keyword.startswith('head') or header_length < 3:
@ -215,10 +225,10 @@ class Grid:
comments = []
content = f.readlines()
for i,line in enumerate(content[:header_length]):
items = line.split('#')[0].lower().strip().split()
items: List[str] = line.split('#')[0].lower().strip().split()
key = items[0] if items else ''
if key == 'grid':
cells = np.array([ int(dict(zip(items[1::2],items[2::2]))[i]) for i in ['a','b','c']])
cells = np.array([int(dict(zip(items[1::2],items[2::2]))[i]) for i in ['a','b','c']])
elif key == 'size':
size = np.array([float(dict(zip(items[1::2],items[2::2]))[i]) for i in ['x','y','z']])
elif key == 'origin':
@ -226,19 +236,19 @@ class Grid:
else:
comments.append(line.strip())
material = np.empty(cells.prod()) # initialize as flat array
material = np.empty(int(cells.prod())) # initialize as flat array
i = 0
for line in content[header_length:]:
items = line.split('#')[0].split()
if len(items) == 3:
if items[1].lower() == 'of':
items = np.ones(int(items[0]))*float(items[2])
if items[1].lower() == 'of':
material_entry = np.ones(int(items[0]))*float(items[2])
elif items[1].lower() == 'to':
items = np.linspace(int(items[0]),int(items[2]),
material_entry = np.linspace(int(items[0]),int(items[2]),
abs(int(items[2])-int(items[0]))+1,dtype=float)
else: items = list(map(float,items))
else: items = list(map(float,items))
material[i:i+len(items)] = items
else: material_entry = list(map(float, items))
else: material_entry = list(map(float, items))
material[i:i+len(material_entry)] = material_entry
i += len(items)
if i != cells.prod():
@ -251,7 +261,7 @@ class Grid:
@staticmethod
def load_Neper(fname):
def load_Neper(fname: Union[str, Path]) -> "Grid":
"""
Load from Neper VTK file.
@ -276,10 +286,10 @@ class Grid:
@staticmethod
def load_DREAM3D(fname,
feature_IDs=None,cell_data=None,
phases='Phases',Euler_angles='EulerAngles',
base_group=None):
def load_DREAM3D(fname: str,
feature_IDs: str = None, cell_data: str = None,
phases: str = 'Phases', Euler_angles: str = 'EulerAngles',
base_group: str = None) -> "Grid":
"""
Load DREAM.3D (HDF5) file.
@ -339,7 +349,7 @@ class Grid:
@staticmethod
def from_table(table,coordinates,labels):
def from_table(table: Table, coordinates: str, labels: Union[str, Sequence[str]]) -> "Grid":
"""
Create grid from ASCII table.
@ -372,11 +382,16 @@ class Grid:
@staticmethod
def _find_closest_seed(seeds, weights, point):
def _find_closest_seed(seeds: np.ndarray, weights: np.ndarray, point: np.ndarray) -> np.integer:
return np.argmin(np.sum((np.broadcast_to(point,(len(seeds),3))-seeds)**2,axis=1) - weights)
@staticmethod
def from_Laguerre_tessellation(cells,size,seeds,weights,material=None,periodic=True):
def from_Laguerre_tessellation(cells,
size,
seeds,
weights,
material = None,
periodic = True):
"""
Create grid from Laguerre tessellation.
@ -412,7 +427,6 @@ class Grid:
seeds_p = seeds
coords = grid_filters.coordinates0_point(cells,size).reshape(-1,3)
pool = mp.Pool(int(os.environ.get('OMP_NUM_THREADS',4)))
result = pool.map_async(partial(Grid._find_closest_seed,seeds_p,weights_p), coords)
pool.close()
@ -428,7 +442,11 @@ class Grid:
@staticmethod
def from_Voronoi_tessellation(cells,size,seeds,material=None,periodic=True):
def from_Voronoi_tessellation(cells: np.ndarray,
size: Union[Sequence[float], np.ndarray],
seeds: np.ndarray,
material: np.ndarray = None,
periodic: bool = True) -> "Grid":
"""
Create grid from Voronoi tessellation.
@ -509,7 +527,12 @@ class Grid:
@staticmethod
def from_minimal_surface(cells,size,surface,threshold=0.0,periods=1,materials=(0,1)):
def from_minimal_surface(cells: np.ndarray,
size: Union[Sequence[float], np.ndarray],
surface: str,
threshold: float = 0.0,
periods: int = 1,
materials: tuple = (0,1)) -> "Grid":
"""
Create grid from definition of triply periodic minimal surface.
@ -595,7 +618,7 @@ class Grid:
)
def save(self,fname,compress=True):
def save(self, fname: Union[str, Path], compress: bool = True):
"""
Save as VTK image data file.
@ -614,7 +637,7 @@ class Grid:
v.save(fname if str(fname).endswith('.vti') else str(fname)+'.vti',parallel=False,compress=compress)
def save_ASCII(self,fname):
def save_ASCII(self, fname: Union[str, TextIO]):
"""
Save as geom file.
@ -649,8 +672,14 @@ class Grid:
VTK.from_rectilinear_grid(self.cells,self.size,self.origin).show()
def add_primitive(self,dimension,center,exponent,
fill=None,R=Rotation(),inverse=False,periodic=True):
def add_primitive(self,
dimension: np.ndarray,
center: np.ndarray,
exponent: Union[np.ndarray, float],
fill: int = None,
R: Rotation = Rotation(),
inverse: bool = False,
periodic: bool = True) -> "Grid":
"""
Insert a primitive geometric object at a given position.
@ -734,7 +763,7 @@ class Grid:
)
def mirror(self,directions,reflect=False):
def mirror(self, directions: Sequence[str], reflect: bool = False) -> "Grid":
"""
Mirror grid along given directions.
@ -769,7 +798,7 @@ class Grid:
if not set(directions).issubset(valid):
raise ValueError(f'invalid direction {set(directions).difference(valid)} specified')
limits = [None,None] if reflect else [-2,0]
limits: Sequence[Optional[int]] = [None,None] if reflect else [-2,0]
mat = self.material.copy()
if 'x' in directions:
@ -786,7 +815,7 @@ class Grid:
)
def flip(self,directions):
def flip(self, directions: Sequence[str]) -> "Grid":
"""
Flip grid along given directions.
@ -815,7 +844,7 @@ class Grid:
)
def scale(self,cells,periodic=True):
def scale(self, cells: np.ndarray, periodic: bool = True) -> "Grid":
"""
Scale grid to new cells.
@ -859,7 +888,7 @@ class Grid:
)
def clean(self,stencil=3,selection=None,periodic=True):
def clean(self, stencil: int = 3, selection: Sequence[float] = None, periodic: bool = True) -> "Grid":
"""
Smooth grid by selecting most frequent material index within given stencil at each location.
@ -878,7 +907,7 @@ class Grid:
Updated grid-based geometry.
"""
def mostFrequent(arr,selection=None):
def mostFrequent(arr, selection = None):
me = arr[arr.size//2]
if selection is None or me in selection:
unique, inverse = np.unique(arr, return_inverse=True)
@ -899,7 +928,7 @@ class Grid:
)
def renumber(self):
def renumber(self) -> "Grid":
"""
Renumber sorted material indices as 0,...,N-1.
@ -918,7 +947,7 @@ class Grid:
)
def rotate(self,R,fill=None):
def rotate(self, R: Rotation, fill: Union[int, float] = None) -> "Grid":
"""
Rotate grid (pad if required).
@ -956,7 +985,7 @@ class Grid:
)
def canvas(self,cells=None,offset=None,fill=None):
def canvas(self, cells = None, offset = None, fill = None):
"""
Crop or enlarge/pad grid.
@ -1008,7 +1037,7 @@ class Grid:
)
def substitute(self,from_material,to_material):
def substitute(self, from_material: np.ndarray, to_material: np.ndarray) -> "Grid":
"""
Substitute material indices.
@ -1025,7 +1054,7 @@ class Grid:
Updated grid-based geometry.
"""
def mp(entry,mapper):
def mp(entry, mapper):
return mapper[entry] if entry in mapper else entry
mp = np.vectorize(mp)
@ -1038,7 +1067,7 @@ class Grid:
)
def sort(self):
def sort(self) -> "Grid":
"""
Sort material indices such that min(material) is located at (0,0,0).
@ -1060,7 +1089,11 @@ class Grid:
)
def vicinity_offset(self,vicinity=1,offset=None,trigger=[],periodic=True):
def vicinity_offset(self,
vicinity: int = 1,
offset: int = None,
trigger: Sequence[int] = [],
periodic: bool = True) -> "Grid":
"""
Offset material index of points in the vicinity of xxx.
@ -1088,8 +1121,7 @@ class Grid:
Updated grid-based geometry.
"""
def tainted_neighborhood(stencil,trigger):
def tainted_neighborhood(stencil, trigger):
me = stencil[stencil.shape[0]//2]
return np.any(stencil != me if len(trigger) == 0 else
np.in1d(stencil,np.array(list(set(trigger) - {me}))))
@ -1108,7 +1140,7 @@ class Grid:
)
def get_grain_boundaries(self,periodic=True,directions='xyz'):
def get_grain_boundaries(self, periodic = True, directions = 'xyz'):
"""
Create VTK unstructured grid containing grain boundaries.

View File

@ -130,9 +130,9 @@ def gradient(size: _np.ndarray, f: _np.ndarray) -> _np.ndarray:
return _np.fft.irfftn(grad_,axes=(0,1,2),s=f.shape[:3])
def coordinates0_point(cells: Union[ _np.ndarray,Sequence[int]],
size: _np.ndarray,
origin: _np.ndarray = _np.zeros(3)) -> _np.ndarray:
def coordinates0_point(cells: Union[_np.ndarray, Sequence[int]],
size: Union[_np.ndarray, Sequence[float]],
origin: Union[_np.ndarray, Sequence[float]] = _np.zeros(3)) -> _np.ndarray:
"""
Cell center positions (undeformed).
@ -305,9 +305,9 @@ def cellsSizeOrigin_coordinates0_point(coordinates0: _np.ndarray,
return (cells,size,origin)
def coordinates0_node(cells: Union[_np.ndarray,Sequence[int]],
size: _np.ndarray,
origin: _np.ndarray = _np.zeros(3)) -> _np.ndarray:
def coordinates0_node(cells: Union[_np.ndarray, Sequence[int]],
size: Union[_np.ndarray, Sequence[int]],
origin: Union[_np.ndarray, Sequence[int]] = _np.zeros(3)) -> _np.ndarray:
"""
Nodal positions (undeformed).