new delete function for VTK

follows 'damask.Table' and makes GeomGrid.load_SPPARKS easier
This commit is contained in:
Martin Diehl 2023-11-28 11:50:12 +01:00
parent 9710ec11f0
commit 9aa68e83a2
3 changed files with 47 additions and 5 deletions

View File

@ -457,7 +457,7 @@ class VTK:
data: Union[None, np.ndarray, np.ma.MaskedArray] = None, data: Union[None, np.ndarray, np.ma.MaskedArray] = None,
info: Optional[str] = None, info: Optional[str] = None,
*, *,
table: Optional['Table'] = None): table: Optional['Table'] = None) -> 'VTK':
""" """
Add new or replace existing point or cell data. Add new or replace existing point or cell data.
@ -534,7 +534,6 @@ class VTK:
else: else:
raise TypeError raise TypeError
return dup return dup
@ -581,6 +580,42 @@ class VTK:
raise KeyError(f'array "{label}" not found') raise KeyError(f'array "{label}" not found')
def delete(self,
label: str) -> 'VTK':
"""
Delete either cell or point data.
Cell data takes precedence over point data, i.e. this
function assumes that labels are unique among cell and
point data.
Parameters
----------
label : str
Data label.
Returns
-------
updated : damask.VTK
Updated VTK-based geometry.
"""
dup = self.copy()
cell_data = dup.vtk_data.GetCellData()
for a in range(cell_data.GetNumberOfArrays()):
if cell_data.GetArrayName(a) == label:
dup.vtk_data.GetCellData().RemoveArray(label)
return dup
point_data = self.vtk_data.GetPointData()
for a in range(point_data.GetNumberOfArrays()):
if point_data.GetArrayName(a) == label:
dup.vtk_data.GetPointData().RemoveArray(label)
return dup
raise KeyError(f'array "{label}" not found')
def show(self, def show(self,
label: Optional[str] = None, label: Optional[str] = None,
colormap: Union[Colormap, str] = 'cividis'): colormap: Union[Colormap, str] = 'cividis'):

View File

@ -94,9 +94,9 @@ class TestGeomGrid:
def test_save_load_SPPARKS(self,res_path,tmp_path): def test_save_load_SPPARKS(self,res_path,tmp_path):
v = VTK.load(res_path/'SPPARKS_dump.vti') v = VTK.load(res_path/'SPPARKS_dump.vti')
v.set('material',v.get('spins')).save(tmp_path/'SPPARKS_dump.vti',parallel=False) v.set('material',v.get('spins')).delete('spins').save(tmp_path/'SPPARKS_dump.vti',parallel=False)
assert np.all(GeomGrid.load_SPPARKS(res_path/'SPPARKS_dump.vti').material == \ assert GeomGrid.load_SPPARKS(res_path/'SPPARKS_dump.vti') == \
GeomGrid.load(tmp_path/'SPPARKS_dump.vti').material) GeomGrid.load(tmp_path/'SPPARKS_dump.vti')
def test_invalid_origin(self,default): def test_invalid_origin(self,default):
with pytest.raises(ValueError): with pytest.raises(ValueError):

View File

@ -199,6 +199,13 @@ class TestVTK:
mask_manual = default.set('D',np.where(masked.mask,masked.fill_value,masked)) mask_manual = default.set('D',np.where(masked.mask,masked.fill_value,masked))
assert mask_manual == mask_auto assert mask_manual == mask_auto
@pytest.mark.parametrize('mode',['cells','points'])
def test_delete(self,default,mode):
data = np.random.rand(default.N_cells if mode == 'cells' else default.N_points).astype(np.float32)
v = default.set('D',data)
assert np.all(data == v.get('D'))
v = v.delete('D')
assert v == default
@pytest.mark.parametrize('data_type,shape',[(float,(3,)), @pytest.mark.parametrize('data_type,shape',[(float,(3,)),
(float,(3,3)), (float,(3,3)),