renamed .add to .set to be consistent with Table.set

This commit is contained in:
Philip Eisenlohr 2022-05-11 18:19:10 -04:00
parent 6f07145b75
commit 648d17d381
3 changed files with 46 additions and 30 deletions

View File

@ -401,13 +401,13 @@ class VTK:
# Check https://blog.kitware.com/ghost-and-blanking-visibility-changes/ for missing data
def add(self,
def set(self,
label: str = None,
data: Union[np.ndarray, np.ma.MaskedArray] = None,
*,
table: 'Table' = None):
"""
Add data to either cells or points.
Add (or replace existing) point or cell data.
Data can either be a numpy.array, which requires a corresponding label,
or a damask.Table.
@ -417,11 +417,15 @@ class VTK:
label : str, optional
Label of data array.
data : numpy.ndarray or numpy.ma.MaskedArray, optional
Data to add. First dimension needs to match either
Data to add or replace. First array dimension needs to match either
number of cells or number of points.
table: damask.Table, optional
Data to add. Number of rows needs to match either
number of cells or number of points.
Data to add or replace. Each table label is individually considered.
Number of rows needs to match either number of cells or number of points.
Notes
-----
If the number of cells equals the number of points, the data is added to both.
"""
@ -429,7 +433,10 @@ class VTK:
label: str,
data: np.ndarray):
N_data = data.shape[0]
N_p,N_c = vtk_data.GetNumberOfPoints(),vtk_data.GetNumberOfCells()
if (N_data := data.shape[0]) not in [N_p,N_c]:
raise ValueError(f'data count mismatch ({N_data}{N_p} & {N_c})')
data_ = data.reshape(N_data,-1) \
.astype(np.single if data.dtype in [np.double,np.longdouble] else data.dtype)
@ -442,12 +449,10 @@ class VTK:
d.SetName(label)
if N_data == vtk_data.GetNumberOfPoints():
if N_data == N_p:
vtk_data.GetPointData().AddArray(d)
elif N_data == vtk_data.GetNumberOfCells():
if N_data == N_c:
vtk_data.GetCellData().AddArray(d)
else:
raise ValueError(f'data count mismatch ({N_data}{self.N_points} & {self.N_cells})')
if data is None and table is None:
raise KeyError('no data given')

View File

@ -16,6 +16,17 @@
</DataArray>
</PointData>
<CellData>
<DataArray type="Float32" Name="coordinates" NumberOfComponents="3" format="binary" RangeMin="0.7453560147132696" RangeMax="2.449489742783178">
AQAAAACAAAB4AAAAVgAAAA==eF5jYICBhv2WfY9tLfuS7Ypk3PeDaCDf7okF3/7Vq1bZrV6lZQ+k94HEgHL2QHovUM7+iUUfiG0LlQdhkH77Ipnj9iB5qFp7kBjQDiBmcADRANsaLXM=
<InformationKey name="L2_NORM_RANGE" location="vtkDataArray" length="2">
<Value index="0">
0.74535601471
</Value>
<Value index="1">
2.4494897428
</Value>
</InformationKey>
</DataArray>
</CellData>
<Points>
<DataArray type="Float64" Name="Points" NumberOfComponents="3" format="binary" RangeMin="0.7453559924999299" RangeMax="2.449489742783178">

View File

