Merge branch 'improved-table-slicing' into 'development'

Extended slicing functionality of Table

See merge request damask/DAMASK!522
This commit is contained in:
Franz Roters 2022-02-15 12:48:39 +00:00
commit d4f7114164
2 changed files with 23 additions and 10 deletions

View File

@ -44,6 +44,13 @@ class Table:
return '\n'.join(['# '+c for c in self.comments])+'\n'+data_repr
def __eq__(self,
other: object) -> bool:
"""Compare to other Table."""
return NotImplemented if not isinstance(other,Table) else \
self.shapes == other.shapes and self.data.equals(other.data)
def __getitem__(self,
item: Union[slice, Tuple[slice, ...]]) -> 'Table':
"""
@ -75,20 +82,22 @@ class Table:
colB colA
0 1 0
2 7 6
>>> tbl[1:2,'colB']
>>> tbl[[True,False,False,True],'colB']
colB
1 4
2 7
0 1
3 10
"""
item = (item,slice(None,None,None)) if isinstance(item,slice) else \
item if isinstance(item[0],slice) else \
item_ = (item,slice(None,None,None)) if isinstance(item,(slice,np.ndarray)) else \
(np.array(item),slice(None,None,None)) if isinstance(item,list) and np.array(item).dtype == np.bool_ else \
(np.array(item[0]),item[1]) if isinstance(item[0],list) else \
item if isinstance(item[0],(slice,np.ndarray)) else \
(slice(None,None,None),item)
sliced = self.data.loc[item]
cols = np.array(sliced.columns if isinstance(sliced,pd.core.frame.DataFrame) else [item[1]])
sliced = self.data.loc[item_]
cols = np.array(sliced.columns if isinstance(sliced,pd.core.frame.DataFrame) else [item_[1]])
_,idx = np.unique(cols,return_index=True)
return self.__class__(data=sliced,
shapes = {k:self.shapes[k] for k in cols[np.sort(idx)]},
shapes={k:self.shapes[k] for k in cols[np.sort(idx)]},
comments=self.comments)

View File

@ -59,10 +59,14 @@ class TestTable:
@pytest.mark.parametrize('N',[1,3,4])
def test_slice(self,default,N):
mask = np.random.choice([True,False],len(default))
assert len(default[:N]) == 1+N
assert len(default[:N,['F','s']]) == 1+N
assert len(default[mask,['F','s']]) == np.count_nonzero(mask)
assert default[mask,['F','s']] == default[mask][['F','s']] == default[['F','s']][mask]
assert default[np.logical_not(mask),['F','s']] != default[mask][['F','s']]
assert default[N:].get('F').shape == (len(default)-N,3,3)
assert (default[:N,['v','s']].data == default['v','s'][:N].data).all().all()
assert default[:N,['v','s']].data.equals(default['v','s'][:N].data)
@pytest.mark.parametrize('mode',['str','path'])
def test_write_read(self,default,tmp_path,mode):