Correct shapes for gradient calculations
This commit is contained in:
parent
0650f46ab1
commit
307debebd4
|
@ -189,42 +189,32 @@ class TestGridFilters:
|
||||||
'0.0', '0.0', '0.0',
|
'0.0', '0.0', '0.0',
|
||||||
'0.0', '0.0', 'np.cos(np.pi*2*nodes[...,2]/size[2])*np.pi*2/size[2]']),
|
'0.0', '0.0', 'np.cos(np.pi*2*nodes[...,2]/size[2])*np.pi*2/size[2]']),
|
||||||
|
|
||||||
(['np.sin(np.pi*2*nodes[...,0]/size[0])','np.sin(np.pi*2*nodes[...,1]/size[1])','np.sin(np.pi*2*nodes[...,2]/size[2])'],
|
(['np.sin(np.pi*2*nodes[...,0]/size[0])','np.sin(np.pi*2*nodes[...,1]/size[1])',\
|
||||||
|
'np.sin(np.pi*2*nodes[...,2]/size[2])'],
|
||||||
['np.cos(np.pi*2*nodes[...,0]/size[0])*np.pi*2/size[0]', '0.0', '0.0',
|
['np.cos(np.pi*2*nodes[...,0]/size[0])*np.pi*2/size[0]', '0.0', '0.0',
|
||||||
'0.0', 'np.cos(np.pi*2*nodes[...,1]/size[1])*np.pi*2/size[1]', '0.0',
|
'0.0', 'np.cos(np.pi*2*nodes[...,1]/size[1])*np.pi*2/size[1]', '0.0',
|
||||||
'0.0', '0.0', 'np.cos(np.pi*2*nodes[...,2]/size[2])*np.pi*2/size[2]']),
|
'0.0', '0.0', 'np.cos(np.pi*2*nodes[...,2]/size[2])*np.pi*2/size[2]']),
|
||||||
|
|
||||||
(['np.sin(np.pi*2*nodes[...,0]/size[0])' ],
|
(['np.sin(np.pi*2*nodes[...,0]/size[0])' ],
|
||||||
['np.cos(np.pi*2*nodes[...,0]/size[0])*np.pi*2/size[0]', '0.0', '0.0' ])
|
['np.cos(np.pi*2*nodes[...,0]/size[0])*np.pi*2/size[0]', '0.0', '0.0' ]),
|
||||||
|
|
||||||
|
(['8.0' ],
|
||||||
|
['0.0', '0.0', '0.0' ])
|
||||||
]
|
]
|
||||||
|
|
||||||
# @pytest.mark.parametrize('field_def,grad_def',
|
|
||||||
# [(['0.0', 'np.cos(np.pi*2*nodes[...,1]/size[1])', '0.0' ],
|
|
||||||
# ['0.0', '0.0', '0.0',
|
|
||||||
# '0.0', '-np.pi*2/size[1]*np.sin(np.pi*2*nodes[...,1]/size[1])', '0.0',
|
|
||||||
# '0.0', '0.0', '0.0' ])
|
|
||||||
# ])
|
|
||||||
@pytest.mark.parametrize('field_def,grad_def',grad_test_data)
|
@pytest.mark.parametrize('field_def,grad_def',grad_test_data)
|
||||||
|
|
||||||
def test_grad(self,field_def,grad_def):
|
def test_grad(self,field_def,grad_def):
|
||||||
# size = np.random.random(3)+1.0
|
size = np.random.random(3)+1.0
|
||||||
# grid = np.random.randint(8,32,(3))
|
grid = np.random.randint(8,32,(3))
|
||||||
size = np.array([1.0,1.0,1.0])
|
|
||||||
grid = np.array([2,5,3])
|
|
||||||
|
|
||||||
nodes = grid_filters.cell_coord0(grid,size)
|
nodes = grid_filters.cell_coord0(grid,size)
|
||||||
# print('y nodes are',nodes[...,1])
|
|
||||||
# print('inner bracket is',np.pi*2*nodes[...,1]/size[1])
|
|
||||||
my_locals = locals() # needed for list comprehension
|
my_locals = locals() # needed for list comprehension
|
||||||
|
|
||||||
print('field length is',len(field_def))
|
|
||||||
field = np.stack([np.broadcast_to(eval(f,globals(),my_locals),grid) for f in field_def],axis=-1)
|
field = np.stack([np.broadcast_to(eval(f,globals(),my_locals),grid) for f in field_def],axis=-1)
|
||||||
print('field initial shape is',field.shape)
|
field = field.reshape(tuple(grid) + ((3,) if len(field_def)==3 else (1,)))
|
||||||
field = field.reshape(tuple(grid) + ((3,3) if len(field_def)==9 else (3,)))
|
|
||||||
print('field is',field.shape)
|
|
||||||
grad = np.stack([np.broadcast_to(eval(c,globals(),my_locals),grid) for c in grad_def], axis=-1)
|
grad = np.stack([np.broadcast_to(eval(c,globals(),my_locals),grid) for c in grad_def], axis=-1)
|
||||||
grad = grad.reshape(tuple(grid) + ((3,3) if len(grad_def)==9 else (3,)))
|
grad = grad.reshape(tuple(grid) + ((3,3) if len(grad_def)==9 else (3,)))
|
||||||
print('gradient is',grad.shape)
|
|
||||||
print('code gradient is',grid_filters.gradient(size,field))
|
|
||||||
assert np.allclose(grad,grid_filters.gradient(size,field))
|
assert np.allclose(grad,grid_filters.gradient(size,field))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue