testing broadcasting
This commit is contained in:
parent
06e4327c0b
commit
b33de48528
|
@ -160,10 +160,10 @@ class Rotation:
|
|||
if self.shape == ():
|
||||
q = np.broadcast_to(self.quaternion,shape+(4,))
|
||||
else:
|
||||
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape+(1,)),
|
||||
np.broadcast_to(self.quaternion[...,1:2],shape+(1,)),
|
||||
np.broadcast_to(self.quaternion[...,2:3],shape+(1,)),
|
||||
np.broadcast_to(self.quaternion[...,3:4],shape+(1,))])
|
||||
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape),
|
||||
np.broadcast_to(self.quaternion[...,1:2],shape),
|
||||
np.broadcast_to(self.quaternion[...,2:3],shape),
|
||||
np.broadcast_to(self.quaternion[...,3:4],shape)]).reshape(shape+(4,))
|
||||
return self.__class__(q)
|
||||
|
||||
|
||||
|
@ -537,7 +537,7 @@ class Rotation:
|
|||
)
|
||||
# reduce Euler angles to definition range
|
||||
eu[np.abs(eu)<1.e-6] = 0.0
|
||||
eu = np.where(eu<0, (eu+2.0*np.pi)%np.array([2.0*np.pi,np.pi,2.0*np.pi]),eu)
|
||||
eu = np.where(eu<0, (eu+2.0*np.pi)%np.array([2.0*np.pi,np.pi,2.0*np.pi]),eu) # needed?
|
||||
return eu
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -866,6 +866,16 @@ class TestRotation:
|
|||
with pytest.raises(ValueError):
|
||||
function(invalid_shape)
|
||||
|
||||
@pytest.mark.parametrize('shape',[None,(3,),(4,2)])
|
||||
def test_broadcast(self,shape):
|
||||
rot = Rotation.from_random(shape)
|
||||
new_shape = tuple(np.random.randint(8,32,(3))) if shape is None else \
|
||||
rot.shape + (np.random.randint(8,32),)
|
||||
rot_broadcast = rot.broadcast_to(tuple(new_shape))
|
||||
for i in range(rot_broadcast.shape[-1]):
|
||||
assert (rot_broadcast.quaternion[...,i,:], rot.quaternion)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('function,invalid',[(Rotation.from_quaternion, np.array([-1,0,0,0])),
|
||||
(Rotation.from_quaternion, np.array([1,1,1,0])),
|
||||
(Rotation.from_Eulers, np.array([1,4,0])),
|
||||
|
|
Loading…
Reference in New Issue