From 0300912b3032eb7c1a309e620de01c0efecb4b3a Mon Sep 17 00:00:00 2001 From: Philip Eisenlohr Date: Sun, 13 Feb 2022 22:00:48 -0500 Subject: [PATCH] Table.__eq__ for proper comparison; logical masks for slicing now work --- python/damask/_table.py | 27 ++++++++++++++++++--------- python/tests/test_Table.py | 6 +++++- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/python/damask/_table.py b/python/damask/_table.py index 1572c4f76..81c23e73b 100644 --- a/python/damask/_table.py +++ b/python/damask/_table.py @@ -44,6 +44,13 @@ class Table: return '\n'.join(['# '+c for c in self.comments])+'\n'+data_repr + def __eq__(self, + other: object) -> bool: + """Compare to other Table.""" + return NotImplemented if not isinstance(other,Table) else \ + self.shapes == other.shapes and self.data.equals(other.data) + + def __getitem__(self, item: Union[slice, Tuple[slice, ...]]) -> 'Table': """ @@ -75,20 +82,22 @@ class Table: colB colA 0 1 0 2 7 6 - >>> tbl[1:2,'colB'] + >>> tbl[[True,False,False,True],'colB'] colB - 1 4 - 2 7 + 0 1 + 3 10 """ - item = (item,slice(None,None,None)) if isinstance(item,slice) else \ - item if isinstance(item[0],slice) else \ - (slice(None,None,None),item) - sliced = self.data.loc[item] - cols = np.array(sliced.columns if isinstance(sliced,pd.core.frame.DataFrame) else [item[1]]) + item_ = (item,slice(None,None,None)) if isinstance(item,(slice,np.ndarray)) else \ + (np.array(item),slice(None,None,None)) if isinstance(item,list) and np.array(item).dtype == np.bool_ else \ + (np.array(item[0]),item[1]) if isinstance(item[0],list) else \ + item if isinstance(item[0],(slice,np.ndarray)) else \ + (slice(None,None,None),item) + sliced = self.data.loc[item_] + cols = np.array(sliced.columns if isinstance(sliced,pd.core.frame.DataFrame) else [item_[1]]) _,idx = np.unique(cols,return_index=True) return self.__class__(data=sliced, - shapes = {k:self.shapes[k] for k in cols[np.sort(idx)]}, + shapes={k:self.shapes[k] for k in cols[np.sort(idx)]}, comments=self.comments) diff --git a/python/tests/test_Table.py b/python/tests/test_Table.py index 62eb6d63b..1f89026a3 100644 --- a/python/tests/test_Table.py +++ b/python/tests/test_Table.py @@ -59,10 +59,14 @@ class TestTable: @pytest.mark.parametrize('N',[1,3,4]) def test_slice(self,default,N): + mask = np.random.choice([True,False],len(default)) assert len(default[:N]) == 1+N assert len(default[:N,['F','s']]) == 1+N + assert len(default[mask,['F','s']]) == np.count_nonzero(mask) + assert default[mask,['F','s']] == default[mask][['F','s']] == default[['F','s']][mask] + assert default[np.logical_not(mask),['F','s']] != default[mask][['F','s']] assert default[N:].get('F').shape == (len(default)-N,3,3) - assert (default[:N,['v','s']].data == default['v','s'][:N].data).all().all() + assert default[:N,['v','s']].data.equals(default['v','s'][:N].data) @pytest.mark.parametrize('mode',['str','path']) def test_write_read(self,default,tmp_path,mode):