Do not strictly require np.ndarrays for grid, size, or origin when not essential for functionality.
This commit is contained in:
parent
e0b4bc6d1e
commit
7d9a4c08ce
|
@ -40,20 +40,20 @@ class VTK:
|
||||||
"""
|
"""
|
||||||
Create VTK of type vtk.vtkRectilinearGrid.
|
Create VTK of type vtk.vtkRectilinearGrid.
|
||||||
|
|
||||||
This is the common type for results from the grid solver.
|
This is the common type for grid solver results.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
grid : numpy.ndarray of shape (3) of np.dtype = int
|
grid : iterable of int, len (3)
|
||||||
Number of cells.
|
Number of cells along each dimension.
|
||||||
size : numpy.ndarray of shape (3)
|
size : iterable of float, len (3)
|
||||||
Physical length.
|
Physical lengths along each dimension.
|
||||||
origin : numpy.ndarray of shape (3), optional
|
origin : iterable of float, len (3), optional
|
||||||
Spatial origin.
|
Spatial origin coordinates.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
vtk_data = vtk.vtkRectilinearGrid()
|
vtk_data = vtk.vtkRectilinearGrid()
|
||||||
vtk_data.SetDimensions(*(grid+1))
|
vtk_data.SetDimensions(*(np.array(grid)+1))
|
||||||
coord = [np_to_vtk(np.linspace(origin[i],origin[i]+size[i],grid[i]+1),deep=True) for i in [0,1,2]]
|
coord = [np_to_vtk(np.linspace(origin[i],origin[i]+size[i],grid[i]+1),deep=True) for i in [0,1,2]]
|
||||||
[coord[i].SetName(n) for i,n in enumerate(['x','y','z'])]
|
[coord[i].SetName(n) for i,n in enumerate(['x','y','z'])]
|
||||||
vtk_data.SetXCoordinates(coord[0])
|
vtk_data.SetXCoordinates(coord[0])
|
||||||
|
@ -68,7 +68,7 @@ class VTK:
|
||||||
"""
|
"""
|
||||||
Create VTK of type vtk.vtkUnstructuredGrid.
|
Create VTK of type vtk.vtkUnstructuredGrid.
|
||||||
|
|
||||||
This is the common type for results from FEM solvers.
|
This is the common type for FEM solver results.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -127,7 +127,7 @@ class VTK:
|
||||||
fname : str or pathlib.Path
|
fname : str or pathlib.Path
|
||||||
Filename for reading. Valid extensions are .vtr, .vtu, .vtp, and .vtk.
|
Filename for reading. Valid extensions are .vtr, .vtu, .vtp, and .vtk.
|
||||||
dataset_type : str, optional
|
dataset_type : str, optional
|
||||||
Name of the vtk.vtkDataSet subclass when opening an .vtk file. Valid types are vtkRectilinearGrid,
|
Name of the vtk.vtkDataSet subclass when opening a .vtk file. Valid types are vtkRectilinearGrid,
|
||||||
vtkUnstructuredGrid, and vtkPolyData.
|
vtkUnstructuredGrid, and vtkPolyData.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -215,12 +215,13 @@ class VTK:
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
data : numpy.ndarray
|
data : numpy.ndarray
|
||||||
Data to add. First dimension need to match either
|
Data to add. First dimension needs to match either
|
||||||
number of cells or number of points
|
number of cells or number of points.
|
||||||
label : str
|
label : str
|
||||||
Data label.
|
Data label.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
N_data = data.shape[0]
|
||||||
N_points = self.vtk_data.GetNumberOfPoints()
|
N_points = self.vtk_data.GetNumberOfPoints()
|
||||||
N_cells = self.vtk_data.GetNumberOfCells()
|
N_cells = self.vtk_data.GetNumberOfCells()
|
||||||
|
|
||||||
|
@ -228,18 +229,16 @@ class VTK:
|
||||||
if label is None:
|
if label is None:
|
||||||
raise ValueError('No label defined for numpy.ndarray')
|
raise ValueError('No label defined for numpy.ndarray')
|
||||||
|
|
||||||
if data.dtype in [np.float64, np.float128]: # avoid large files
|
d = np_to_vtk((data.astype(np.float32) if data.dtype in [np.float64, np.float128]
|
||||||
d = np_to_vtk(num_array=data.astype(np.float32).reshape(data.shape[0],-1),deep=True)
|
else data).reshape(N_data,-1),deep=True) # avoid large files
|
||||||
else:
|
|
||||||
d = np_to_vtk(num_array=data.reshape(data.shape[0],-1),deep=True)
|
|
||||||
d.SetName(label)
|
d.SetName(label)
|
||||||
|
|
||||||
if data.shape[0] == N_cells:
|
if N_data == N_cells:
|
||||||
self.vtk_data.GetCellData().AddArray(d)
|
self.vtk_data.GetCellData().AddArray(d)
|
||||||
elif data.shape[0] == N_points:
|
elif N_data == N_points:
|
||||||
self.vtk_data.GetPointData().AddArray(d)
|
self.vtk_data.GetPointData().AddArray(d)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Invalid shape {data.shape[0]}')
|
raise ValueError(f'Cell / point count ({N_cells} / {N_points}) differs from data ({N_data}).')
|
||||||
elif isinstance(data,pd.DataFrame):
|
elif isinstance(data,pd.DataFrame):
|
||||||
raise NotImplementedError('pd.DataFrame')
|
raise NotImplementedError('pd.DataFrame')
|
||||||
elif isinstance(data,Table):
|
elif isinstance(data,Table):
|
||||||
|
@ -272,7 +271,7 @@ class VTK:
|
||||||
if point_data.GetArrayName(a) == label:
|
if point_data.GetArrayName(a) == label:
|
||||||
return vtk_to_np(point_data.GetArray(a))
|
return vtk_to_np(point_data.GetArray(a))
|
||||||
|
|
||||||
raise ValueError(f'array "{label}" not found')
|
raise ValueError(f'Array "{label}" not found.')
|
||||||
|
|
||||||
|
|
||||||
def get_comments(self):
|
def get_comments(self):
|
||||||
|
|
Loading…
Reference in New Issue