fixed broadcasting + corresponding test
This commit is contained in:
parent
784d6d09d9
commit
f07eaf19d0
|
@ -12,7 +12,7 @@ _R1 = (3.*np.pi/4.)**(1./3.)
|
|||
class Rotation:
|
||||
u"""
|
||||
Orientation stored with functionality for conversion to different representations.
|
||||
|
||||
|
||||
The following conventions apply:
|
||||
|
||||
- coordinate frames are right-handed.
|
||||
|
@ -161,10 +161,10 @@ class Rotation:
|
|||
if self.shape == ():
|
||||
q = np.broadcast_to(self.quaternion,shape+(4,))
|
||||
else:
|
||||
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape),
|
||||
np.broadcast_to(self.quaternion[...,1:2],shape),
|
||||
np.broadcast_to(self.quaternion[...,2:3],shape),
|
||||
np.broadcast_to(self.quaternion[...,3:4],shape)]).reshape(shape+(4,))
|
||||
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape).reshape(shape+(1,)),
|
||||
np.broadcast_to(self.quaternion[...,1:2],shape).reshape(shape+(1,)),
|
||||
np.broadcast_to(self.quaternion[...,2:3],shape).reshape(shape+(1,)),
|
||||
np.broadcast_to(self.quaternion[...,3:4],shape).reshape(shape+(1,))])
|
||||
return self.__class__(q)
|
||||
|
||||
|
||||
|
|
|
@ -873,7 +873,7 @@ class TestRotation:
|
|||
rot.shape + (np.random.randint(8,32),)
|
||||
rot_broadcast = rot.broadcast_to(tuple(new_shape))
|
||||
for i in range(rot_broadcast.shape[-1]):
|
||||
assert (rot_broadcast.quaternion[...,i,:], rot.quaternion)
|
||||
assert np.allclose(rot_broadcast.quaternion[...,i,:], rot.quaternion)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('function,invalid',[(Rotation.from_quaternion, np.array([-1,0,0,0])),
|
||||
|
|
Loading…
Reference in New Issue