added getitem and where functionality to Table
This commit is contained in:
parent
e7e8d7be21
commit
4877334986
|
@ -33,6 +33,10 @@ class Table:
|
||||||
"""Brief overview."""
|
"""Brief overview."""
|
||||||
return '\n'.join(['# '+c for c in self.comments])+'\n'+self.data.__repr__()
|
return '\n'.join(['# '+c for c in self.comments])+'\n'+self.data.__repr__()
|
||||||
|
|
||||||
|
def __getitem__(self,item):
|
||||||
|
"""Return slice according to item."""
|
||||||
|
return self.__class__(data=self.data[item],shapes=self.shapes,comments=self.comments)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""Number of rows."""
|
"""Number of rows."""
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
@ -45,6 +49,15 @@ class Table:
|
||||||
"""Copy Table."""
|
"""Copy Table."""
|
||||||
return self.__copy__()
|
return self.__copy__()
|
||||||
|
|
||||||
|
def where(self,expression):
|
||||||
|
"""
|
||||||
|
Return boolean array corresponding to interpolated expression being True.
|
||||||
|
|
||||||
|
Table columns are addressed as #column# and will have appropriate shapes.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return eval(re.sub('#(.+?)#',r'self.get("\1")',expression))
|
||||||
|
|
||||||
|
|
||||||
def _label_discrete(self):
|
def _label_discrete(self):
|
||||||
"""Label data individually, e.g. v v v ==> 1_v 2_v 3_v."""
|
"""Label data individually, e.g. v v v ==> 1_v 2_v 3_v."""
|
||||||
|
|
|
@ -8,7 +8,7 @@ from damask import Table
|
||||||
def default():
|
def default():
|
||||||
"""Simple Table."""
|
"""Simple Table."""
|
||||||
x = np.ones((5,13),dtype=float)
|
x = np.ones((5,13),dtype=float)
|
||||||
return Table(x,{'F':(3,3),'v':(3,),'s':(1,)},['test data','contains only ones'])
|
return Table(x,{'F':(3,3),'v':(3,),'s':(1,)},['test data','contains five rows of only ones'])
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def reference_dir(reference_dir_base):
|
def reference_dir(reference_dir_base):
|
||||||
|
@ -20,8 +20,9 @@ class TestTable:
|
||||||
def test_repr(self,default):
|
def test_repr(self,default):
|
||||||
print(default)
|
print(default)
|
||||||
|
|
||||||
def test_len(self):
|
@pytest.mark.parametrize('N',[10,40])
|
||||||
len(Table(np.random.rand(7,3),{'X':3})) == 7
|
def test_len(self,N):
|
||||||
|
len(Table(np.random.rand(N,3),{'X':3})) == N
|
||||||
|
|
||||||
def test_get_scalar(self,default):
|
def test_get_scalar(self,default):
|
||||||
d = default.get('s')
|
d = default.get('s')
|
||||||
|
@ -39,6 +40,16 @@ class TestTable:
|
||||||
d = default.get('5_F')
|
d = default.get('5_F')
|
||||||
assert np.allclose(d,1.0) and d.shape[1:] == (1,)
|
assert np.allclose(d,1.0) and d.shape[1:] == (1,)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('N',[10,40])
|
||||||
|
def test_getitem(self,N):
|
||||||
|
assert len(Table(np.random.rand(N,1),{'X':1})[:N//2]) == N//2
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('N',[10,40])
|
||||||
|
@pytest.mark.parametrize('limit',[0.1,0.6])
|
||||||
|
def test_where(self,N,limit):
|
||||||
|
r = Table(np.random.rand(N,1),{'X':1})
|
||||||
|
assert np.all(r[r.where(f'#X# > {limit}')].get('X') > limit)
|
||||||
|
|
||||||
@pytest.mark.parametrize('mode',['str','path'])
|
@pytest.mark.parametrize('mode',['str','path'])
|
||||||
def test_write_read(self,default,tmp_path,mode):
|
def test_write_read(self,default,tmp_path,mode):
|
||||||
default.save(tmp_path/'default.txt')
|
default.save(tmp_path/'default.txt')
|
||||||
|
|
Loading…
Reference in New Issue