added typehints for table module

This commit is contained in:
Daniel Otto de Mentock 2022-01-12 17:10:13 +01:00
parent 771e8acdb9
commit ffa80f6bef
1 changed files with 51 additions and 35 deletions

View File

@ -1,5 +1,7 @@
import re import re
import copy import copy
from typing import Union, Optional, Tuple, List, TextIO, Set
import pathlib
import pandas as pd import pandas as pd
import numpy as np import numpy as np
@ -9,7 +11,7 @@ from . import util
class Table: class Table:
"""Manipulate multi-dimensional spreadsheet-like data.""" """Manipulate multi-dimensional spreadsheet-like data."""
def __init__(self,data,shapes,comments=None): def __init__(self, data: np.ndarray, shapes: dict, comments: Optional[Union[str, list]] = None):
""" """
New spreadsheet. New spreadsheet.
@ -30,7 +32,7 @@ class Table:
self._relabel('uniform') self._relabel('uniform')
def __repr__(self): def __repr__(self) -> str:
"""Brief overview.""" """Brief overview."""
self._relabel('shapes') self._relabel('shapes')
data_repr = self.data.__repr__() data_repr = self.data.__repr__()
@ -38,7 +40,7 @@ class Table:
return '\n'.join(['# '+c for c in self.comments])+'\n'+data_repr return '\n'.join(['# '+c for c in self.comments])+'\n'+data_repr
def __getitem__(self,item): def __getitem__(self, item: Union[slice, Tuple[slice, ...]]) -> "Table":
""" """
Slice the Table according to item. Slice the Table according to item.
@ -85,19 +87,19 @@ class Table:
comments=self.comments) comments=self.comments)
def __len__(self): def __len__(self) -> int:
"""Number of rows.""" """Number of rows."""
return len(self.data) return len(self.data)
def __copy__(self): def __copy__(self) -> "Table":
"""Create deep copy.""" """Create deep copy."""
return copy.deepcopy(self) return copy.deepcopy(self)
copy = __copy__ copy = __copy__
def _label(self,what,how): def _label(self, what: Union[str, List[str]], how: str) -> List[str]:
""" """
Expand labels according to data shape. Expand labels according to data shape.
@ -128,7 +130,7 @@ class Table:
return labels return labels
def _relabel(self,how): def _relabel(self, how: str):
""" """
Modify labeling of data in-place. Modify labeling of data in-place.
@ -141,17 +143,21 @@ class Table:
'linear' ==> 1_v 2_v 3_v 'linear' ==> 1_v 2_v 3_v
""" """
self.data.columns = self._label(self.shapes,how) self.data.columns = self._label(self.shapes,how) #type: ignore
def _add_comment(self,label,shape,info): def _add_comment(self, label: str, shape: Tuple[int, ...], info: Optional[str]):
if info is not None: if info is not None:
specific = f'{label}{" "+str(shape) if np.prod(shape,dtype=int) > 1 else ""}: {info}' specific = f'{label}{" "+str(shape) if np.prod(shape,dtype=int) > 1 else ""}: {info}'
general = util.execution_stamp('Table') general = util.execution_stamp('Table')
self.comments.append(f'{specific} / {general}') self.comments.append(f'{specific} / {general}')
def isclose(self,other,rtol=1e-5,atol=1e-8,equal_nan=True): def isclose(self,
other: "Table",
rtol: float = 1e-5,
atol: float = 1e-8,
equal_nan: bool = True) -> np.ndarray:
""" """
Report where values are approximately equal to corresponding ones of other Table. Report where values are approximately equal to corresponding ones of other Table.
@ -179,7 +185,11 @@ class Table:
equal_nan=equal_nan) equal_nan=equal_nan)
def allclose(self,other,rtol=1e-5,atol=1e-8,equal_nan=True): def allclose(self,
other: "Table",
rtol: float = 1e-5,
atol: float = 1e-8,
equal_nan: bool = True) -> bool:
""" """
Test whether all values are approximately equal to corresponding ones of other Table. Test whether all values are approximately equal to corresponding ones of other Table.
@ -208,7 +218,7 @@ class Table:
@staticmethod @staticmethod
def load(fname): def load(fname: Union[TextIO, str, pathlib.Path]) -> "Table":
""" """
Load from ASCII table file. Load from ASCII table file.
@ -229,11 +239,13 @@ class Table:
Table data from file. Table data from file.
""" """
try: if isinstance(fname, TextIO):
f = open(fname)
except TypeError:
f = fname f = fname
f.seek(0) f.seek(0)
elif isinstance(fname, (str, pathlib.Path)):
f = open(fname)
else:
raise TypeError
comments = [] comments = []
line = f.readline().strip() line = f.readline().strip()
@ -261,7 +273,7 @@ class Table:
@staticmethod @staticmethod
def load_ang(fname): def load_ang(fname: Union[TextIO, str, pathlib.Path]) -> "Table":
""" """
Load from ang file. Load from ang file.
@ -286,11 +298,13 @@ class Table:
Table data from file. Table data from file.
""" """
try: if isinstance(fname, TextIO):
f = open(fname)
except TypeError:
f = fname f = fname
f.seek(0) f.seek(0)
elif isinstance(fname, (str, pathlib.Path)):
f = open(fname)
else:
raise TypeError
content = f.readlines() content = f.readlines()
@ -312,11 +326,11 @@ class Table:
@property @property
def labels(self): def labels(self) -> List[Tuple[int, ...]]:
return list(self.shapes) return list(self.shapes)
def get(self,label): def get(self, label: str) -> np.ndarray:
""" """
Get column data. Get column data.
@ -336,7 +350,7 @@ class Table:
return data.astype(type(data.flatten()[0])) return data.astype(type(data.flatten()[0]))
def set(self,label,data,info=None): def set(self, label: str, data: np.ndarray, info: str = None) -> "Table":
""" """
Set column data. Set column data.
@ -369,7 +383,7 @@ class Table:
return dup return dup
def add(self,label,data,info=None): def add(self, label: str, data: np.ndarray, info: str = None) -> "Table":
""" """
Add column data. Add column data.
@ -401,7 +415,7 @@ class Table:
return dup return dup
def delete(self,label): def delete(self, label: str) -> "Table":
""" """
Delete column data. Delete column data.
@ -422,7 +436,7 @@ class Table:
return dup return dup
def rename(self,old,new,info=None): def rename(self, old: Union[str, List[str]], new: Union[str, List[str]], info: str = None) -> "Table":
""" """
Rename column data. Rename column data.
@ -448,7 +462,7 @@ class Table:
return dup return dup
def sort_by(self,labels,ascending=True): def sort_by(self, labels: Union[str, List[str]], ascending: Union[bool, List[bool]] = True) -> "Table":
""" """
Sort table by values of given labels. Sort table by values of given labels.
@ -481,7 +495,7 @@ class Table:
return dup return dup
def append(self,other): def append(self, other: "Table") -> "Table":
""" """
Append other table vertically (similar to numpy.vstack). Append other table vertically (similar to numpy.vstack).
@ -506,7 +520,7 @@ class Table:
return dup return dup
def join(self,other): def join(self, other: "Table") -> "Table":
""" """
Append other table horizontally (similar to numpy.hstack). Append other table horizontally (similar to numpy.hstack).
@ -533,7 +547,7 @@ class Table:
return dup return dup
def save(self,fname): def save(self, fname: Union[TextIO, str, pathlib.Path]):
""" """
Save as plain text file. Save as plain text file.
@ -543,9 +557,9 @@ class Table:
Filename or file for writing. Filename or file for writing.
""" """
seen = set() seen: Set = set()
labels = [] labels = []
for l in [x for x in self.data.columns if not (x in seen or seen.add(x))]: for l in [x for x in self.data.columns if x not in seen]:
if self.shapes[l] == (1,): if self.shapes[l] == (1,):
labels.append(f'{l}') labels.append(f'{l}')
elif len(self.shapes[l]) == 1: elif len(self.shapes[l]) == 1:
@ -555,10 +569,12 @@ class Table:
labels += [f'{util.srepr(self.shapes[l],"x")}:{i+1}_{l}' \ labels += [f'{util.srepr(self.shapes[l],"x")}:{i+1}_{l}' \
for i in range(np.prod(self.shapes[l]))] for i in range(np.prod(self.shapes[l]))]
try: if isinstance(fname, TextIO):
fhandle = open(fname,'w',newline='\n')
except TypeError:
fhandle = fname fhandle = fname
elif isinstance(fname, (str, pathlib.Path)):
fhandle = open(fname,'w',newline='\n')
else:
raise TypeError
fhandle.write('\n'.join([f'# {c}' for c in self.comments] + [' '.join(labels)])+'\n') fhandle.write('\n'.join([f'# {c}' for c in self.comments] + [' '.join(labels)])+'\n')
self.data.to_csv(fhandle,sep=' ',na_rep='nan',index=False,header=False) self.data.to_csv(fhandle,sep=' ',na_rep='nan',index=False,header=False)