diff --git a/python/damask/_table.py b/python/damask/_table.py index a65971f39..04ab426ce 100644 --- a/python/damask/_table.py +++ b/python/damask/_table.py @@ -1,5 +1,7 @@ import re import copy +from typing import Union, Optional, Tuple, List, TextIO, Set +import pathlib import pandas as pd import numpy as np @@ -9,7 +11,7 @@ from . import util class Table: """Manipulate multi-dimensional spreadsheet-like data.""" - def __init__(self,data,shapes,comments=None): + def __init__(self, data: np.ndarray, shapes: dict, comments: Optional[Union[str, list]] = None): """ New spreadsheet. @@ -30,7 +32,7 @@ class Table: self._relabel('uniform') - def __repr__(self): + def __repr__(self) -> str: """Brief overview.""" self._relabel('shapes') data_repr = self.data.__repr__() @@ -38,7 +40,7 @@ class Table: return '\n'.join(['# '+c for c in self.comments])+'\n'+data_repr - def __getitem__(self,item): + def __getitem__(self, item: Union[slice, Tuple[slice, ...]]) -> "Table": """ Slice the Table according to item. @@ -85,19 +87,19 @@ class Table: comments=self.comments) - def __len__(self): + def __len__(self) -> int: """Number of rows.""" return len(self.data) - def __copy__(self): + def __copy__(self) -> "Table": """Create deep copy.""" return copy.deepcopy(self) copy = __copy__ - def _label(self,what,how): + def _label(self, what: Union[str, List[str]], how: str) -> List[str]: """ Expand labels according to data shape. @@ -128,7 +130,7 @@ class Table: return labels - def _relabel(self,how): + def _relabel(self, how: str): """ Modify labeling of data in-place. @@ -141,17 +143,21 @@ class Table: 'linear' ==> 1_v 2_v 3_v """ - self.data.columns = self._label(self.shapes,how) + self.data.columns = self._label(self.shapes,how) #type: ignore - def _add_comment(self,label,shape,info): + def _add_comment(self, label: str, shape: Tuple[int, ...], info: Optional[str]): if info is not None: specific = f'{label}{" "+str(shape) if np.prod(shape,dtype=int) > 1 else ""}: {info}' general = util.execution_stamp('Table') self.comments.append(f'{specific} / {general}') - def isclose(self,other,rtol=1e-5,atol=1e-8,equal_nan=True): + def isclose(self, + other: "Table", + rtol: float = 1e-5, + atol: float = 1e-8, + equal_nan: bool = True) -> np.ndarray: """ Report where values are approximately equal to corresponding ones of other Table. @@ -179,7 +185,11 @@ class Table: equal_nan=equal_nan) - def allclose(self,other,rtol=1e-5,atol=1e-8,equal_nan=True): + def allclose(self, + other: "Table", + rtol: float = 1e-5, + atol: float = 1e-8, + equal_nan: bool = True) -> bool: """ Test whether all values are approximately equal to corresponding ones of other Table. @@ -208,7 +218,7 @@ class Table: @staticmethod - def load(fname): + def load(fname: Union[TextIO, str, pathlib.Path]) -> "Table": """ Load from ASCII table file. @@ -229,11 +239,13 @@ class Table: Table data from file. """ - try: - f = open(fname) - except TypeError: + if isinstance(fname, TextIO): f = fname f.seek(0) + elif isinstance(fname, (str, pathlib.Path)): + f = open(fname) + else: + raise TypeError comments = [] line = f.readline().strip() @@ -261,7 +273,7 @@ class Table: @staticmethod - def load_ang(fname): + def load_ang(fname: Union[TextIO, str, pathlib.Path]) -> "Table": """ Load from ang file. @@ -286,11 +298,13 @@ class Table: Table data from file. """ - try: - f = open(fname) - except TypeError: + if isinstance(fname, TextIO): f = fname f.seek(0) + elif isinstance(fname, (str, pathlib.Path)): + f = open(fname) + else: + raise TypeError content = f.readlines() @@ -312,11 +326,11 @@ class Table: @property - def labels(self): + def labels(self) -> List[Tuple[int, ...]]: return list(self.shapes) - def get(self,label): + def get(self, label: str) -> np.ndarray: """ Get column data. @@ -336,7 +350,7 @@ class Table: return data.astype(type(data.flatten()[0])) - def set(self,label,data,info=None): + def set(self, label: str, data: np.ndarray, info: str = None) -> "Table": """ Set column data. @@ -356,7 +370,7 @@ class Table: """ dup = self.copy() - dup._add_comment(label,data.shape[1:],info) + dup._add_comment(label, data.shape[1:],info) m = re.match(r'(.*)\[((\d+,)*(\d+))\]',label) if m: key = m.group(1) @@ -369,7 +383,7 @@ class Table: return dup - def add(self,label,data,info=None): + def add(self, label: str, data: np.ndarray, info: str = None) -> "Table": """ Add column data. @@ -401,7 +415,7 @@ class Table: return dup - def delete(self,label): + def delete(self, label: str) -> "Table": """ Delete column data. @@ -422,7 +436,7 @@ class Table: return dup - def rename(self,old,new,info=None): + def rename(self, old: Union[str, List[str]], new: Union[str, List[str]], info: str = None) -> "Table": """ Rename column data. @@ -448,7 +462,7 @@ class Table: return dup - def sort_by(self,labels,ascending=True): + def sort_by(self, labels: Union[str, List[str]], ascending: Union[bool, List[bool]] = True) -> "Table": """ Sort table by values of given labels. @@ -481,7 +495,7 @@ class Table: return dup - def append(self,other): + def append(self, other: "Table") -> "Table": """ Append other table vertically (similar to numpy.vstack). @@ -506,7 +520,7 @@ class Table: return dup - def join(self,other): + def join(self, other: "Table") -> "Table": """ Append other table horizontally (similar to numpy.hstack). @@ -533,7 +547,7 @@ class Table: return dup - def save(self,fname): + def save(self, fname: Union[TextIO, str, pathlib.Path]): """ Save as plain text file. @@ -543,9 +557,9 @@ class Table: Filename or file for writing. """ - seen = set() + seen: Set = set() labels = [] - for l in [x for x in self.data.columns if not (x in seen or seen.add(x))]: + for l in [x for x in self.data.columns if x not in seen]: if self.shapes[l] == (1,): labels.append(f'{l}') elif len(self.shapes[l]) == 1: @@ -555,10 +569,12 @@ class Table: labels += [f'{util.srepr(self.shapes[l],"x")}:{i+1}_{l}' \ for i in range(np.prod(self.shapes[l]))] - try: - fhandle = open(fname,'w',newline='\n') - except TypeError: + if isinstance(fname, TextIO): fhandle = fname + elif isinstance(fname, (str, pathlib.Path)): + fhandle = open(fname,'w',newline='\n') + else: + raise TypeError fhandle.write('\n'.join([f'# {c}' for c in self.comments] + [' '.join(labels)])+'\n') self.data.to_csv(fhandle,sep=' ',na_rep='nan',index=False,header=False)