@ -147,24 +147,24 @@ class TestVTK:
with pytest.raises(KeyError):
default.get('does_not_exist')
def test_invalid_add_shape(self,default):
def test_invalid_set_shape(self,default):
with pytest.raises(ValueError):
default.add('valid',np.ones(3))
default.set('valid',np.ones(3))
def test_invalid_add_missing_label(self,default):
def test_invalid_set_missing_label(self,default):
data = np.random.randint(9,size=np.prod(np.array(default.vtk_data.GetDimensions())-1))
with pytest.raises(ValueError):
default.add(data=data)
default.set(data=data)
def test_invalid_add_type(self,default):
def test_invalid_set_type(self,default):
with pytest.raises(TypeError):
default.add(label='valid',data='invalid_type')
default.set(label='valid',data='invalid_type')
with pytest.raises(TypeError):
default.add(label='valid',table='invalid_type')
default.set(label='valid',table='invalid_type')
def test_invalid_add_dual(self,default):
def test_invalid_set_dual(self,default):
with pytest.raises(KeyError):
default.add(label='valid',data=0,table=0)
default.set(label='valid',data=0,table=0)
@pytest.mark.parametrize('data_type,shape',[(float,(3,)),
(float,(3,3)),
@ -172,31 +172,31 @@ class TestVTK:
(int,(4,)),
(str,(1,))])
@pytest.mark.parametrize('N_values',[5*6*7,6*7*8])
def test_add_get(self,default,data_type,shape,N_values):
def test_set_get(self,default,data_type,shape,N_values):
data = np.squeeze(np.random.randint(0,100,(N_values,)+shape)).astype(data_type)
new = default.add('data',data)
new = default.set('data',data)
assert (np.squeeze(data.reshape(N_values,-1)) == new.get('data')).all()
@pytest.mark.parametrize('shapes',[{'scalar':(1,),'vector':(3,),'tensor':(3,3)},
{'vector':(6,),'tensor':(3,3)},
{'tensor':(3,3),'scalar':(1,)}])
def test_add_table(self,default,shapes):
def test_set_table(self,default,shapes):
N = np.random.choice([default.N_points,default.N_cells])
d = dict()
for k,s in shapes.items():
d[k] = dict(shape = s,
data = np.random.random(N*np.prod(s)).reshape((N,-1)))
new = default.add(table=Table(shapes,np.column_stack([d[k]['data'] for k in shapes.keys()])))
new = default.set(table=Table(shapes,np.column_stack([d[k]['data'] for k in shapes.keys()])))
for k,s in shapes.items():
assert np.allclose(np.squeeze(d[k]['data']),new.get(k),rtol=1e-7)
def test_add_masked(self,default):
def test_set_masked(self,default):
data = np.random.rand(5*6*7,3)
masked = ma.MaskedArray(data,mask=data<.4,fill_value=42.)
mask_auto = default.add('D',masked)
mask_manual = default.add('D',np.where(masked.mask,masked.fill_value,masked))
mask_auto = default.set('D',masked)
mask_manual = default.set('D',np.where(masked.mask,masked.fill_value,masked))
assert mask_manual == mask_auto
@ -210,7 +210,7 @@ class TestVTK:
data = np.squeeze(np.random.randint(0,100,(N_values,)+shape)).astype(data_type)
ALPHABET = np.array(list(string.ascii_lowercase + ' '))
label = ''.join(np.random.choice(ALPHABET, size=10))
new = default.add(label,data)
new = default.set(label,data)
if N_values == default.N_points: assert label in new.labels['Point Data']
if N_values == default.N_cells: assert label in new.labels['Cell Data']
@ -225,7 +225,7 @@ class TestVTK:
@pytest.mark.xfail(int(vtk.vtkVersion.GetVTKVersion().split('.')[0])<8, reason='missing METADATA')
def test_compare_reference_polyData(self,update,ref_path,tmp_path):
points=np.dstack((np.linspace(0.,1.,10),np.linspace(0.,2.,10),np.linspace(-1.,1.,10))).squeeze()
polyData = VTK.from_poly_data(points).add('coordinates',points)
polyData = VTK.from_poly_data(points).set('coordinates',points)
if update:
polyData.save(ref_path/'polyData')
else:
@ -242,8 +242,8 @@ class TestVTK:
c = coords[:-1,:-1,:-1,:].reshape(-1,3,order='F')
n = coords[:,:,:,:].reshape(-1,3,order='F')
rectilinearGrid = VTK.from_rectilinear_grid(grid) \
.add('cell',np.ascontiguousarray(c)) \
.add('node',np.ascontiguousarray(n))
.set('cell',np.ascontiguousarray(c)) \
.set('node',np.ascontiguousarray(n))
if update:
rectilinearGrid.save(ref_path/'rectilinearGrid')
else: