testing broadcasting
This commit is contained in:
parent
06e4327c0b
commit
b33de48528
|
@ -160,10 +160,10 @@ class Rotation:
|
||||||
if self.shape == ():
|
if self.shape == ():
|
||||||
q = np.broadcast_to(self.quaternion,shape+(4,))
|
q = np.broadcast_to(self.quaternion,shape+(4,))
|
||||||
else:
|
else:
|
||||||
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape+(1,)),
|
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape),
|
||||||
np.broadcast_to(self.quaternion[...,1:2],shape+(1,)),
|
np.broadcast_to(self.quaternion[...,1:2],shape),
|
||||||
np.broadcast_to(self.quaternion[...,2:3],shape+(1,)),
|
np.broadcast_to(self.quaternion[...,2:3],shape),
|
||||||
np.broadcast_to(self.quaternion[...,3:4],shape+(1,))])
|
np.broadcast_to(self.quaternion[...,3:4],shape)]).reshape(shape+(4,))
|
||||||
return self.__class__(q)
|
return self.__class__(q)
|
||||||
|
|
||||||
|
|
||||||
|
@ -537,7 +537,7 @@ class Rotation:
|
||||||
)
|
)
|
||||||
# reduce Euler angles to definition range
|
# reduce Euler angles to definition range
|
||||||
eu[np.abs(eu)<1.e-6] = 0.0
|
eu[np.abs(eu)<1.e-6] = 0.0
|
||||||
eu = np.where(eu<0, (eu+2.0*np.pi)%np.array([2.0*np.pi,np.pi,2.0*np.pi]),eu)
|
eu = np.where(eu<0, (eu+2.0*np.pi)%np.array([2.0*np.pi,np.pi,2.0*np.pi]),eu) # needed?
|
||||||
return eu
|
return eu
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -866,6 +866,16 @@ class TestRotation:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
function(invalid_shape)
|
function(invalid_shape)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('shape',[None,(3,),(4,2)])
|
||||||
|
def test_broadcast(self,shape):
|
||||||
|
rot = Rotation.from_random(shape)
|
||||||
|
new_shape = tuple(np.random.randint(8,32,(3))) if shape is None else \
|
||||||
|
rot.shape + (np.random.randint(8,32),)
|
||||||
|
rot_broadcast = rot.broadcast_to(tuple(new_shape))
|
||||||
|
for i in range(rot_broadcast.shape[-1]):
|
||||||
|
assert (rot_broadcast.quaternion[...,i,:], rot.quaternion)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('function,invalid',[(Rotation.from_quaternion, np.array([-1,0,0,0])),
|
@pytest.mark.parametrize('function,invalid',[(Rotation.from_quaternion, np.array([-1,0,0,0])),
|
||||||
(Rotation.from_quaternion, np.array([1,1,1,0])),
|
(Rotation.from_quaternion, np.array([1,1,1,0])),
|
||||||
(Rotation.from_Eulers, np.array([1,4,0])),
|
(Rotation.from_Eulers, np.array([1,4,0])),
|
||||||
|
|
Loading…
Reference in New Issue