added getitem and where functionality to Table

This commit is contained in:
Philip Eisenlohr 2020-12-02 19:25:54 -05:00
parent e7e8d7be21
commit 4877334986
2 changed files with 27 additions and 3 deletions

View File

@ -33,6 +33,10 @@ class Table:
"""Brief overview.""" """Brief overview."""
return '\n'.join(['# '+c for c in self.comments])+'\n'+self.data.__repr__() return '\n'.join(['# '+c for c in self.comments])+'\n'+self.data.__repr__()
def __getitem__(self,item):
"""Return slice according to item."""
return self.__class__(data=self.data[item],shapes=self.shapes,comments=self.comments)
def __len__(self): def __len__(self):
"""Number of rows.""" """Number of rows."""
return len(self.data) return len(self.data)
@ -45,6 +49,15 @@ class Table:
"""Copy Table.""" """Copy Table."""
return self.__copy__() return self.__copy__()
def where(self,expression):
"""
Return boolean array corresponding to interpolated expression being True.
Table columns are addressed as #column# and will have appropriate shapes.
"""
return eval(re.sub('#(.+?)#',r'self.get("\1")',expression))
def _label_discrete(self): def _label_discrete(self):
"""Label data individually, e.g. v v v ==> 1_v 2_v 3_v.""" """Label data individually, e.g. v v v ==> 1_v 2_v 3_v."""

View File

@ -8,7 +8,7 @@ from damask import Table
def default(): def default():
"""Simple Table.""" """Simple Table."""
x = np.ones((5,13),dtype=float) x = np.ones((5,13),dtype=float)
return Table(x,{'F':(3,3),'v':(3,),'s':(1,)},['test data','contains only ones']) return Table(x,{'F':(3,3),'v':(3,),'s':(1,)},['test data','contains five rows of only ones'])
@pytest.fixture @pytest.fixture
def reference_dir(reference_dir_base): def reference_dir(reference_dir_base):
@ -20,8 +20,9 @@ class TestTable:
def test_repr(self,default): def test_repr(self,default):
print(default) print(default)
def test_len(self): @pytest.mark.parametrize('N',[10,40])
len(Table(np.random.rand(7,3),{'X':3})) == 7 def test_len(self,N):
len(Table(np.random.rand(N,3),{'X':3})) == N
def test_get_scalar(self,default): def test_get_scalar(self,default):
d = default.get('s') d = default.get('s')
@ -39,6 +40,16 @@ class TestTable:
d = default.get('5_F') d = default.get('5_F')
assert np.allclose(d,1.0) and d.shape[1:] == (1,) assert np.allclose(d,1.0) and d.shape[1:] == (1,)
@pytest.mark.parametrize('N',[10,40])
def test_getitem(self,N):
assert len(Table(np.random.rand(N,1),{'X':1})[:N//2]) == N//2
@pytest.mark.parametrize('N',[10,40])
@pytest.mark.parametrize('limit',[0.1,0.6])
def test_where(self,N,limit):
r = Table(np.random.rand(N,1),{'X':1})
assert np.all(r[r.where(f'#X# > {limit}')].get('X') > limit)
@pytest.mark.parametrize('mode',['str','path']) @pytest.mark.parametrize('mode',['str','path'])
def test_write_read(self,default,tmp_path,mode): def test_write_read(self,default,tmp_path,mode):
default.save(tmp_path/'default.txt') default.save(tmp_path/'default.txt')