shape property and numpy-like broadcasting

this makes it easy to apply a single rotation to a field
This commit is contained in:
Martin Diehl 2020-05-02 15:50:46 +02:00
parent f0bb50b97d
commit ef4a4dad4a
1 changed files with 28 additions and 11 deletions

View File

@ -53,6 +53,12 @@ class Rotation:
""" """
self.quaternion = quaternion.copy() self.quaternion = quaternion.copy()
@property
def shape(self):
return self.quaternion.shape[:-1]
def __copy__(self): def __copy__(self):
"""Copy.""" """Copy."""
return self.__class__(self.quaternion) return self.__class__(self.quaternion)
@ -123,35 +129,35 @@ class Rotation:
details to be discussed details to be discussed
""" """
shape = self.quaternion.shape[:-1]
if isinstance(other, Rotation): # rotate a rotation if isinstance(other, Rotation): # rotate a rotation
q_m = self.quaternion[...,0].reshape(shape+(1,)) q_m = self.quaternion[...,0].reshape(self.shape+(1,))
p_m = self.quaternion[...,1:] p_m = self.quaternion[...,1:]
q_o = other.quaternion[...,0].reshape(shape+(1,)) q_o = other.quaternion[...,0].reshape(self.shape+(1,))
p_o = other.quaternion[...,1:] p_o = other.quaternion[...,1:]
q = (q_m*q_o - np.einsum('...i,...i',p_m,p_o).reshape(shape+(1,))) q = (q_m*q_o - np.einsum('...i,...i',p_m,p_o).reshape(self.shape+(1,)))
p = q_m*p_m + q_o*p_m + _P * np.cross(p_m,p_o) p = q_m*p_m + q_o*p_m + _P * np.cross(p_m,p_o)
return self.__class__(np.block([q,p])).standardize() return self.__class__(np.block([q,p])).standardize()
elif isinstance(other,np.ndarray): elif isinstance(other,np.ndarray):
if shape + (3,) == other.shape: if self.shape + (3,) == other.shape:
q_m = self.quaternion[...,0] q_m = self.quaternion[...,0]
p_m = self.quaternion[...,1:] p_m = self.quaternion[...,1:]
A = q_m**2.0 - np.einsum('...i,...i',p_m,p_m) A = q_m**2.0 - np.einsum('...i,...i',p_m,p_m)
B = 2.0 * np.einsum('...i,...i',p_m,p_m) B = 2.0 * np.einsum('...i,...i',p_m,p_m)
C = 2.0 * _P * q_m C = 2.0 * _P * q_m
return np.block([(A * other[...,i]).reshape(shape+(1,)) + return np.block([(A * other[...,i]).reshape(self.shape+(1,)) +
(B * p_m[...,i]).reshape(shape+(1,)) + (B * p_m[...,i]).reshape(self.shape+(1,)) +
(C * ( p_m[...,(i+1)%3]*other[...,(i+2)%3]\ (C * ( p_m[...,(i+1)%3]*other[...,(i+2)%3]\
- p_m[...,(i+2)%3]*other[...,(i+1)%3])).reshape(shape+(1,)) - p_m[...,(i+2)%3]*other[...,(i+1)%3])).reshape(self.shape+(1,))
for i in [0,1,2]]) for i in [0,1,2]])
if shape + (3,3) == other.shape: if self.shape + (3,3) == other.shape:
R = self.asMatrix() R = self.asMatrix()
return np.einsum('...im,...jn,...mn',R,R,other) return np.einsum('...im,...jn,...mn',R,R,other)
if shape + (3,3,3,3) == other.shape: if self.shape + (3,3,3,3) == other.shape:
R = self.asMatrix() R = self.asMatrix()
return np.einsum('...im,...jn,...ko,...lp,...mnop',R,R,R,R,other) return np.einsum('...im,...jn,...ko,...lp,...mnop',R,R,R,R,other)
else:
raise ValueError
def inverse(self): def inverse(self):
"""In-place inverse rotation/backward rotation.""" """In-place inverse rotation/backward rotation."""
@ -186,6 +192,17 @@ class Rotation:
return other*self.inversed() return other*self.inversed()
def broadcast_to(self,shape):
if self.shape == ():
q = np.broadcast_to(self.quaternion,shape+(4,))
else:
q = np.block([np.broadcast_to(self.quaternion[...,0:1],shape+(1,)),
np.broadcast_to(self.quaternion[...,1:2],shape+(1,)),
np.broadcast_to(self.quaternion[...,2:3],shape+(1,)),
np.broadcast_to(self.quaternion[...,3:4],shape+(1,))])
return self.__class__(q)
def average(self,other): def average(self,other):
""" """
Calculate the average rotation. Calculate the average rotation.