Do not strictly require np.ndarrays for grid, size, or origin when not essential for functionality.

This commit is contained in:
Philip Eisenlohr 2020-09-10 00:59:40 +02:00
parent e0b4bc6d1e
commit 7d9a4c08ce
1 changed files with 20 additions and 21 deletions

View File

@ -40,20 +40,20 @@ class VTK:
""" """
Create VTK of type vtk.vtkRectilinearGrid. Create VTK of type vtk.vtkRectilinearGrid.
This is the common type for results from the grid solver. This is the common type for grid solver results.
Parameters Parameters
---------- ----------
grid : numpy.ndarray of shape (3) of np.dtype = int grid : iterable of int, len (3)
Number of cells. Number of cells along each dimension.
size : numpy.ndarray of shape (3) size : iterable of float, len (3)
Physical length. Physical lengths along each dimension.
origin : numpy.ndarray of shape (3), optional origin : iterable of float, len (3), optional
Spatial origin. Spatial origin coordinates.
""" """
vtk_data = vtk.vtkRectilinearGrid() vtk_data = vtk.vtkRectilinearGrid()
vtk_data.SetDimensions(*(grid+1)) vtk_data.SetDimensions(*(np.array(grid)+1))
coord = [np_to_vtk(np.linspace(origin[i],origin[i]+size[i],grid[i]+1),deep=True) for i in [0,1,2]] coord = [np_to_vtk(np.linspace(origin[i],origin[i]+size[i],grid[i]+1),deep=True) for i in [0,1,2]]
[coord[i].SetName(n) for i,n in enumerate(['x','y','z'])] [coord[i].SetName(n) for i,n in enumerate(['x','y','z'])]
vtk_data.SetXCoordinates(coord[0]) vtk_data.SetXCoordinates(coord[0])
@ -68,7 +68,7 @@ class VTK:
""" """
Create VTK of type vtk.vtkUnstructuredGrid. Create VTK of type vtk.vtkUnstructuredGrid.
This is the common type for results from FEM solvers. This is the common type for FEM solver results.
Parameters Parameters
---------- ----------
@ -127,7 +127,7 @@ class VTK:
fname : str or pathlib.Path fname : str or pathlib.Path
Filename for reading. Valid extensions are .vtr, .vtu, .vtp, and .vtk. Filename for reading. Valid extensions are .vtr, .vtu, .vtp, and .vtk.
dataset_type : str, optional dataset_type : str, optional
Name of the vtk.vtkDataSet subclass when opening an .vtk file. Valid types are vtkRectilinearGrid, Name of the vtk.vtkDataSet subclass when opening a .vtk file. Valid types are vtkRectilinearGrid,
vtkUnstructuredGrid, and vtkPolyData. vtkUnstructuredGrid, and vtkPolyData.
""" """
@ -215,12 +215,13 @@ class VTK:
Parameters Parameters
---------- ----------
data : numpy.ndarray data : numpy.ndarray
Data to add. First dimension need to match either Data to add. First dimension needs to match either
number of cells or number of points number of cells or number of points.
label : str label : str
Data label. Data label.
""" """
N_data = data.shape[0]
N_points = self.vtk_data.GetNumberOfPoints() N_points = self.vtk_data.GetNumberOfPoints()
N_cells = self.vtk_data.GetNumberOfCells() N_cells = self.vtk_data.GetNumberOfCells()
@ -228,18 +229,16 @@ class VTK:
if label is None: if label is None:
raise ValueError('No label defined for numpy.ndarray') raise ValueError('No label defined for numpy.ndarray')
if data.dtype in [np.float64, np.float128]: # avoid large files d = np_to_vtk((data.astype(np.float32) if data.dtype in [np.float64, np.float128]
d = np_to_vtk(num_array=data.astype(np.float32).reshape(data.shape[0],-1),deep=True) else data).reshape(N_data,-1),deep=True) # avoid large files
else:
d = np_to_vtk(num_array=data.reshape(data.shape[0],-1),deep=True)
d.SetName(label) d.SetName(label)
if data.shape[0] == N_cells: if N_data == N_cells:
self.vtk_data.GetCellData().AddArray(d) self.vtk_data.GetCellData().AddArray(d)
elif data.shape[0] == N_points: elif N_data == N_points:
self.vtk_data.GetPointData().AddArray(d) self.vtk_data.GetPointData().AddArray(d)
else: else:
raise ValueError(f'Invalid shape {data.shape[0]}') raise ValueError(f'Cell / point count ({N_cells} / {N_points}) differs from data ({N_data}).')
elif isinstance(data,pd.DataFrame): elif isinstance(data,pd.DataFrame):
raise NotImplementedError('pd.DataFrame') raise NotImplementedError('pd.DataFrame')
elif isinstance(data,Table): elif isinstance(data,Table):
@ -272,7 +271,7 @@ class VTK:
if point_data.GetArrayName(a) == label: if point_data.GetArrayName(a) == label:
return vtk_to_np(point_data.GetArray(a)) return vtk_to_np(point_data.GetArray(a))
raise ValueError(f'array "{label}" not found') raise ValueError(f'Array "{label}" not found.')
def get_comments(self): def get_comments(self):