added typehints for vtk module

This commit is contained in:
Daniel Otto de Mentock 2022-01-17 15:00:25 +01:00
parent adf7abbda6
commit 7b158ba108
1 changed files with 17 additions and 14 deletions

View File

@ -2,6 +2,7 @@ import os
import warnings
import multiprocessing as mp
from pathlib import Path
from typing import Union, Optional, Literal, List
import numpy as np
import vtk
@ -20,7 +21,7 @@ class VTK:
High-level interface to VTK.
"""
def __init__(self,vtk_data):
def __init__(self, vtk_data: vtk.vtkImageData):
"""
New spatial visualization.
@ -36,7 +37,8 @@ class VTK:
@staticmethod
def from_image_data(cells,size,origin=np.zeros(3)):
#ITERABLES PROPER
def from_image_data(cells: np.ndarray, size: np.ndarray, origin: Optional[np.ndarray] = np.zeros(3)) -> "VTK":
"""
Create VTK of type vtk.vtkImageData.
@ -66,7 +68,7 @@ class VTK:
@staticmethod
def from_rectilinear_grid(grid,size,origin=np.zeros(3)):
def from_rectilinear_grid(grid: np.ndarray, size: np.ndarray, origin: np.ndarray = np.zeros(3)) -> "VTK":
"""
Create VTK of type vtk.vtkRectilinearGrid.
@ -98,7 +100,7 @@ class VTK:
@staticmethod
def from_unstructured_grid(nodes,connectivity,cell_type):
def from_unstructured_grid(nodes: np.ndarray, connectivity: np.ndarray, cell_type: str) -> "VTK":
"""
Create VTK of type vtk.vtkUnstructuredGrid.
@ -138,7 +140,7 @@ class VTK:
@staticmethod
def from_poly_data(points):
def from_poly_data(points: np.ndarray) -> "VTK":
"""
Create VTK of type vtk.polyData.
@ -172,7 +174,8 @@ class VTK:
@staticmethod
def load(fname,dataset_type=None):
def load(fname: Union[str, Path],
dataset_type: Literal['vtkImageData', 'vtkRectilinearGrid', 'vtkUnstructuredGrid', 'vtkPolyData'] = None) -> "VTK":
"""
Load from VTK file.
@ -189,7 +192,7 @@ class VTK:
VTK-based geometry from file.
"""
if not os.path.isfile(fname): # vtk has a strange error handling
if not os.path.isfile(fname): # vtk has a strange error handling
raise FileNotFoundError(f'No such file: {fname}')
ext = Path(fname).suffix
if ext == '.vtk' or dataset_type is not None:
@ -234,7 +237,7 @@ class VTK:
def _write(writer):
"""Wrapper for parallel writing."""
writer.Write()
def save(self,fname,parallel=True,compress=True):
def save(self, fname: Union[str, Path], parallel: bool = True, compress: bool = True):
"""
Save as VTK file.
@ -280,7 +283,7 @@ class VTK:
# Check https://blog.kitware.com/ghost-and-blanking-visibility-changes/ for missing data
# Needs support for damask.Table
def add(self,data,label=None):
def add(self, data: Union[np.ndarray, np.ma.MaskedArray], label: str = None):
"""
Add data to either cells or points.
@ -327,7 +330,7 @@ class VTK:
raise TypeError
def get(self,label):
def get(self, label: str) -> np.ndarray:
"""
Get either cell or point data.
@ -369,7 +372,7 @@ class VTK:
raise ValueError(f'Array "{label}" not found.')
def get_comments(self):
def get_comments(self) -> List[str]:
"""Return the comments."""
fielddata = self.vtk_data.GetFieldData()
for a in range(fielddata.GetNumberOfArrays()):
@ -379,7 +382,7 @@ class VTK:
return []
def set_comments(self,comments):
def set_comments(self, comments: Union[str, List[str]]):
"""
Set comments.
@ -396,7 +399,7 @@ class VTK:
self.vtk_data.GetFieldData().AddArray(s)
def add_comments(self,comments):
def add_comments(self, comments: Union[str, List[str]]):
"""
Add comments.
@ -409,7 +412,7 @@ class VTK:
self.set_comments(self.get_comments() + ([comments] if isinstance(comments,str) else comments))
def __repr__(self):
def __repr__(self) -> str:
"""ASCII representation of the VTK data."""
writer = vtk.vtkDataSetWriter()
writer.SetHeader(f'# {util.execution_stamp("VTK")}')