Robust translation for view.times

This commit is contained in:
Philip Eisenlohr 2022-11-11 06:03:14 +00:00
parent 5bb31dd4d2
commit 349a39609e
2 changed files with 25 additions and 17 deletions

View File

@ -11,7 +11,7 @@ from pathlib import Path
from functools import partial
from collections import defaultdict
from collections.abc import Iterable
from typing import Union, Callable, Any, Sequence, Literal, Dict, List, Tuple
from typing import Union, Callable, Any, Sequence, Literal, Dict, List, Tuple, Optional
import h5py
import numpy as np
@ -122,7 +122,7 @@ class Result:
r = re.compile(rf'{prefix_inc}([0-9]+)')
self.increments = sorted([i for i in f.keys() if r.match(i)],key=util.natural_sort)
self.times = [round(f[i].attrs['t/s'],12) for i in self.increments]
self.times = np.around([f[i].attrs['t/s'] for i in self.increments],12)
if len(self.increments) == 0:
raise ValueError('incomplete DADF5 file')
@ -228,19 +228,19 @@ class Result:
self.increments[c] if isinstance(c,int) and c<0 else
f'{prefix_inc}{c}' for c in choice]
elif what == 'times':
atol = 1e-2 * np.min(np.diff(self.times))
what = 'increments'
if choice == ['*']:
choice = self.increments
else:
iterator = map(float,choice) # type: ignore
iterator = np.array(choice).astype(float)
choice = []
for c in iterator:
idx = np.searchsorted(self.times,c)
if idx >= len(self.times): continue
if np.isclose(c,self.times[idx]):
idx = np.searchsorted(self.times,c,side='left')
if idx<len(self.times) and np.isclose(c,self.times[idx],rtol=0,atol=atol):
choice.append(self.increments[idx])
elif np.isclose(c,self.times[idx+1]):
choice.append(self.increments[idx+1]) # type: ignore
elif idx>0 and np.isclose(c,self.times[idx-1],rtol=0,atol=atol):
choice.append(self.increments[idx-1])
valid = _match(choice,getattr(self,what))
existing = set(self.visible[what])
@ -248,11 +248,9 @@ class Result:
if action == 'set':
dup.visible[what] = sorted(set(valid), key=util.natural_sort)
elif action == 'add':
add = existing.union(valid)
dup.visible[what] = sorted(add, key=util.natural_sort)
dup.visible[what] = sorted(existing.union(valid), key=util.natural_sort)
elif action == 'del':
diff = existing.difference(valid)
dup.visible[what] = sorted(diff, key=util.natural_sort)
dup.visible[what] = sorted(existing.difference(valid), key=util.natural_sort)
return dup
@ -1546,7 +1544,7 @@ class Result:
def get(self,
output: Union[str, List[str]] = '*',
flatten: bool = True,
prune: bool = True):
prune: bool = True) -> Optional[Dict[str,Any]]:
"""
Collect data per phase/homogenization reflecting the group/folder structure in the DADF5 file.
@ -1568,7 +1566,7 @@ class Result:
Datasets structured by phase/homogenization and according to selected view.
"""
r = {} # type: ignore
r: Dict[str,Any] = {}
with h5py.File(self.fname,'r') as f:
for inc in util.show_progress(self.visible['increments']):
@ -1597,7 +1595,7 @@ class Result:
prune: bool = True,
constituents: IntSequence = None,
fill_float: float = np.nan,
fill_int: int = 0):
fill_int: int = 0) -> Optional[Dict[str,Any]]:
"""
Merge data into spatial order that is compatible with the damask.VTK geometry representation.
@ -1634,9 +1632,9 @@ class Result:
Datasets structured by spatial position and according to selected view.
"""
r = {} # type: ignore
r: Dict[str,Any] = {}
constituents_ = list(map(int,constituents)) if isinstance(constituents,Iterable) else \
constituents_ = map(int,constituents) if isinstance(constituents,Iterable) else \
(range(self.N_constituents) if constituents is None else [constituents]) # type: ignore
suffixes = [''] if self.N_constituents == 1 or isinstance(constituents,int) else \

View File

@ -100,6 +100,16 @@ class TestResult:
assert n0.get('F') is n1.get('F') is None and \
len(n0.visible[label]) == len(n1.visible[label]) == 0
def test_view_invalid_incstimes(self,default):
with pytest.raises(ValueError):
default.view(increments=0,times=0)
@pytest.mark.parametrize('inc',[0,10])
@pytest.mark.parametrize('sign',[+1,-1])
def test_view_approxtimes(self,default,inc,sign):
eps = sign*1e-3
assert [default.increments[inc]] == default.view(times=default.times[inc]+eps).visible['increments']
def test_add_invalid(self,default):
default.add_absolute('xxxx')