avoid random order when using sets

This commit is contained in:
Martin Diehl 2021-04-03 11:08:22 +02:00
parent 885aeb62e5
commit 27f2e3b26e
2 changed files with 41 additions and 32 deletions

View File

@ -25,11 +25,27 @@ from . import util
h5py3 = h5py.__version__[0] == '3' h5py3 = h5py.__version__[0] == '3'
def _read(dataset): def _read(dataset):
metadata = {k:(v if h5py3 else v.decode()) for k,v in dataset.attrs.items()} metadata = {k:(v if h5py3 else v.decode()) for k,v in dataset.attrs.items()}
dtype = np.dtype(dataset.dtype,metadata=metadata) dtype = np.dtype(dataset.dtype,metadata=metadata)
return np.array(dataset,dtype=dtype) return np.array(dataset,dtype=dtype)
def _match(requested,existing):
def flatten_list(list_of_lists):
return [e for e_ in list_of_lists for e in e_]
if requested is True:
requested = '*'
elif requested is False or requested is None:
requested = []
requested_ = requested if hasattr(requested,'__iter__') and not isinstance(requested,str) else \
[requested]
return sorted(set(flatten_list([glob.fnmatch.filter(existing,r) for r in requested_])),
key=util.natural_sort)
class Result: class Result:
""" """
@ -70,8 +86,7 @@ class Result:
self.origin = f['geometry'].attrs['origin'] self.origin = f['geometry'].attrs['origin']
r=re.compile('inc[0-9]+' if self.version_minor < 12 else 'increment_[0-9]+') r=re.compile('inc[0-9]+' if self.version_minor < 12 else 'increment_[0-9]+')
increments_unsorted = {int(i[10:]):i for i in f.keys() if r.match(i)} self.increments = sorted([i for i in f.keys() if r.match(i)],key=util.natural_sort)
self.increments = [increments_unsorted[i] for i in sorted(increments_unsorted)]
self.times = [round(f[i].attrs['time/s'],12) for i in self.increments] if self.version_minor < 12 else \ self.times = [round(f[i].attrs['time/s'],12) for i in self.increments] if self.version_minor < 12 else \
[round(f[i].attrs['t/s'],12) for i in self.increments] [round(f[i].attrs['t/s'],12) for i in self.increments]
@ -81,15 +96,17 @@ class Result:
self.homogenizations = [m.decode() for m in np.unique(f[f'{grp}/homogenization'] self.homogenizations = [m.decode() for m in np.unique(f[f'{grp}/homogenization']
['Name' if self.version_minor < 12 else 'label'])] ['Name' if self.version_minor < 12 else 'label'])]
self.homogenizations = sorted(self.homogenizations,key=util.natural_sort)
self.phases = [c.decode() for c in np.unique(f[f'{grp}/phase'] self.phases = [c.decode() for c in np.unique(f[f'{grp}/phase']
['Name' if self.version_minor < 12 else 'label'])] ['Name' if self.version_minor < 12 else 'label'])]
self.phases = sorted(self.phases,key=util.natural_sort)
self.fields = [] self.fields = []
for c in self.phases: for c in self.phases:
self.fields += f['/'.join([self.increments[0],'phase',c])].keys() self.fields += f['/'.join([self.increments[0],'phase',c])].keys()
for m in self.homogenizations: for m in self.homogenizations:
self.fields += f['/'.join([self.increments[0],'homogenization',m])].keys() self.fields += f['/'.join([self.increments[0],'homogenization',m])].keys()
self.fields = list(set(self.fields)) # make unique self.fields = sorted(set(self.fields),key=util.natural_sort) # make unique
self.visible = {'increments': self.increments, self.visible = {'increments': self.increments,
'phases': self.phases, 'phases': self.phases,
@ -135,14 +152,10 @@ class Result:
True is equivalent to [*], False is equivalent to []. True is equivalent to [*], False is equivalent to [].
""" """
def natural_sort(key):
convert = lambda text: int(text) if text.isdigit() else text
return [ convert(c) for c in re.split('([0-9]+)', key) ]
# allow True/False and string arguments # allow True/False and string arguments
if datasets is True: if datasets is True:
datasets = ['*'] datasets = '*'
elif datasets is False: elif datasets is False or datasets is None:
datasets = [] datasets = []
choice = datasets if hasattr(datasets,'__iter__') and not isinstance(datasets,str) else \ choice = datasets if hasattr(datasets,'__iter__') and not isinstance(datasets,str) else \
[datasets] [datasets]
@ -166,19 +179,17 @@ class Result:
elif np.isclose(c,self.times[idx+1]): elif np.isclose(c,self.times[idx+1]):
choice.append(self.increments[idx+1]) choice.append(self.increments[idx+1])
valid = [e for e_ in [glob.fnmatch.filter(getattr(self,what),s) for s in choice] for e in e_] valid = _match(choice,getattr(self,what))
existing = set(self.visible[what]) existing = set(self.visible[what])
if action == 'set': if action == 'set':
self.visible[what] = valid self.visible[what] = sorted(set(valid), key=util.natural_sort)
elif action == 'add': elif action == 'add':
add = existing.union(valid) add = existing.union(valid)
add_sorted = sorted(add, key=natural_sort) self.visible[what] = sorted(add, key=util.natural_sort)
self.visible[what] = add_sorted
elif action == 'del': elif action == 'del':
diff = existing.difference(valid) diff = existing.difference(valid)
diff_sorted = sorted(diff, key=natural_sort) self.visible[what] = sorted(diff, key=util.natural_sort)
self.visible[what] = diff_sorted
def _get_attribute(self,path,attr): def _get_attribute(self,path,attr):
@ -1285,7 +1296,6 @@ class Result:
ln = 3 if self.version_minor < 12 else 10 # compatibility hack ln = 3 if self.version_minor < 12 else 10 # compatibility hack
N_digits = int(np.floor(np.log10(max(1,int(self.increments[-1][ln:])))))+1 N_digits = int(np.floor(np.log10(max(1,int(self.increments[-1][ln:])))))+1
output_ = set([output] if isinstance(output,str) else output)
constituents_ = constituents if isinstance(constituents,Iterable) else \ constituents_ = constituents if isinstance(constituents,Iterable) else \
(range(self.N_constituents) if constituents is None else [constituents]) (range(self.N_constituents) if constituents is None else [constituents])
@ -1322,8 +1332,7 @@ class Result:
if field not in f['/'.join((inc,ty,label))].keys(): continue if field not in f['/'.join((inc,ty,label))].keys(): continue
outs = {} outs = {}
for out in f['/'.join((inc,ty,label,field))].keys() if '*' in output_ else \ for out in _match(output,f['/'.join((inc,ty,label,field))].keys()):
output_.intersection(f['/'.join((inc,ty,label,field))].keys()):
data = ma.array(_read(f['/'.join((inc,ty,label,field,out))])) data = ma.array(_read(f['/'.join((inc,ty,label,field,out))]))
if ty == 'phase': if ty == 'phase':
@ -1375,23 +1384,20 @@ class Result:
""" """
r = {} r = {}
output_ = set([output] if isinstance(output,str) else output)
with h5py.File(self.fname,'r') as f: with h5py.File(self.fname,'r') as f:
for inc in util.show_progress(self.visible['increments']): for inc in util.show_progress(self.visible['increments']):
r[inc] = {'phase':{},'homogenization':{},'geometry':{}} r[inc] = {'phase':{},'homogenization':{},'geometry':{}}
for out in f['/'.join((inc,'geometry'))].keys() if '*' in output_ else \ for out in _match(output,f['/'.join((inc,'geometry'))].keys()):
output_.intersection(f['/'.join((inc,'geometry'))].keys()):
r[inc]['geometry'][out] = _read(f['/'.join((inc,'geometry',out))]) r[inc]['geometry'][out] = _read(f['/'.join((inc,'geometry',out))])
for ty in ['phase','homogenization']: for ty in ['phase','homogenization']:
for label in self.visible[ty+'s']: for label in self.visible[ty+'s']:
r[inc][ty][label] = {} r[inc][ty][label] = {}
for field in set(self.visible['fields']).union(f['/'.join((inc,ty,label))].keys()): for field in _match(self.visible['fields'],f['/'.join((inc,ty,label))].keys()):
r[inc][ty][label][field] = {} r[inc][ty][label][field] = {}
for out in f['/'.join((inc,ty,label,field))].keys() if '*' in output_ else \ for out in _match(output,f['/'.join((inc,ty,label,field))].keys()):
output_.intersection(f['/'.join((inc,ty,label,field))].keys()):
r[inc][ty][label][field][out] = _read(f['/'.join((inc,ty,label,field,out))]) r[inc][ty][label][field][out] = _read(f['/'.join((inc,ty,label,field,out))])
if prune: r = util.dict_prune(r) if prune: r = util.dict_prune(r)
@ -1435,7 +1441,6 @@ class Result:
""" """
r = {} r = {}
output_ = set([output] if isinstance(output,str) else output)
constituents_ = constituents if isinstance(constituents,Iterable) else \ constituents_ = constituents if isinstance(constituents,Iterable) else \
(range(self.N_constituents) if constituents is None else [constituents]) (range(self.N_constituents) if constituents is None else [constituents])
@ -1464,18 +1469,16 @@ class Result:
for inc in util.show_progress(self.visible['increments']): for inc in util.show_progress(self.visible['increments']):
r[inc] = {'phase':{},'homogenization':{},'geometry':{}} r[inc] = {'phase':{},'homogenization':{},'geometry':{}}
for out in f['/'.join((inc,'geometry'))].keys() if '*' in output_ else \ for out in _match(output,f['/'.join((inc,'geometry'))].keys()):
output_.intersection(f['/'.join((inc,'geometry'))].keys()):
r[inc]['geometry'][out] = _read(f['/'.join((inc,'geometry',out))]) r[inc]['geometry'][out] = _read(f['/'.join((inc,'geometry',out))])
for ty in ['phase','homogenization']: for ty in ['phase','homogenization']:
for label in self.visible[ty+'s']: for label in self.visible[ty+'s']:
for field in set(self.visible['fields']).union(f['/'.join((inc,ty,label))].keys()): for field in _match(self.visible['fields'],f['/'.join((inc,ty,label))].keys()):
if field not in r[inc][ty].keys(): if field not in r[inc][ty].keys():
r[inc][ty][field] = {} r[inc][ty][field] = {}
for out in f['/'.join((inc,ty,label,field))].keys() if '*' in output_ else \ for out in _match(output,f['/'.join((inc,ty,label,field))].keys()):
output_.intersection(f['/'.join((inc,ty,label,field))].keys()):
data = ma.array(_read(f['/'.join((inc,ty,label,field,out))])) data = ma.array(_read(f['/'.join((inc,ty,label,field,out))]))
if ty == 'phase': if ty == 'phase':

View File

@ -17,6 +17,7 @@ __all__=[
'srepr', 'srepr',
'emph','deemph','warn','strikeout', 'emph','deemph','warn','strikeout',
'execute', 'execute',
'natural_sort',
'show_progress', 'show_progress',
'scale_to_coprime', 'scale_to_coprime',
'project_stereographic', 'project_stereographic',
@ -113,6 +114,11 @@ def execute(cmd,wd='./',env=None):
return process.stdout, process.stderr return process.stdout, process.stderr
def natural_sort(key):
convert = lambda text: int(text) if text.isdigit() else text
return [ convert(c) for c in re.split('([0-9]+)', key) ]
def show_progress(iterable,N_iter=None,prefix='',bar_length=50): def show_progress(iterable,N_iter=None,prefix='',bar_length=50):
""" """
Decorate a loop with a status bar. Decorate a loop with a status bar.