vtk.comments as directly accessed property

This commit is contained in:
Philip Eisenlohr 2022-02-14 09:19:09 -05:00
parent 1d5abc206a
commit 2ce464c48e
5 changed files with 21 additions and 29 deletions

View File

@ -173,8 +173,8 @@ class Grid:
Parameters
----------
fname : str or pathlib.Path
Grid file to read. Valid extension is .vti, which will be appended
if not given.
Grid file to read.
Valid extension is .vti, which will be appended if not given.
Returns
-------
@ -183,9 +183,9 @@ class Grid:
"""
v = VTK.load(fname if str(fname).endswith(('.vti','.vtr')) else str(fname)+'.vti') # compatibility hack
comments = v.get_comments()
cells = np.array(v.vtk_data.GetDimensions())-1
bbox = np.array(v.vtk_data.GetBounds()).reshape(3,2).T
comments = v.comments
return Grid(material = v.get('material').reshape(cells,order='F'),
size = bbox[1] - bbox[0],
@ -645,7 +645,7 @@ class Grid:
"""
v = VTK.from_image_data(self.cells,self.size,self.origin)
v.add(self.material.flatten(order='F'),'material')
v.add_comments(self.comments)
v.comments += self.comments
v.save(fname,parallel=False,compress=compress)

View File

@ -1623,7 +1623,7 @@ class Result:
else:
raise ValueError(f'invalid mode {mode}')
v.set_comments(util.execution_stamp('Result','export_VTK'))
v.comments = util.execution_stamp('Result','export_VTK')
N_digits = int(np.floor(np.log10(max(1,int(self.increments[-1][10:])))))+1
@ -1639,7 +1639,7 @@ class Result:
if self.version_minor >= 13:
creator = f.attrs['creator'] if h5py3 else f.attrs['creator'].decode()
created = f.attrs['created'] if h5py3 else f.attrs['created'].decode()
v.add_comments(f'{creator} ({created})')
v.comments += f'{creator} ({created})'
for inc in util.show_progress(self.visible['increments']):

View File

@ -2,7 +2,7 @@ import os
import warnings
import multiprocessing as mp
from pathlib import Path
from typing import Union, Literal, List
from typing import Union, Literal, List, Sequence
import numpy as np
import vtk
@ -386,7 +386,8 @@ class VTK:
raise ValueError(f'Array "{label}" not found.')
def get_comments(self) -> List[str]:
@property
def comments(self) -> List[str]:
"""Return the comments."""
fielddata = self.vtk_data.GetFieldData()
for a in range(fielddata.GetNumberOfArrays()):
@ -395,9 +396,9 @@ class VTK:
return [comments.GetValue(i) for i in range(comments.GetNumberOfValues())]
return []
def set_comments(self,
comments: Union[str, List[str]]):
@comments.setter
def comments(self,
comments: Union[str, Sequence[str]]):
"""
Set comments.
@ -407,6 +408,11 @@ class VTK:
Comments.
"""
if isinstance(comments,list):
i = 0
while -i < len(comments) and len(comments[i-1]) == 1: i -= 1 # repack any trailing characters
comments = comments[:i] + [''.join(comments[i:])] # that resulted from autosplitting of str to list
s = vtk.vtkStringArray()
s.SetName('comments')
for c in [comments] if isinstance(comments,str) else comments:
@ -414,20 +420,6 @@ class VTK:
self.vtk_data.GetFieldData().AddArray(s)
def add_comments(self,
comments: Union[str, List[str]]):
"""
Add comments.
Parameters
----------
comments : str or list of str
Comments to add.
"""
self.set_comments(self.get_comments() + ([comments] if isinstance(comments,str) else comments))
def __repr__(self) -> str:
"""ASCII representation of the VTK data."""
writer = vtk.vtkDataSetWriter()

View File

@ -385,7 +385,7 @@ class TestResult:
result.export_VTK(output,parallel=False)
fname = fname.split('.')[0]+f'_inc{(inc if type(inc) == int else inc[0]):0>2}.vti'
v = VTK.load(tmp_path/fname)
v.set_comments('n/a')
v.comments = 'n/a'
v.save(tmp_path/fname,parallel=False)
with open(fname) as f:
cur = hashlib.md5(f.read().encode()).hexdigest()

View File

@ -162,10 +162,10 @@ class TestVTK:
def test_comments(self,tmp_path,default):
default.add_comments(['this is a comment'])
default.comments += 'this is a comment'
default.save(tmp_path/'with_comments',parallel=False)
new = VTK.load(tmp_path/'with_comments.vti')
assert new.get_comments() == ['this is a comment']
assert new.comments == ['this is a comment']
@pytest.mark.xfail(int(vtk.vtkVersion.GetVTKVersion().split('.')[0])<8, reason='missing METADATA')
def test_compare_reference_polyData(self,update,ref_path,tmp_path):