testing broadcasting

This commit is contained in:
Martin Diehl 2020-05-22 15:12:37 +02:00
parent 06e4327c0b
commit b33de48528
2 changed files with 15 additions and 5 deletions

View File

@ -160,10 +160,10 @@ class Rotation:
if self.shape == ():
q = np.broadcast_to(self.quaternion,shape+(4,))
else:
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape+(1,)),
np.broadcast_to(self.quaternion[...,1:2],shape+(1,)),
np.broadcast_to(self.quaternion[...,2:3],shape+(1,)),
np.broadcast_to(self.quaternion[...,3:4],shape+(1,))])
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape),
np.broadcast_to(self.quaternion[...,1:2],shape),
np.broadcast_to(self.quaternion[...,2:3],shape),
np.broadcast_to(self.quaternion[...,3:4],shape)]).reshape(shape+(4,))
return self.__class__(q)
@ -537,7 +537,7 @@ class Rotation:
)
# reduce Euler angles to definition range
eu[np.abs(eu)<1.e-6] = 0.0
eu = np.where(eu<0, (eu+2.0*np.pi)%np.array([2.0*np.pi,np.pi,2.0*np.pi]),eu)
eu = np.where(eu<0, (eu+2.0*np.pi)%np.array([2.0*np.pi,np.pi,2.0*np.pi]),eu) # needed?
return eu
@staticmethod

View File

@ -866,6 +866,16 @@ class TestRotation:
with pytest.raises(ValueError):
function(invalid_shape)
@pytest.mark.parametrize('shape',[None,(3,),(4,2)])
def test_broadcast(self,shape):
rot = Rotation.from_random(shape)
new_shape = tuple(np.random.randint(8,32,(3))) if shape is None else \
rot.shape + (np.random.randint(8,32),)
rot_broadcast = rot.broadcast_to(tuple(new_shape))
for i in range(rot_broadcast.shape[-1]):
assert (rot_broadcast.quaternion[...,i,:], rot.quaternion)
@pytest.mark.parametrize('function,invalid',[(Rotation.from_quaternion, np.array([-1,0,0,0])),
(Rotation.from_quaternion, np.array([1,1,1,0])),
(Rotation.from_Eulers, np.array([1,4,0])),