diff --git a/python/damask/_rotation.py b/python/damask/_rotation.py index 387ee84cb..d639cb0e0 100644 --- a/python/damask/_rotation.py +++ b/python/damask/_rotation.py @@ -160,10 +160,10 @@ class Rotation: if self.shape == (): q = np.broadcast_to(self.quaternion,shape+(4,)) else: - q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape+(1,)), - np.broadcast_to(self.quaternion[...,1:2],shape+(1,)), - np.broadcast_to(self.quaternion[...,2:3],shape+(1,)), - np.broadcast_to(self.quaternion[...,3:4],shape+(1,))]) + q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape), + np.broadcast_to(self.quaternion[...,1:2],shape), + np.broadcast_to(self.quaternion[...,2:3],shape), + np.broadcast_to(self.quaternion[...,3:4],shape)]).reshape(shape+(4,)) return self.__class__(q) @@ -537,7 +537,7 @@ class Rotation: ) # reduce Euler angles to definition range eu[np.abs(eu)<1.e-6] = 0.0 - eu = np.where(eu<0, (eu+2.0*np.pi)%np.array([2.0*np.pi,np.pi,2.0*np.pi]),eu) + eu = np.where(eu<0, (eu+2.0*np.pi)%np.array([2.0*np.pi,np.pi,2.0*np.pi]),eu) # needed? return eu @staticmethod diff --git a/python/tests/test_Rotation.py b/python/tests/test_Rotation.py index a9768afb8..5acbca6f1 100644 --- a/python/tests/test_Rotation.py +++ b/python/tests/test_Rotation.py @@ -866,6 +866,16 @@ class TestRotation: with pytest.raises(ValueError): function(invalid_shape) + @pytest.mark.parametrize('shape',[None,(3,),(4,2)]) + def test_broadcast(self,shape): + rot = Rotation.from_random(shape) + new_shape = tuple(np.random.randint(8,32,(3))) if shape is None else \ + rot.shape + (np.random.randint(8,32),) + rot_broadcast = rot.broadcast_to(tuple(new_shape)) + for i in range(rot_broadcast.shape[-1]): + assert (rot_broadcast.quaternion[...,i,:], rot.quaternion) + + @pytest.mark.parametrize('function,invalid',[(Rotation.from_quaternion, np.array([-1,0,0,0])), (Rotation.from_quaternion, np.array([1,1,1,0])), (Rotation.from_Eulers, np.array([1,4,0])),