diff --git a/python/damask/_rotation.py b/python/damask/_rotation.py index f442561cf..cb5b2fbd1 100644 --- a/python/damask/_rotation.py +++ b/python/damask/_rotation.py @@ -53,6 +53,12 @@ class Rotation: """ self.quaternion = quaternion.copy() + + @property + def shape(self): + return self.quaternion.shape[:-1] + + def __copy__(self): """Copy.""" return self.__class__(self.quaternion) @@ -123,35 +129,35 @@ class Rotation: details to be discussed """ - shape = self.quaternion.shape[:-1] - if isinstance(other, Rotation): # rotate a rotation - q_m = self.quaternion[...,0].reshape(shape+(1,)) + q_m = self.quaternion[...,0].reshape(self.shape+(1,)) p_m = self.quaternion[...,1:] - q_o = other.quaternion[...,0].reshape(shape+(1,)) + q_o = other.quaternion[...,0].reshape(self.shape+(1,)) p_o = other.quaternion[...,1:] - q = (q_m*q_o - np.einsum('...i,...i',p_m,p_o).reshape(shape+(1,))) + q = (q_m*q_o - np.einsum('...i,...i',p_m,p_o).reshape(self.shape+(1,))) p = q_m*p_m + q_o*p_m + _P * np.cross(p_m,p_o) return self.__class__(np.block([q,p])).standardize() elif isinstance(other,np.ndarray): - if shape + (3,) == other.shape: + if self.shape + (3,) == other.shape: q_m = self.quaternion[...,0] p_m = self.quaternion[...,1:] A = q_m**2.0 - np.einsum('...i,...i',p_m,p_m) B = 2.0 * np.einsum('...i,...i',p_m,p_m) C = 2.0 * _P * q_m - return np.block([(A * other[...,i]).reshape(shape+(1,)) + - (B * p_m[...,i]).reshape(shape+(1,)) + + return np.block([(A * other[...,i]).reshape(self.shape+(1,)) + + (B * p_m[...,i]).reshape(self.shape+(1,)) + (C * ( p_m[...,(i+1)%3]*other[...,(i+2)%3]\ - - p_m[...,(i+2)%3]*other[...,(i+1)%3])).reshape(shape+(1,)) + - p_m[...,(i+2)%3]*other[...,(i+1)%3])).reshape(self.shape+(1,)) for i in [0,1,2]]) - if shape + (3,3) == other.shape: + if self.shape + (3,3) == other.shape: R = self.asMatrix() return np.einsum('...im,...jn,...mn',R,R,other) - if shape + (3,3,3,3) == other.shape: + if self.shape + (3,3,3,3) == other.shape: R = self.asMatrix() return np.einsum('...im,...jn,...ko,...lp,...mnop',R,R,R,R,other) + else: + raise ValueError def inverse(self): """In-place inverse rotation/backward rotation.""" @@ -186,6 +192,17 @@ class Rotation: return other*self.inversed() + def broadcast_to(self,shape): + if self.shape == (): + q = np.broadcast_to(self.quaternion,shape+(4,)) + else: + q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape+(1,)), + np.broadcast_to(self.quaternion[...,1:2],shape+(1,)), + np.broadcast_to(self.quaternion[...,2:3],shape+(1,)), + np.broadcast_to(self.quaternion[...,3:4],shape+(1,))]) + return self.__class__(q) + + def average(self,other): """ Calculate the average rotation.