just to make sure ...

This commit is contained in:
Martin Diehl 2019-12-21 23:43:56 +01:00
parent 343da2e54d
commit e08b096f08
1 changed files with 19 additions and 1 deletions

View File

@ -30,18 +30,36 @@ class TestGridFilters:
grid = np.random.randint(8,32,(3)) grid = np.random.randint(8,32,(3))
size = np.random.random(3) size = np.random.random(3)
origin = np.random.random(3) origin = np.random.random(3)
coord0 = eval('grid_filters.{}_coord0(grid,size,origin)'.format(mode)) # noqa coord0 = eval('grid_filters.{}_coord0(grid,size,origin)'.format(mode)) # noqa
_grid,_size,_origin = eval('grid_filters.{}_coord0_2_DNA(coord0.reshape((-1,3)))'.format(mode)) _grid,_size,_origin = eval('grid_filters.{}_coord0_2_DNA(coord0.reshape((-1,3)))'.format(mode))
assert np.allclose(grid,_grid) and np.allclose(size,_size) and np.allclose(origin,_origin) assert np.allclose(grid,_grid) and np.allclose(size,_size) and np.allclose(origin,_origin)
def test_displacement_fluct_equivalence(self): def test_displacement_fluct_equivalence(self):
"""Ensure that fluctuations are periodic."""
size = np.random.random(3) size = np.random.random(3)
grid = np.random.randint(8,32,(3)) grid = np.random.randint(8,32,(3))
F = np.random.random(tuple(grid)+(3,3)) F = np.random.random(tuple(grid)+(3,3))
assert np.allclose(grid_filters.node_displacement_fluct(size,F), assert np.allclose(grid_filters.node_displacement_fluct(size,F),
grid_filters.cell_2_node(grid_filters.cell_displacement_fluct(size,F))) grid_filters.cell_2_node(grid_filters.cell_displacement_fluct(size,F)))
def test_interpolation_nonperiodic(self):
size = np.random.random(3)
grid = np.random.randint(8,32,(3))
F = np.random.random(tuple(grid)+(3,3))
assert np.allclose(grid_filters.node_coord(size,F) [1:-1,1:-1,1:-1],grid_filters.cell_2_node(
grid_filters.cell_coord(size,F))[1:-1,1:-1,1:-1])
@pytest.mark.parametrize('mode',[('cell'),('node')])
def test_coord0_origin(self,mode):
origin= np.random.random(3)
size = np.random.random(3) # noqa
grid = np.random.randint(8,32,(3))
shifted = eval('grid_filters.{}_coord0(grid,size,origin)'.format(mode))
unshifted = eval('grid_filters.{}_coord0(grid,size)'.format(mode))
if mode == 'cell':
assert np.allclose(shifted,unshifted+np.broadcast_to(origin,tuple(grid[::-1]) +(3,)))
elif mode == 'node':
assert np.allclose(shifted,unshifted+np.broadcast_to(origin,tuple(grid[::-1]+1)+(3,)))
@pytest.mark.parametrize('mode',[('cell'),('node')]) @pytest.mark.parametrize('mode',[('cell'),('node')])
def test_displacement_avg_vanishes(self,mode): def test_displacement_avg_vanishes(self,mode):