added generic types to remaining non-overwritten rotation functions (exception __mul__)

This commit is contained in:
Daniel Otto de Mentock 2022-02-04 09:57:42 +01:00
parent d1f9e98e3c
commit 9dad54304c
2 changed files with 15 additions and 16 deletions

View File

@ -165,7 +165,6 @@ class Orientation(Rotation,Crystal):
Orientation to check for equality. Orientation to check for equality.
""" """
eq = self.__eq__(other) eq = self.__eq__(other)
if not isinstance(eq, bool): if not isinstance(eq, bool):
return eq return eq
@ -501,8 +500,8 @@ class Orientation(Rotation,Crystal):
return np.ones_like(rho[...,0],dtype=bool) return np.ones_like(rho[...,0],dtype=bool)
def disorientation(self, def disorientation(self,
other, other: "Orientation",
return_operators = False): return_operators: bool = False) -> object:
""" """
Calculate disorientation between myself and given other orientation. Calculate disorientation between myself and given other orientation.
@ -575,9 +574,9 @@ class Orientation(Rotation,Crystal):
r = np.where(np.any(forward[...,np.newaxis],axis=(0,1),keepdims=True), r = np.where(np.any(forward[...,np.newaxis],axis=(0,1),keepdims=True),
r_.quaternion, r_.quaternion,
_r.quaternion) _r.quaternion)
loc: Tuple[float] = np.where(ok) loc = np.where(ok)
sort: np.ndarray = 0 if len(loc) == 2 else np.lexsort(loc[:1:-1]) sort = 0 if len(loc) == 2 else np.lexsort(loc[:1:-1])
quat: np.ndarray = r[ok][sort].reshape(blend+(4,)) quat = r[ok][sort].reshape(blend+(4,))
return ( return (
(self.copy(rotation=quat), (self.copy(rotation=quat),
(np.vstack(loc[:2]).T)[sort].reshape(blend+(2,))) (np.vstack(loc[:2]).T)[sort].reshape(blend+(2,)))

View File

@ -259,7 +259,7 @@ class Rotation:
return self**exp return self**exp
def __mul__(self, other: 'Rotation') -> 'Rotation': def __mul__(self: MyType, other: MyType) -> MyType:
""" """
Compose with other. Compose with other.
@ -281,12 +281,12 @@ class Rotation:
p_o = other.quaternion[...,1:] p_o = other.quaternion[...,1:]
q = (q_m*q_o - np.einsum('...i,...i',p_m,p_o).reshape(self.shape+(1,))) q = (q_m*q_o - np.einsum('...i,...i',p_m,p_o).reshape(self.shape+(1,)))
p = q_m*p_o + q_o*p_m + _P * np.cross(p_m,p_o) p = q_m*p_o + q_o*p_m + _P * np.cross(p_m,p_o)
return Rotation(np.block([q,p]))._standardize() return Rotation(np.block([q,p]))._standardize() #type: ignore
else: else:
raise TypeError('Use "R@b", i.e. matmul, to apply rotation "R" to object "b"') raise TypeError('Use "R@b", i.e. matmul, to apply rotation "R" to object "b"')
def __imul__(self, def __imul__(self: MyType,
other: 'Rotation') -> 'Rotation': other: MyType) -> MyType:
""" """
Compose with other (in-place). Compose with other (in-place).
@ -299,8 +299,8 @@ class Rotation:
return self*other return self*other
def __truediv__(self: 'Rotation', def __truediv__(self: MyType,
other: 'Rotation') -> 'Rotation': other: MyType) -> MyType:
""" """
Compose with inverse of other. Compose with inverse of other.
@ -320,8 +320,8 @@ class Rotation:
else: else:
raise TypeError('Use "R@b", i.e. matmul, to apply rotation "R" to object "b"') raise TypeError('Use "R@b", i.e. matmul, to apply rotation "R" to object "b"')
def __itruediv__(self: 'Rotation', def __itruediv__(self: MyType,
other: 'Rotation') -> 'Rotation': other: MyType) -> MyType:
""" """
Compose with inverse of other (in-place). Compose with inverse of other (in-place).
@ -492,8 +492,8 @@ class Rotation:
accept_homomorph = True) accept_homomorph = True)
def misorientation(self, def misorientation(self: MyType,
other: 'Rotation') -> 'Rotation': other: MyType) -> MyType:
""" """
Calculate misorientation to other Rotation. Calculate misorientation to other Rotation.