From 9aa68e83a20702f80d141f0be2205ad5dcc102e2 Mon Sep 17 00:00:00 2001 From: Martin Diehl Date: Tue, 28 Nov 2023 11:50:12 +0100 Subject: [PATCH] new delete function for VTK follows 'damask.Table' and makes GeomGrid.load_SPPARKS easier --- python/damask/_vtk.py | 39 +++++++++++++++++++++++++++++++++-- python/tests/test_GeomGrid.py | 6 +++--- python/tests/test_VTK.py | 7 +++++++ 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/python/damask/_vtk.py b/python/damask/_vtk.py index f8a58d73b..920e27781 100644 --- a/python/damask/_vtk.py +++ b/python/damask/_vtk.py @@ -457,7 +457,7 @@ class VTK: data: Union[None, np.ndarray, np.ma.MaskedArray] = None, info: Optional[str] = None, *, - table: Optional['Table'] = None): + table: Optional['Table'] = None) -> 'VTK': """ Add new or replace existing point or cell data. @@ -534,7 +534,6 @@ class VTK: else: raise TypeError - return dup @@ -581,6 +580,42 @@ class VTK: raise KeyError(f'array "{label}" not found') + def delete(self, + label: str) -> 'VTK': + """ + Delete either cell or point data. + + Cell data takes precedence over point data, i.e. this + function assumes that labels are unique among cell and + point data. + + Parameters + ---------- + label : str + Data label. + + Returns + ------- + updated : damask.VTK + Updated VTK-based geometry. + + """ + dup = self.copy() + cell_data = dup.vtk_data.GetCellData() + for a in range(cell_data.GetNumberOfArrays()): + if cell_data.GetArrayName(a) == label: + dup.vtk_data.GetCellData().RemoveArray(label) + return dup + + point_data = self.vtk_data.GetPointData() + for a in range(point_data.GetNumberOfArrays()): + if point_data.GetArrayName(a) == label: + dup.vtk_data.GetPointData().RemoveArray(label) + return dup + + raise KeyError(f'array "{label}" not found') + + def show(self, label: Optional[str] = None, colormap: Union[Colormap, str] = 'cividis'): diff --git a/python/tests/test_GeomGrid.py b/python/tests/test_GeomGrid.py index 76f30701f..b91c73e5c 100644 --- a/python/tests/test_GeomGrid.py +++ b/python/tests/test_GeomGrid.py @@ -94,9 +94,9 @@ class TestGeomGrid: def test_save_load_SPPARKS(self,res_path,tmp_path): v = VTK.load(res_path/'SPPARKS_dump.vti') - v.set('material',v.get('spins')).save(tmp_path/'SPPARKS_dump.vti',parallel=False) - assert np.all(GeomGrid.load_SPPARKS(res_path/'SPPARKS_dump.vti').material == \ - GeomGrid.load(tmp_path/'SPPARKS_dump.vti').material) + v.set('material',v.get('spins')).delete('spins').save(tmp_path/'SPPARKS_dump.vti',parallel=False) + assert GeomGrid.load_SPPARKS(res_path/'SPPARKS_dump.vti') == \ + GeomGrid.load(tmp_path/'SPPARKS_dump.vti') def test_invalid_origin(self,default): with pytest.raises(ValueError): diff --git a/python/tests/test_VTK.py b/python/tests/test_VTK.py index 0368d1d16..e3cb64ca5 100644 --- a/python/tests/test_VTK.py +++ b/python/tests/test_VTK.py @@ -199,6 +199,13 @@ class TestVTK: mask_manual = default.set('D',np.where(masked.mask,masked.fill_value,masked)) assert mask_manual == mask_auto + @pytest.mark.parametrize('mode',['cells','points']) + def test_delete(self,default,mode): + data = np.random.rand(default.N_cells if mode == 'cells' else default.N_points).astype(np.float32) + v = default.set('D',data) + assert np.all(data == v.get('D')) + v = v.delete('D') + assert v == default @pytest.mark.parametrize('data_type,shape',[(float,(3,)), (float,(3,3)),