fixed broadcasting + corresponding test

This commit is contained in:
Martin Diehl 2020-05-25 16:37:08 +02:00
parent 784d6d09d9
commit f07eaf19d0
2 changed files with 6 additions and 6 deletions

View File

@ -161,10 +161,10 @@ class Rotation:
if self.shape == ():
q = np.broadcast_to(self.quaternion,shape+(4,))
else:
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape),
np.broadcast_to(self.quaternion[...,1:2],shape),
np.broadcast_to(self.quaternion[...,2:3],shape),
np.broadcast_to(self.quaternion[...,3:4],shape)]).reshape(shape+(4,))
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape).reshape(shape+(1,)),
np.broadcast_to(self.quaternion[...,1:2],shape).reshape(shape+(1,)),
np.broadcast_to(self.quaternion[...,2:3],shape).reshape(shape+(1,)),
np.broadcast_to(self.quaternion[...,3:4],shape).reshape(shape+(1,))])
return self.__class__(q)

View File

@ -873,7 +873,7 @@ class TestRotation:
rot.shape + (np.random.randint(8,32),)
rot_broadcast = rot.broadcast_to(tuple(new_shape))
for i in range(rot_broadcast.shape[-1]):
assert (rot_broadcast.quaternion[...,i,:], rot.quaternion)
assert np.allclose(rot_broadcast.quaternion[...,i,:], rot.quaternion)
@pytest.mark.parametrize('function,invalid',[(Rotation.from_quaternion, np.array([-1,0,0,0])),