Merge remote-tracking branch 'origin/keyword-view' into polishing

This commit is contained in:
Martin Diehl 2022-03-09 15:36:00 +01:00
commit b916712946
2 changed files with 75 additions and 114 deletions

View File

@ -4,7 +4,6 @@ import fnmatch
import os import os
import copy import copy
import datetime import datetime
import warnings
import xml.etree.ElementTree as ET # noqa import xml.etree.ElementTree as ET # noqa
import xml.dom.minidom import xml.dom.minidom
from pathlib import Path from pathlib import Path
@ -28,21 +27,6 @@ h5py3 = h5py.__version__[0] == '3'
chunk_size = 1024**2//8 # for compression in HDF5 chunk_size = 1024**2//8 # for compression in HDF5
def _view_transition(what,datasets,increments,times,phases,homogenizations,fields):
if (datasets is not None and what is None) or (what is not None and datasets is None):
raise ValueError('"what" and "datasets" need to be used as a pair')
if datasets is not None or what is not None:
warnings.warn('arguments "what" and "datasets" will be removed in DAMASK v3.0.0-alpha7', DeprecationWarning,2)
return what,datasets
if sum(1 for _ in filter(None.__ne__, [increments,times,phases,homogenizations,fields])) > 1:
raise ValueError('only one out of "increments", "times", "phases", "homogenizations", and "fields" can be used')
else:
if increments is not None: return "increments", increments
if times is not None: return "times", times
if phases is not None: return "phases", phases
if homogenizations is not None: return "homogenizations", homogenizations
if fields is not None: return "fields", fields
def _read(dataset): def _read(dataset):
"""Read a dataset and its metadata into a numpy.ndarray.""" """Read a dataset and its metadata into a numpy.ndarray."""
metadata = {k:(v.decode() if not h5py3 and type(v) is bytes else v) for k,v in dataset.attrs.items()} metadata = {k:(v.decode() if not h5py3 and type(v) is bytes else v) for k,v in dataset.attrs.items()}
@ -185,7 +169,13 @@ class Result:
return util.srepr([util.deemph(header)] + first + in_between + last) return util.srepr([util.deemph(header)] + first + in_between + last)
def _manage_view(self,action,what,datasets): def _manage_view(self,
action,
increments=None,
times=None,
phases=None,
homogenizations=None,
fields=None):
""" """
Manages the visibility of the groups. Manages the visibility of the groups.
@ -193,11 +183,6 @@ class Result:
---------- ----------
action : str action : str
Select from 'set', 'add', and 'del'. Select from 'set', 'add', and 'del'.
what : str
Attribute to change (must be from self.visible).
datasets : (list of) int (for increments), (list of) float (for times), (list of) str, or bool
Name of datasets; supports '?' and '*' wildcards.
True is equivalent to '*', False is equivalent to [].
Returns Returns
------- -------
@ -205,22 +190,28 @@ class Result:
Modified or new view on the DADF5 file. Modified or new view on the DADF5 file.
""" """
if increments is not None and times is not None:
raise ValueError('"increments" and "times" are mutually exclusive')
dup = self.copy()
for what,datasets in zip(['increments','times','phases','homogenizations','fields'],
[ increments, times, phases, homogenizations, fields ]):
if datasets is None:
continue
# allow True/False and string arguments # allow True/False and string arguments
if datasets is True: elif datasets is True:
datasets = '*' datasets = '*'
elif datasets is False or datasets is None: elif datasets is False:
datasets = [] datasets = []
choice = list(datasets).copy() if hasattr(datasets,'__iter__') and not isinstance(datasets,str) else \ choice = list(datasets).copy() if hasattr(datasets,'__iter__') and not isinstance(datasets,str) else \
[datasets] [datasets]
what_ = what if what.endswith('s') else what+'s' if what == 'increments':
if what_ == 'increments':
choice = [c if isinstance(c,str) and c.startswith('increment_') else choice = [c if isinstance(c,str) and c.startswith('increment_') else
self.increments[c] if isinstance(c,int) and c<0 else self.increments[c] if isinstance(c,int) and c<0 else
f'increment_{c}' for c in choice] f'increment_{c}' for c in choice]
elif what_ == 'times': elif what == 'times':
what_ = 'increments' what = 'increments'
if choice == ['*']: if choice == ['*']:
choice = self.increments choice = self.increments
else: else:
@ -234,18 +225,17 @@ class Result:
elif np.isclose(c,self.times[idx+1]): elif np.isclose(c,self.times[idx+1]):
choice.append(self.increments[idx+1]) choice.append(self.increments[idx+1])
valid = _match(choice,getattr(self,what_)) valid = _match(choice,getattr(self,what))
existing = set(self.visible[what_]) existing = set(self.visible[what])
dup = self.copy()
if action == 'set': if action == 'set':
dup.visible[what_] = sorted(set(valid), key=util.natural_sort) dup.visible[what] = sorted(set(valid), key=util.natural_sort)
elif action == 'add': elif action == 'add':
add = existing.union(valid) add = existing.union(valid)
dup.visible[what_] = sorted(add, key=util.natural_sort) dup.visible[what] = sorted(add, key=util.natural_sort)
elif action == 'del': elif action == 'del':
diff = existing.difference(valid) diff = existing.difference(valid)
dup.visible[what_] = sorted(diff, key=util.natural_sort) dup.visible[what] = sorted(diff, key=util.natural_sort)
return dup return dup
@ -298,7 +288,7 @@ class Result:
return selected return selected
def view(self,what=None,datasets=None,*, def view(self,*,
increments=None, increments=None,
times=None, times=None,
phases=None, phases=None,
@ -313,11 +303,6 @@ class Result:
Parameters Parameters
---------- ----------
what : {'increments', 'times', 'phases', 'homogenizations', 'fields'}
Attribute to change. DEPRECATED.
datasets : (list of) int (for increments), (list of) float (for times), (list of) str, or bool
Name of datasets; supports '?' and '*' wildcards. DEPRECATED.
True is equivalent to '*', False is equivalent to [].
increments: (list of) int, (list of) str, or bool, optional. increments: (list of) int, (list of) str, or bool, optional.
Number(s) of increments to select. Number(s) of increments to select.
times: (list of) float, (list of) str, or bool, optional. times: (list of) float, (list of) str, or bool, optional.
@ -351,24 +336,16 @@ class Result:
>>> r_t10to40 = r.view(times=r.times_in_range(10.0,40.0)) >>> r_t10to40 = r.view(times=r.times_in_range(10.0,40.0))
""" """
v = _view_transition(what,datasets,increments,times,phases,homogenizations,fields) dup = self._manage_view('set',increments,times,phases,homogenizations,fields)
if protected is not None: if protected is not None:
if v is None:
dup = self.copy()
else:
what_,datasets_ = v
dup = self._manage_view('set',what_,datasets_)
if not protected: if not protected:
print(util.warn('Warning: Modification of existing datasets allowed!')) print(util.warn('Warning: Modification of existing datasets allowed!'))
dup._protected = protected dup._protected = protected
else:
what_,datasets_ = v
dup = self._manage_view('set',what_,datasets_)
return dup return dup
def view_more(self,what=None,datasets=None,*, def view_more(self,*,
increments=None, increments=None,
times=None, times=None,
phases=None, phases=None,
@ -382,11 +359,6 @@ class Result:
Parameters Parameters
---------- ----------
what : {'increments', 'times', 'phases', 'homogenizations', 'fields'}
Attribute to change. DEPRECATED.
datasets : (list of) int (for increments), (list of) float (for times), (list of) str, or bool
Name of datasets; supports '?' and '*' wildcards. DEPRECATED.
True is equivalent to '*', False is equivalent to [].
increments: (list of) int, (list of) str, or bool, optional. increments: (list of) int, (list of) str, or bool, optional.
Number(s) of increments to select. Number(s) of increments to select.
times: (list of) float, (list of) str, or bool, optional. times: (list of) float, (list of) str, or bool, optional.
@ -413,11 +385,10 @@ class Result:
>>> r_first_and_last = r.first.view_more(increments=-1) >>> r_first_and_last = r.first.view_more(increments=-1)
""" """
what_, datasets_ = _view_transition(what,datasets,increments,times,phases,homogenizations,fields) return self._manage_view('add',increments,times,phases,homogenizations,fields)
return self._manage_view('add',what_,datasets_)
def view_less(self,what=None,datasets=None,*, def view_less(self,*,
increments=None, increments=None,
times=None, times=None,
phases=None, phases=None,
@ -431,11 +402,6 @@ class Result:
Parameters Parameters
---------- ----------
what : {'increments', 'times', 'phases', 'homogenizations', 'fields'}
Attribute to change. DEPRECATED.
datasets : (list of) int (for increments), (list of) float (for times), (list of) str, or bool
Name of datasets; supports '?' and '*' wildcards. DEPRECATED.
True is equivalent to '*', False is equivalent to [].
increments: (list of) int, (list of) str, or bool, optional. increments: (list of) int, (list of) str, or bool, optional.
Number(s) of increments to select. Number(s) of increments to select.
times: (list of) float, (list of) str, or bool, optional. times: (list of) float, (list of) str, or bool, optional.
@ -461,8 +427,7 @@ class Result:
>>> r_deformed = r_all.view_less(increments=0) >>> r_deformed = r_all.view_less(increments=0)
""" """
what_, datasets_ = _view_transition(what,datasets,increments,times,phases,homogenizations,fields) return self._manage_view('del',increments,times,phases,homogenizations,fields)
return self._manage_view('del',what_,datasets_)
def rename(self,name_src,name_dst): def rename(self,name_src,name_dst):
@ -1839,9 +1804,9 @@ class Result:
d = obj.attrs['description'] if h5py3 else obj.attrs['description'].decode() d = obj.attrs['description'] if h5py3 else obj.attrs['description'].decode()
if not Path(name).exists() or overwrite: if not Path(name).exists() or overwrite:
with open(name,'w') as f_out: f_out.write(obj[0].decode()) with open(name,'w') as f_out: f_out.write(obj[0].decode())
print(f"Exported {d} to '{name}'.") print(f'Exported {d} to "{name}".')
else: else:
print(f"'{name}' exists, {d} not exported.") print(f'"{name}" exists, {d} not exported.')
elif type(obj) == h5py.Group: elif type(obj) == h5py.Group:
os.makedirs(name, exist_ok=True) os.makedirs(name, exist_ok=True)

View File

@ -69,8 +69,8 @@ class TestResult:
@pytest.mark.parametrize('what',['increments','times','phases','fields']) # ToDo: discuss homogenizations @pytest.mark.parametrize('what',['increments','times','phases','fields']) # ToDo: discuss homogenizations
def test_view_none(self,default,what): def test_view_none(self,default,what):
n0 = default.view(what,False) n0 = default.view(**{what:False})
n1 = default.view(what,[]) n1 = default.view(**{what:[]})
label = 'increments' if what == 'times' else what label = 'increments' if what == 'times' else what
@ -79,29 +79,25 @@ class TestResult:
@pytest.mark.parametrize('what',['increments','times','phases','fields']) # ToDo: discuss homogenizations @pytest.mark.parametrize('what',['increments','times','phases','fields']) # ToDo: discuss homogenizations
def test_view_more(self,default,what): def test_view_more(self,default,what):
empty = default.view(what,False) empty = default.view(**{what:False})
a = empty.view_more(what,'*').get('F') a = empty.view_more(**{what:'*'}).get('F')
b = empty.view_more(what,True).get('F') b = empty.view_more(**{what:True}).get('F')
assert dict_equal(a,b) assert dict_equal(a,b)
@pytest.mark.parametrize('what',['increments','times','phases','fields']) # ToDo: discuss homogenizations @pytest.mark.parametrize('what',['increments','times','phases','fields']) # ToDo: discuss homogenizations
def test_view_less(self,default,what): def test_view_less(self,default,what):
full = default.view(what,True) full = default.view(**{what:True})
n0 = full.view_less(what,'*') n0 = full.view_less(**{what:'*'})
n1 = full.view_less(what,True) n1 = full.view_less(**{what:True})
label = 'increments' if what == 'times' else what label = 'increments' if what == 'times' else what
assert n0.get('F') is n1.get('F') is None and \ assert n0.get('F') is n1.get('F') is None and \
len(n0.visible[label]) == len(n1.visible[label]) == 0 len(n0.visible[label]) == len(n1.visible[label]) == 0
def test_view_invalid(self,default):
with pytest.raises(AttributeError):
default.view('invalid',True)
def test_add_invalid(self,default): def test_add_invalid(self,default):
default.add_absolute('xxxx') default.add_absolute('xxxx')
@ -469,7 +465,7 @@ class TestResult:
def test_get(self,update,request,ref_path,view,output,flatten,prune): def test_get(self,update,request,ref_path,view,output,flatten,prune):
result = Result(ref_path/'4grains2x4x3_compressionY.hdf5') result = Result(ref_path/'4grains2x4x3_compressionY.hdf5')
for key,value in view.items(): for key,value in view.items():
result = result.view(key,value) result = result.view(**{key:value})
fname = request.node.name fname = request.node.name
cur = result.get(output,flatten,prune) cur = result.get(output,flatten,prune)
@ -494,7 +490,7 @@ class TestResult:
def test_place(self,update,request,ref_path,view,output,flatten,prune,constituents): def test_place(self,update,request,ref_path,view,output,flatten,prune,constituents):
result = Result(ref_path/'4grains2x4x3_compressionY.hdf5') result = Result(ref_path/'4grains2x4x3_compressionY.hdf5')
for key,value in view.items(): for key,value in view.items():
result = result.view(key,value) result = result.view(**{key:value})
fname = request.node.name fname = request.node.name
cur = result.place(output,flatten,prune,constituents) cur = result.place(output,flatten,prune,constituents)