Do not strictly require np.ndarrays for grid, size, or origin when not essential for functionality.

This commit is contained in:
Philip Eisenlohr 2020-09-10 00:59:40 +02:00
parent e0b4bc6d1e
commit 7d9a4c08ce
1 changed files with 20 additions and 21 deletions

View File

@ -40,20 +40,20 @@ class VTK:
"""
Create VTK of type vtk.vtkRectilinearGrid.
This is the common type for results from the grid solver.
This is the common type for grid solver results.
Parameters
----------
grid : numpy.ndarray of shape (3) of np.dtype = int
Number of cells.
size : numpy.ndarray of shape (3)
Physical length.
origin : numpy.ndarray of shape (3), optional
Spatial origin.
grid : iterable of int, len (3)
Number of cells along each dimension.
size : iterable of float, len (3)
Physical lengths along each dimension.
origin : iterable of float, len (3), optional
Spatial origin coordinates.
"""
vtk_data = vtk.vtkRectilinearGrid()
vtk_data.SetDimensions(*(grid+1))
vtk_data.SetDimensions(*(np.array(grid)+1))
coord = [np_to_vtk(np.linspace(origin[i],origin[i]+size[i],grid[i]+1),deep=True) for i in [0,1,2]]
[coord[i].SetName(n) for i,n in enumerate(['x','y','z'])]
vtk_data.SetXCoordinates(coord[0])
@ -68,7 +68,7 @@ class VTK:
"""
Create VTK of type vtk.vtkUnstructuredGrid.
This is the common type for results from FEM solvers.
This is the common type for FEM solver results.
Parameters
----------
@ -127,7 +127,7 @@ class VTK:
fname : str or pathlib.Path
Filename for reading. Valid extensions are .vtr, .vtu, .vtp, and .vtk.
dataset_type : str, optional
Name of the vtk.vtkDataSet subclass when opening an .vtk file. Valid types are vtkRectilinearGrid,
Name of the vtk.vtkDataSet subclass when opening a .vtk file. Valid types are vtkRectilinearGrid,
vtkUnstructuredGrid, and vtkPolyData.
"""
@ -215,12 +215,13 @@ class VTK:
Parameters
----------
data : numpy.ndarray
Data to add. First dimension need to match either
number of cells or number of points
Data to add. First dimension needs to match either
number of cells or number of points.
label : str
Data label.
"""
N_data = data.shape[0]
N_points = self.vtk_data.GetNumberOfPoints()
N_cells = self.vtk_data.GetNumberOfCells()
@ -228,18 +229,16 @@ class VTK:
if label is None:
raise ValueError('No label defined for numpy.ndarray')
if data.dtype in [np.float64, np.float128]: # avoid large files
d = np_to_vtk(num_array=data.astype(np.float32).reshape(data.shape[0],-1),deep=True)
else:
d = np_to_vtk(num_array=data.reshape(data.shape[0],-1),deep=True)
d = np_to_vtk((data.astype(np.float32) if data.dtype in [np.float64, np.float128]
else data).reshape(N_data,-1),deep=True) # avoid large files
d.SetName(label)
if data.shape[0] == N_cells:
if N_data == N_cells:
self.vtk_data.GetCellData().AddArray(d)
elif data.shape[0] == N_points:
elif N_data == N_points:
self.vtk_data.GetPointData().AddArray(d)
else:
raise ValueError(f'Invalid shape {data.shape[0]}')
raise ValueError(f'Cell / point count ({N_cells} / {N_points}) differs from data ({N_data}).')
elif isinstance(data,pd.DataFrame):
raise NotImplementedError('pd.DataFrame')
elif isinstance(data,Table):
@ -272,7 +271,7 @@ class VTK:
if point_data.GetArrayName(a) == label:
return vtk_to_np(point_data.GetArray(a))
raise ValueError(f'array "{label}" not found')
raise ValueError(f'Array "{label}" not found.')
def get_comments(self):