diff --git a/python/damask/_vtk.py b/python/damask/_vtk.py index dbbfb1b10..561726153 100644 --- a/python/damask/_vtk.py +++ b/python/damask/_vtk.py @@ -2,6 +2,7 @@ import os import warnings import multiprocessing as mp from pathlib import Path +from typing import Union, Optional, Literal, List import numpy as np import vtk @@ -20,7 +21,7 @@ class VTK: High-level interface to VTK. """ - def __init__(self,vtk_data): + def __init__(self, vtk_data: vtk.vtkImageData): """ New spatial visualization. @@ -36,7 +37,8 @@ class VTK: @staticmethod - def from_image_data(cells,size,origin=np.zeros(3)): + #ITERABLES PROPER + def from_image_data(cells: np.ndarray, size: np.ndarray, origin: Optional[np.ndarray] = np.zeros(3)) -> "VTK": """ Create VTK of type vtk.vtkImageData. @@ -66,7 +68,7 @@ class VTK: @staticmethod - def from_rectilinear_grid(grid,size,origin=np.zeros(3)): + def from_rectilinear_grid(grid: np.ndarray, size: np.ndarray, origin: np.ndarray = np.zeros(3)) -> "VTK": """ Create VTK of type vtk.vtkRectilinearGrid. @@ -98,7 +100,7 @@ class VTK: @staticmethod - def from_unstructured_grid(nodes,connectivity,cell_type): + def from_unstructured_grid(nodes: np.ndarray, connectivity: np.ndarray, cell_type: str) -> "VTK": """ Create VTK of type vtk.vtkUnstructuredGrid. @@ -138,7 +140,7 @@ class VTK: @staticmethod - def from_poly_data(points): + def from_poly_data(points: np.ndarray) -> "VTK": """ Create VTK of type vtk.polyData. @@ -172,7 +174,8 @@ class VTK: @staticmethod - def load(fname,dataset_type=None): + def load(fname: Union[str, Path], + dataset_type: Literal['vtkImageData', 'vtkRectilinearGrid', 'vtkUnstructuredGrid', 'vtkPolyData'] = None) -> "VTK": """ Load from VTK file. @@ -189,7 +192,7 @@ class VTK: VTK-based geometry from file. """ - if not os.path.isfile(fname): # vtk has a strange error handling + if not os.path.isfile(fname): # vtk has a strange error handling raise FileNotFoundError(f'No such file: {fname}') ext = Path(fname).suffix if ext == '.vtk' or dataset_type is not None: @@ -234,7 +237,7 @@ class VTK: def _write(writer): """Wrapper for parallel writing.""" writer.Write() - def save(self,fname,parallel=True,compress=True): + def save(self, fname: Union[str, Path], parallel: bool = True, compress: bool = True): """ Save as VTK file. @@ -280,7 +283,7 @@ class VTK: # Check https://blog.kitware.com/ghost-and-blanking-visibility-changes/ for missing data # Needs support for damask.Table - def add(self,data,label=None): + def add(self, data: Union[np.ndarray, np.ma.MaskedArray], label: str = None): """ Add data to either cells or points. @@ -327,7 +330,7 @@ class VTK: raise TypeError - def get(self,label): + def get(self, label: str) -> np.ndarray: """ Get either cell or point data. @@ -369,7 +372,7 @@ class VTK: raise ValueError(f'Array "{label}" not found.') - def get_comments(self): + def get_comments(self) -> List[str]: """Return the comments.""" fielddata = self.vtk_data.GetFieldData() for a in range(fielddata.GetNumberOfArrays()): @@ -379,7 +382,7 @@ class VTK: return [] - def set_comments(self,comments): + def set_comments(self, comments: Union[str, List[str]]): """ Set comments. @@ -396,7 +399,7 @@ class VTK: self.vtk_data.GetFieldData().AddArray(s) - def add_comments(self,comments): + def add_comments(self, comments: Union[str, List[str]]): """ Add comments. @@ -409,7 +412,7 @@ class VTK: self.set_comments(self.get_comments() + ([comments] if isinstance(comments,str) else comments)) - def __repr__(self): + def __repr__(self) -> str: """ASCII representation of the VTK data.""" writer = vtk.vtkDataSetWriter() writer.SetHeader(f'# {util.execution_stamp("VTK")}')