fixed getters for visible entities

storing time information as dictionary simplifies many operations
This commit is contained in:
Martin Diehl 2023-12-17 06:51:38 +01:00
parent 38f9c1977c
commit 8def54c862
No known key found for this signature in database
GPG Key ID: 1FD50837275A0A9B
2 changed files with 18 additions and 17 deletions

View File

@ -120,7 +120,7 @@ class Result:
r = re.compile(rf'{prefix_inc}([0-9]+)') r = re.compile(rf'{prefix_inc}([0-9]+)')
self._increments = sorted([i for i in f.keys() if r.match(i)],key=util.natural_sort) self._increments = sorted([i for i in f.keys() if r.match(i)],key=util.natural_sort)
self._times = np.around([f[i].attrs['t/s'] for i in self._increments],12) self._times = {int(i.split('_')[1]):np.around(f[i].attrs['t/s'],12) for i in self._increments}
if len(self._increments) == 0: if len(self._increments) == 0:
raise ValueError('incomplete DADF5 file') raise ValueError('incomplete DADF5 file')
@ -226,7 +226,8 @@ class Result:
self._increments[c] if isinstance(c,int) and c<0 else self._increments[c] if isinstance(c,int) and c<0 else
f'{prefix_inc}{c}' for c in choice] f'{prefix_inc}{c}' for c in choice]
elif what == 'times': elif what == 'times':
atol = 1e-2 * np.min(np.diff(self._times)) times = list(self._times.values())
atol = 1e-2 * np.min(np.diff(times))
what = 'increments' what = 'increments'
if choice == ['*']: if choice == ['*']:
choice = self._increments choice = self._increments
@ -234,10 +235,10 @@ class Result:
iterator = np.array(choice).astype(float) iterator = np.array(choice).astype(float)
choice = [] choice = []
for c in iterator: for c in iterator:
idx = np.searchsorted(self._times,c,side='left') idx = np.searchsorted(times,c,side='left')
if idx<len(self._times) and np.isclose(c,self._times[idx],rtol=0,atol=atol): if idx<len(self._times) and np.isclose(c,times[idx],rtol=0,atol=atol):
choice.append(self._increments[idx]) choice.append(self._increments[idx])
elif idx>0 and np.isclose(c,self._times[idx-1],rtol=0,atol=atol): elif idx>0 and np.isclose(c,times[idx-1],rtol=0,atol=atol):
choice.append(self._increments[idx-1]) choice.append(self._increments[idx-1])
valid = _match(choice,getattr(self,'_'+what)) valid = _match(choice,getattr(self,'_'+what))
@ -296,9 +297,9 @@ class Result:
Time of each increment within the given bounds. Time of each increment within the given bounds.
""" """
s,e = (self._times[ 0] if start is None else start, s,e = (self.times[ 0] if start is None else start,
self._times[-1] if end is None else end) self.times[-1] if end is None else end)
return [t for t in self._times if s <= t <= e] return [t for t in self.times if s <= t <= e]
def view(self,*, def view(self,*,
@ -533,7 +534,7 @@ class Result:
msg = [] msg = []
with h5py.File(self.fname,'r') as f: with h5py.File(self.fname,'r') as f:
for inc in self.visible['increments']: for inc in self.visible['increments']:
msg += [f'\n{inc} ({self._times[self._increments.index(inc)]} s)'] msg += [f'\n{inc} ({self._times[int(inc.split("_")[1])]} s)']
for ty in ['phase','homogenization']: for ty in ['phase','homogenization']:
msg += [f' {ty}'] msg += [f' {ty}']
for label in self.visible[ty+'s']: for label in self.visible[ty+'s']:
@ -575,19 +576,19 @@ class Result:
@property @property
def times(self): def times(self):
return NotImplementedError return [self._times[i] for i in self.increments]
@property @property
def phases(self): def phases(self):
return [copy.deepcopy(self.visible['phases'])] return self.visible['phases']
@property @property
def homogenizations(self): def homogenizations(self):
return [copy.deepcopy(self.visible['homogenizations'])] return self.visible['homogenizations']
@property @property
def fields(self): def fields(self):
return [copy.deepcopy(self.visible['fields'])] return self.visible['fields']
@property @property
@ -1769,11 +1770,10 @@ class Result:
time.attrib = {'TimeType': 'List'} time.attrib = {'TimeType': 'List'}
time_data = ET.SubElement(time, 'DataItem') time_data = ET.SubElement(time, 'DataItem')
times = [self._times[self._increments.index(i)] for i in self.visible['increments']]
time_data.attrib = {'Format': 'XML', time_data.attrib = {'Format': 'XML',
'NumberType': 'Float', 'NumberType': 'Float',
'Dimensions': f'{len(times)}'} 'Dimensions': f'{len(self.times)}'}
time_data.text = ' '.join(map(str,times)) time_data.text = ' '.join(map(str,self.times))
attributes = [] attributes = []
data_items = [] data_items = []

View File

@ -126,7 +126,8 @@ class TestResult:
@pytest.mark.parametrize('sign',[+1,-1]) @pytest.mark.parametrize('sign',[+1,-1])
def test_view_approxtimes(self,default,inc,sign): def test_view_approxtimes(self,default,inc,sign):
eps = sign*1e-3 eps = sign*1e-3
assert [default._increments[inc]] == default.view(times=default._times[inc]+eps).visible['increments'] times = list(default._times.values())
assert [default._increments[inc]] == default.view(times=times[inc]+eps).visible['increments']
def test_add_invalid(self,default): def test_add_invalid(self,default):
default.add_absolute('xxxx') default.add_absolute('xxxx')