fixed broadcasting + corresponding test
This commit is contained in:
parent
784d6d09d9
commit
f07eaf19d0
|
@ -161,10 +161,10 @@ class Rotation:
|
|||
if self.shape == ():
|
||||
q = np.broadcast_to(self.quaternion,shape+(4,))
|
||||
else:
|
||||
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape),
|
||||
np.broadcast_to(self.quaternion[...,1:2],shape),
|
||||
np.broadcast_to(self.quaternion[...,2:3],shape),
|
||||
np.broadcast_to(self.quaternion[...,3:4],shape)]).reshape(shape+(4,))
|
||||
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape).reshape(shape+(1,)),
|
||||
np.broadcast_to(self.quaternion[...,1:2],shape).reshape(shape+(1,)),
|
||||
np.broadcast_to(self.quaternion[...,2:3],shape).reshape(shape+(1,)),
|
||||
np.broadcast_to(self.quaternion[...,3:4],shape).reshape(shape+(1,))])
|
||||
return self.__class__(q)
|
||||
|
||||
|
||||
|
|
|
@ -873,7 +873,7 @@ class TestRotation:
|
|||
rot.shape + (np.random.randint(8,32),)
|
||||
rot_broadcast = rot.broadcast_to(tuple(new_shape))
|
||||
for i in range(rot_broadcast.shape[-1]):
|
||||
assert (rot_broadcast.quaternion[...,i,:], rot.quaternion)
|
||||
assert np.allclose(rot_broadcast.quaternion[...,i,:], rot.quaternion)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('function,invalid',[(Rotation.from_quaternion, np.array([-1,0,0,0])),
|
||||
|
|
Loading…
Reference in New Issue