fixed broadcasting + corresponding test
This commit is contained in:
parent
784d6d09d9
commit
f07eaf19d0
|
@ -161,10 +161,10 @@ class Rotation:
|
||||||
if self.shape == ():
|
if self.shape == ():
|
||||||
q = np.broadcast_to(self.quaternion,shape+(4,))
|
q = np.broadcast_to(self.quaternion,shape+(4,))
|
||||||
else:
|
else:
|
||||||
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape),
|
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape).reshape(shape+(1,)),
|
||||||
np.broadcast_to(self.quaternion[...,1:2],shape),
|
np.broadcast_to(self.quaternion[...,1:2],shape).reshape(shape+(1,)),
|
||||||
np.broadcast_to(self.quaternion[...,2:3],shape),
|
np.broadcast_to(self.quaternion[...,2:3],shape).reshape(shape+(1,)),
|
||||||
np.broadcast_to(self.quaternion[...,3:4],shape)]).reshape(shape+(4,))
|
np.broadcast_to(self.quaternion[...,3:4],shape).reshape(shape+(1,))])
|
||||||
return self.__class__(q)
|
return self.__class__(q)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -873,7 +873,7 @@ class TestRotation:
|
||||||
rot.shape + (np.random.randint(8,32),)
|
rot.shape + (np.random.randint(8,32),)
|
||||||
rot_broadcast = rot.broadcast_to(tuple(new_shape))
|
rot_broadcast = rot.broadcast_to(tuple(new_shape))
|
||||||
for i in range(rot_broadcast.shape[-1]):
|
for i in range(rot_broadcast.shape[-1]):
|
||||||
assert (rot_broadcast.quaternion[...,i,:], rot.quaternion)
|
assert np.allclose(rot_broadcast.quaternion[...,i,:], rot.quaternion)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('function,invalid',[(Rotation.from_quaternion, np.array([-1,0,0,0])),
|
@pytest.mark.parametrize('function,invalid',[(Rotation.from_quaternion, np.array([-1,0,0,0])),
|
||||||
|
|
Loading…
Reference in New Issue