more tests

now 95% test coverage of Rotation class
This commit is contained in:
Martin Diehl 2020-05-20 19:10:16 +02:00
parent 128a96e7f6
commit 353fd3ceb6
2 changed files with 81 additions and 7 deletions

View File

@ -88,7 +88,8 @@ class Rotation:
Todo Todo
---- ----
Document details active/passive) Check rotation of 4th order tensor
""" """
if isinstance(other, Rotation): if isinstance(other, Rotation):
q_m = self.quaternion[...,0:1] q_m = self.quaternion[...,0:1]

View File

@ -510,6 +510,54 @@ def _get_pyramid_order(xyz,direction=None):
#################################################################################################### ####################################################################################################
#################################################################################################### ####################################################################################################
def mul(me, other):
"""
Multiplication.
Parameters
----------
other : numpy.ndarray or Rotation
Vector, second or fourth order tensor, or rotation object that is rotated.
Todo
----
Document details active/passive)
consider rotation of (3,3,3,3)-matrix
"""
if me.quaternion.shape != (4,):
raise NotImplementedError('Support for multiple rotations missing')
if isinstance(other, Rotation):
me_q = me.quaternion[0]
me_p = me.quaternion[1:]
other_q = other.quaternion[0]
other_p = other.quaternion[1:]
R = me.__class__(np.append(me_q*other_q - np.dot(me_p,other_p),
me_q*other_p + other_q*me_p + _P * np.cross(me_p,other_p)))
return R.standardize()
elif isinstance(other, np.ndarray):
if other.shape == (3,):
A = me.quaternion[0]**2.0 - np.dot(me.quaternion[1:],me.quaternion[1:])
B = 2.0 * np.dot(me.quaternion[1:],other)
C = 2.0 * _P*me.quaternion[0]
return A*other + B*me.quaternion[1:] + C * np.cross(me.quaternion[1:],other)
elif other.shape == (3,3,):
R = me.as_matrix()
return np.dot(R,np.dot(other,R.T))
elif other.shape == (3,3,3,3,):
R = me.as_matrix()
return np.einsum('ia,jb,kc,ld,abcd->ijkl',R,R,R,R,other)
RR = np.outer(R, R)
RRRR = np.outer(RR, RR).reshape(4 * (3,3))
axes = ((0, 2, 4, 6), (0, 1, 2, 3))
return np.tensordot(RRRR, other, axes)
else:
raise ValueError('Can only rotate vectors, 2nd order ternsors, and 4th order tensors')
else:
raise TypeError('Cannot rotate {}'.format(type(other)))
class TestRotation: class TestRotation:
@ -616,7 +664,7 @@ class TestRotation:
@pytest.mark.parametrize('forward,backward',[(Rotation._cu2qu,Rotation._qu2cu), @pytest.mark.parametrize('forward,backward',[(Rotation._cu2qu,Rotation._qu2cu),
(Rotation._cu2om,Rotation._om2cu), (Rotation._cu2om,Rotation._om2cu),
#(Rotation._cu2eu,Rotation._eu2cu), (Rotation._cu2eu,Rotation._eu2cu),
(Rotation._cu2ax,Rotation._ax2cu), (Rotation._cu2ax,Rotation._ax2cu),
(Rotation._cu2ro,Rotation._ro2cu), (Rotation._cu2ro,Rotation._ro2cu),
(Rotation._cu2ho,Rotation._ho2cu)]) (Rotation._cu2ho,Rotation._ho2cu)])
@ -626,6 +674,8 @@ class TestRotation:
m = rot.as_cubochoric() m = rot.as_cubochoric()
o = backward(forward(m)) o = backward(forward(m))
ok = np.allclose(m,o,atol=atol) ok = np.allclose(m,o,atol=atol)
if np.count_nonzero(np.isclose(np.abs(o),np.pi**(2./3.)*.5)):
ok = ok or np.allclose(m*-1.,o,atol=atol)
print(m,o,rot.as_quaternion()) print(m,o,rot.as_quaternion())
assert ok and np.max(np.abs(o)) < np.pi**(2./3.) * 0.5 + 1.e-9 assert ok and np.max(np.abs(o)) < np.pi**(2./3.) * 0.5 + 1.e-9
@ -785,21 +835,32 @@ class TestRotation:
assert ok and np.linalg.norm(o) < (3.*np.pi/4.)**(1./3.) + 1.e-9 assert ok and np.linalg.norm(o) < (3.*np.pi/4.)**(1./3.) + 1.e-9
@pytest.mark.parametrize('P',[1,-1]) @pytest.mark.parametrize('P',[1,-1])
def test_quaternion(self,default,P): @pytest.mark.parametrize('accept_homomorph',[True,False])
c = np.array([1,P*-1,P*-1,P*-1]) def test_quaternion(self,default,P,accept_homomorph):
c = np.array([1,P*-1,P*-1,P*-1]) * (-1 if accept_homomorph else 1)
for rot in default: for rot in default:
m = rot.as_cubochoric() m = rot.as_cubochoric()
o = Rotation.from_quaternion(rot.as_quaternion()*c,False,P).as_cubochoric() o = Rotation.from_quaternion(rot.as_quaternion()*c,accept_homomorph,P).as_cubochoric()
ok = np.allclose(m,o,atol=atol) ok = np.allclose(m,o,atol=atol)
if np.count_nonzero(np.isclose(np.abs(o),np.pi**(2./3.)*.5)):
ok = ok or np.allclose(m*-1.,o,atol=atol)
print(m,o,rot.as_quaternion()) print(m,o,rot.as_quaternion())
assert ok and o.max() < np.pi**(2./3.)*0.5+1.e-9 assert ok and o.max() < np.pi**(2./3.)*0.5+1.e-9
@pytest.mark.parametrize('reciprocal',[True,False])
def test_basis(self,default,reciprocal):
for rot in default:
om = rot.as_matrix() + 0.1*np.eye(3)
rot = Rotation.from_basis(om,False,reciprocal=reciprocal)
assert np.isclose(np.linalg.det(rot.as_matrix()),1.0)
@pytest.mark.parametrize('function',[Rotation.from_quaternion, @pytest.mark.parametrize('function',[Rotation.from_quaternion,
Rotation.from_Eulers, Rotation.from_Eulers,
Rotation.from_axis_angle, Rotation.from_axis_angle,
Rotation.from_matrix, Rotation.from_matrix,
Rotation.from_Rodrigues, Rotation.from_Rodrigues,
Rotation.from_homochoric]) Rotation.from_homochoric,
Rotation.from_cubochoric])
def test_invalid_shape(self,function): def test_invalid_shape(self,function):
invalid_shape = np.random.random(np.random.randint(8,32,(3))) invalid_shape = np.random.random(np.random.randint(8,32,(3)))
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -813,7 +874,8 @@ class TestRotation:
(Rotation.from_matrix, np.random.rand(3,3)), (Rotation.from_matrix, np.random.rand(3,3)),
(Rotation.from_Rodrigues, np.array([1,0,0,-1])), (Rotation.from_Rodrigues, np.array([1,0,0,-1])),
(Rotation.from_Rodrigues, np.array([1,1,0,1])), (Rotation.from_Rodrigues, np.array([1,1,0,1])),
(Rotation.from_homochoric, np.array([2,2,2])) ]) (Rotation.from_homochoric, np.array([2,2,2])),
(Rotation.from_cubochoric, np.array([1.1,0,0])) ])
def test_invalid_value(self,function,invalid): def test_invalid_value(self,function,invalid):
with pytest.raises(ValueError): with pytest.raises(ValueError):
function(invalid) function(invalid)
@ -833,6 +895,17 @@ class TestRotation:
assert np.all(np.take_along_axis(np.take_along_axis(a,f,-1),b,-1) == a) assert np.all(np.take_along_axis(np.take_along_axis(a,f,-1),b,-1) == a)
@pytest.mark.parametrize('data',[np.random.rand(5,3),
np.random.rand(5,3,3),
np.random.rand(5,3,3,3,3)])
def test_rotate_vectorization(self,default,data):
for rot in default:
v = rot.broadcast_to((5,)) @ data
for i in range(data.shape[0]):
print(i-data[i])
assert np.allclose(mul(rot,data[i]),v[i])
@pytest.mark.parametrize('data',[np.random.rand(3), @pytest.mark.parametrize('data',[np.random.rand(3),
np.random.rand(3,3), np.random.rand(3,3),
np.random.rand(3,3,3,3)]) np.random.rand(3,3,3,3)])