Merge branch 'typehints_table' into 'development'

03 Added typehints for table module

See merge request damask/DAMASK!499
This commit is contained in:
Philip Eisenlohr 2022-01-25 00:56:32 +00:00
commit 80526967c1
1 changed files with 45 additions and 44 deletions

View File

@ -1,15 +1,18 @@
import re import re
import copy import copy
from pathlib import Path
from typing import Union, Optional, Tuple, List
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from ._typehints import FileHandle
from . import util from . import util
class Table: class Table:
"""Manipulate multi-dimensional spreadsheet-like data.""" """Manipulate multi-dimensional spreadsheet-like data."""
def __init__(self,data,shapes,comments=None): def __init__(self, data: np.ndarray, shapes: dict, comments: Optional[Union[str, list]] = None):
""" """
New spreadsheet. New spreadsheet.
@ -30,7 +33,7 @@ class Table:
self._relabel('uniform') self._relabel('uniform')
def __repr__(self): def __repr__(self) -> str:
"""Brief overview.""" """Brief overview."""
self._relabel('shapes') self._relabel('shapes')
data_repr = self.data.__repr__() data_repr = self.data.__repr__()
@ -38,7 +41,7 @@ class Table:
return '\n'.join(['# '+c for c in self.comments])+'\n'+data_repr return '\n'.join(['# '+c for c in self.comments])+'\n'+data_repr
def __getitem__(self,item): def __getitem__(self, item: Union[slice, Tuple[slice, ...]]) -> 'Table':
""" """
Slice the Table according to item. Slice the Table according to item.
@ -85,19 +88,19 @@ class Table:
comments=self.comments) comments=self.comments)
def __len__(self): def __len__(self) -> int:
"""Number of rows.""" """Number of rows."""
return len(self.data) return len(self.data)
def __copy__(self): def __copy__(self) -> 'Table':
"""Create deep copy.""" """Create deep copy."""
return copy.deepcopy(self) return copy.deepcopy(self)
copy = __copy__ copy = __copy__
def _label(self,what,how): def _label(self, what: Union[str, List[str]], how: str) -> List[str]:
""" """
Expand labels according to data shape. Expand labels according to data shape.
@ -105,7 +108,7 @@ class Table:
---------- ----------
what : str or list what : str or list
Labels to expand. Labels to expand.
how : str how : {'uniform, 'shapes', 'linear'}
Mode of labeling. Mode of labeling.
'uniform' ==> v v v 'uniform' ==> v v v
'shapes' ==> 3:v v v 'shapes' ==> 3:v v v
@ -128,30 +131,34 @@ class Table:
return labels return labels
def _relabel(self,how): def _relabel(self, how: str):
""" """
Modify labeling of data in-place. Modify labeling of data in-place.
Parameters Parameters
---------- ----------
how : str how : {'uniform, 'shapes', 'linear'}
Mode of labeling. Mode of labeling.
'uniform' ==> v v v 'uniform' ==> v v v
'shapes' ==> 3:v v v 'shapes' ==> 3:v v v
'linear' ==> 1_v 2_v 3_v 'linear' ==> 1_v 2_v 3_v
""" """
self.data.columns = self._label(self.shapes,how) self.data.columns = self._label(self.shapes,how) #type: ignore
def _add_comment(self,label,shape,info): def _add_comment(self, label: str, shape: Tuple[int, ...], info: Optional[str]):
if info is not None: if info is not None:
specific = f'{label}{" "+str(shape) if np.prod(shape,dtype=int) > 1 else ""}: {info}' specific = f'{label}{" "+str(shape) if np.prod(shape,dtype=int) > 1 else ""}: {info}'
general = util.execution_stamp('Table') general = util.execution_stamp('Table')
self.comments.append(f'{specific} / {general}') self.comments.append(f'{specific} / {general}')
def isclose(self,other,rtol=1e-5,atol=1e-8,equal_nan=True): def isclose(self,
other: 'Table',
rtol: float = 1e-5,
atol: float = 1e-8,
equal_nan: bool = True) -> np.ndarray:
""" """
Report where values are approximately equal to corresponding ones of other Table. Report where values are approximately equal to corresponding ones of other Table.
@ -179,7 +186,11 @@ class Table:
equal_nan=equal_nan) equal_nan=equal_nan)
def allclose(self,other,rtol=1e-5,atol=1e-8,equal_nan=True): def allclose(self,
other: 'Table',
rtol: float = 1e-5,
atol: float = 1e-8,
equal_nan: bool = True) -> bool:
""" """
Test whether all values are approximately equal to corresponding ones of other Table. Test whether all values are approximately equal to corresponding ones of other Table.
@ -208,7 +219,7 @@ class Table:
@staticmethod @staticmethod
def load(fname): def load(fname: FileHandle) -> 'Table':
""" """
Load from ASCII table file. Load from ASCII table file.
@ -229,10 +240,7 @@ class Table:
Table data from file. Table data from file.
""" """
try: f = open(fname) if isinstance(fname, (str, Path)) else fname
f = open(fname)
except TypeError:
f = fname
f.seek(0) f.seek(0)
comments = [] comments = []
@ -261,7 +269,7 @@ class Table:
@staticmethod @staticmethod
def load_ang(fname): def load_ang(fname: FileHandle) -> 'Table':
""" """
Load from ang file. Load from ang file.
@ -286,10 +294,7 @@ class Table:
Table data from file. Table data from file.
""" """
try: f = open(fname) if isinstance(fname, (str, Path)) else fname
f = open(fname)
except TypeError:
f = fname
f.seek(0) f.seek(0)
content = f.readlines() content = f.readlines()
@ -312,11 +317,11 @@ class Table:
@property @property
def labels(self): def labels(self) -> List[Tuple[int, ...]]:
return list(self.shapes) return list(self.shapes)
def get(self,label): def get(self, label: str) -> np.ndarray:
""" """
Get column data. Get column data.
@ -336,7 +341,7 @@ class Table:
return data.astype(type(data.flatten()[0])) return data.astype(type(data.flatten()[0]))
def set(self,label,data,info=None): def set(self, label: str, data: np.ndarray, info: str = None) -> 'Table':
""" """
Set column data. Set column data.
@ -356,7 +361,7 @@ class Table:
""" """
dup = self.copy() dup = self.copy()
dup._add_comment(label,data.shape[1:],info) dup._add_comment(label, data.shape[1:], info)
m = re.match(r'(.*)\[((\d+,)*(\d+))\]',label) m = re.match(r'(.*)\[((\d+,)*(\d+))\]',label)
if m: if m:
key = m.group(1) key = m.group(1)
@ -369,7 +374,7 @@ class Table:
return dup return dup
def add(self,label,data,info=None): def add(self, label: str, data: np.ndarray, info: str = None) -> 'Table':
""" """
Add column data. Add column data.
@ -401,7 +406,7 @@ class Table:
return dup return dup
def delete(self,label): def delete(self, label: str) -> 'Table':
""" """
Delete column data. Delete column data.
@ -422,7 +427,7 @@ class Table:
return dup return dup
def rename(self,old,new,info=None): def rename(self, old: Union[str, List[str]], new: Union[str, List[str]], info: str = None) -> 'Table':
""" """
Rename column data. Rename column data.
@ -448,7 +453,7 @@ class Table:
return dup return dup
def sort_by(self,labels,ascending=True): def sort_by(self, labels: Union[str, List[str]], ascending: Union[bool, List[bool]] = True) -> 'Table':
""" """
Sort table by values of given labels. Sort table by values of given labels.
@ -481,7 +486,7 @@ class Table:
return dup return dup
def append(self,other): def append(self, other: 'Table') -> 'Table':
""" """
Append other table vertically (similar to numpy.vstack). Append other table vertically (similar to numpy.vstack).
@ -506,7 +511,7 @@ class Table:
return dup return dup
def join(self,other): def join(self, other: 'Table') -> 'Table':
""" """
Append other table horizontally (similar to numpy.hstack). Append other table horizontally (similar to numpy.hstack).
@ -533,7 +538,7 @@ class Table:
return dup return dup
def save(self,fname): def save(self, fname: FileHandle):
""" """
Save as plain text file. Save as plain text file.
@ -543,9 +548,8 @@ class Table:
Filename or file for writing. Filename or file for writing.
""" """
seen = set()
labels = [] labels = []
for l in [x for x in self.data.columns if not (x in seen or seen.add(x))]: for l in list(dict.fromkeys(self.data.columns)):
if self.shapes[l] == (1,): if self.shapes[l] == (1,):
labels.append(f'{l}') labels.append(f'{l}')
elif len(self.shapes[l]) == 1: elif len(self.shapes[l]) == 1:
@ -555,10 +559,7 @@ class Table:
labels += [f'{util.srepr(self.shapes[l],"x")}:{i+1}_{l}' \ labels += [f'{util.srepr(self.shapes[l],"x")}:{i+1}_{l}' \
for i in range(np.prod(self.shapes[l]))] for i in range(np.prod(self.shapes[l]))]
try: f = open(fname,'w',newline='\n') if isinstance(fname, (str, Path)) else fname
fhandle = open(fname,'w',newline='\n')
except TypeError:
fhandle = fname
fhandle.write('\n'.join([f'# {c}' for c in self.comments] + [' '.join(labels)])+'\n') f.write('\n'.join([f'# {c}' for c in self.comments] + [' '.join(labels)])+'\n')
self.data.to_csv(fhandle,sep=' ',na_rep='nan',index=False,header=False) self.data.to_csv(f,sep=' ',na_rep='nan',index=False,header=False)