added typehints for vtk module

This commit is contained in:
Daniel Otto de Mentock 2022-01-17 15:00:25 +01:00
parent adf7abbda6
commit 7b158ba108
1 changed files with 17 additions and 14 deletions

View File

@ -2,6 +2,7 @@ import os
import warnings import warnings
import multiprocessing as mp import multiprocessing as mp
from pathlib import Path from pathlib import Path
from typing import Union, Optional, Literal, List
import numpy as np import numpy as np
import vtk import vtk
@ -20,7 +21,7 @@ class VTK:
High-level interface to VTK. High-level interface to VTK.
""" """
def __init__(self,vtk_data): def __init__(self, vtk_data: vtk.vtkImageData):
""" """
New spatial visualization. New spatial visualization.
@ -36,7 +37,8 @@ class VTK:
@staticmethod @staticmethod
def from_image_data(cells,size,origin=np.zeros(3)): #ITERABLES PROPER
def from_image_data(cells: np.ndarray, size: np.ndarray, origin: Optional[np.ndarray] = np.zeros(3)) -> "VTK":
""" """
Create VTK of type vtk.vtkImageData. Create VTK of type vtk.vtkImageData.
@ -66,7 +68,7 @@ class VTK:
@staticmethod @staticmethod
def from_rectilinear_grid(grid,size,origin=np.zeros(3)): def from_rectilinear_grid(grid: np.ndarray, size: np.ndarray, origin: np.ndarray = np.zeros(3)) -> "VTK":
""" """
Create VTK of type vtk.vtkRectilinearGrid. Create VTK of type vtk.vtkRectilinearGrid.
@ -98,7 +100,7 @@ class VTK:
@staticmethod @staticmethod
def from_unstructured_grid(nodes,connectivity,cell_type): def from_unstructured_grid(nodes: np.ndarray, connectivity: np.ndarray, cell_type: str) -> "VTK":
""" """
Create VTK of type vtk.vtkUnstructuredGrid. Create VTK of type vtk.vtkUnstructuredGrid.
@ -138,7 +140,7 @@ class VTK:
@staticmethod @staticmethod
def from_poly_data(points): def from_poly_data(points: np.ndarray) -> "VTK":
""" """
Create VTK of type vtk.polyData. Create VTK of type vtk.polyData.
@ -172,7 +174,8 @@ class VTK:
@staticmethod @staticmethod
def load(fname,dataset_type=None): def load(fname: Union[str, Path],
dataset_type: Literal['vtkImageData', 'vtkRectilinearGrid', 'vtkUnstructuredGrid', 'vtkPolyData'] = None) -> "VTK":
""" """
Load from VTK file. Load from VTK file.
@ -234,7 +237,7 @@ class VTK:
def _write(writer): def _write(writer):
"""Wrapper for parallel writing.""" """Wrapper for parallel writing."""
writer.Write() writer.Write()
def save(self,fname,parallel=True,compress=True): def save(self, fname: Union[str, Path], parallel: bool = True, compress: bool = True):
""" """
Save as VTK file. Save as VTK file.
@ -280,7 +283,7 @@ class VTK:
# Check https://blog.kitware.com/ghost-and-blanking-visibility-changes/ for missing data # Check https://blog.kitware.com/ghost-and-blanking-visibility-changes/ for missing data
# Needs support for damask.Table # Needs support for damask.Table
def add(self,data,label=None): def add(self, data: Union[np.ndarray, np.ma.MaskedArray], label: str = None):
""" """
Add data to either cells or points. Add data to either cells or points.
@ -327,7 +330,7 @@ class VTK:
raise TypeError raise TypeError
def get(self,label): def get(self, label: str) -> np.ndarray:
""" """
Get either cell or point data. Get either cell or point data.
@ -369,7 +372,7 @@ class VTK:
raise ValueError(f'Array "{label}" not found.') raise ValueError(f'Array "{label}" not found.')
def get_comments(self): def get_comments(self) -> List[str]:
"""Return the comments.""" """Return the comments."""
fielddata = self.vtk_data.GetFieldData() fielddata = self.vtk_data.GetFieldData()
for a in range(fielddata.GetNumberOfArrays()): for a in range(fielddata.GetNumberOfArrays()):
@ -379,7 +382,7 @@ class VTK:
return [] return []
def set_comments(self,comments): def set_comments(self, comments: Union[str, List[str]]):
""" """
Set comments. Set comments.
@ -396,7 +399,7 @@ class VTK:
self.vtk_data.GetFieldData().AddArray(s) self.vtk_data.GetFieldData().AddArray(s)
def add_comments(self,comments): def add_comments(self, comments: Union[str, List[str]]):
""" """
Add comments. Add comments.
@ -409,7 +412,7 @@ class VTK:
self.set_comments(self.get_comments() + ([comments] if isinstance(comments,str) else comments)) self.set_comments(self.get_comments() + ([comments] if isinstance(comments,str) else comments))
def __repr__(self): def __repr__(self) -> str:
"""ASCII representation of the VTK data.""" """ASCII representation of the VTK data."""
writer = vtk.vtkDataSetWriter() writer = vtk.vtkDataSetWriter()
writer.SetHeader(f'# {util.execution_stamp("VTK")}') writer.SetHeader(f'# {util.execution_stamp("VTK")}')