Added generic type to rotation functions not overwritten by orientation

This commit is contained in:
Daniel Otto de Mentock 2022-02-02 12:14:00 +01:00
parent 71bc92fed0
commit 3df411469b
1 changed files with 25 additions and 18 deletions

View File

@ -6,7 +6,7 @@ from . import tensor
from . import util
from . import grid_filters
from typing import Union, Sequence, Tuple, Literal, List
from typing import Union, Sequence, Tuple, Literal, List, TypeVar
from ._typehints import FloatSequence, IntSequence
_P = -1
@ -16,6 +16,8 @@ _sc = np.pi**(1./6.)/6.**(1./6.)
_beta = np.pi**(5./6.)/6.**(1./6.)/2.
_R1 = (3.*np.pi/4.)**(1./3.)
MyType = TypeVar('MyType', bound='Rotation')
class Rotation:
u"""
Rotation with functionality for conversion between different representations.
@ -92,8 +94,8 @@ class Rotation:
+ str(self.quaternion)
def __copy__(self,
rotation: Union[FloatSequence, 'Rotation'] = None) -> 'Rotation':
def __copy__(self: MyType,
rotation: Union[FloatSequence, 'Rotation'] = None) -> MyType:
"""Create deep copy."""
dup = copy.deepcopy(self)
if rotation is not None:
@ -220,14 +222,15 @@ class Rotation:
return 0 if self.shape == () else self.shape[0]
def __invert__(self) -> 'Rotation':
def __invert__(self: MyType) -> MyType:
"""Inverse rotation (backward rotation)."""
dup = self.copy()
dup.quaternion[...,1:] *= -1
return dup
def __pow__(self, exp: int) -> 'Rotation':
def __pow__(self: MyType,
exp: int) -> MyType:
"""
Perform the rotation 'exp' times.
@ -241,7 +244,8 @@ class Rotation:
p = self.quaternion[...,1:]/np.linalg.norm(self.quaternion[...,1:],axis=-1,keepdims=True)
return self.copy(rotation=Rotation(np.block([np.cos(exp*phi),np.sin(exp*phi)*p]))._standardize())
def __ipow__(self, exp: int) -> 'Rotation':
def __ipow__(self: MyType,
exp: int) -> MyType:
"""
Perform the rotation 'exp' times (in-place).
@ -280,7 +284,8 @@ class Rotation:
else:
raise TypeError('Use "R@b", i.e. matmul, to apply rotation "R" to object "b"')
def __imul__(self, other: 'Rotation') -> 'Rotation':
def __imul__(self,
other: 'Rotation') -> 'Rotation':
"""
Compose with other (in-place).
@ -293,7 +298,8 @@ class Rotation:
return self*other
def __truediv__(self, other: 'Rotation') -> 'Rotation':
def __truediv__(self: 'Rotation',
other: 'Rotation') -> 'Rotation':
"""
Compose with inverse of other.
@ -313,7 +319,8 @@ class Rotation:
else:
raise TypeError('Use "R@b", i.e. matmul, to apply rotation "R" to object "b"')
def __itruediv__(self, other: 'Rotation') -> 'Rotation':
def __itruediv__(self: 'Rotation',
other: 'Rotation') -> 'Rotation':
"""
Compose with inverse of other (in-place).
@ -369,14 +376,14 @@ class Rotation:
apply = __matmul__
def _standardize(self) -> 'Rotation':
def _standardize(self: MyType) -> MyType:
"""Standardize quaternion (ensure positive real hemisphere)."""
self.quaternion[self.quaternion[...,0] < 0.0] *= -1
return self
def append(self,
other: Union['Rotation', List['Rotation']]) -> 'Rotation':
def append(self: MyType,
other: Union[MyType, List[MyType]]) -> MyType:
"""
Extend array along first dimension with other array(s).
@ -389,8 +396,8 @@ class Rotation:
[self]+other if isinstance(other,list) else [self,other]))))
def flatten(self,
order: Literal['C','F','A'] = 'C') -> 'Rotation':
def flatten(self: MyType,
order: Literal['C','F','A'] = 'C') -> MyType:
"""
Flatten array.
@ -403,9 +410,9 @@ class Rotation:
return self.copy(rotation=self.quaternion.reshape((-1,4),order=order))
def reshape(self,
def reshape(self: MyType,
shape: Union[int, Tuple[int, ...]],
order: Literal['C','F','A'] = 'C') -> 'Rotation':
order: Literal['C','F','A'] = 'C') -> MyType:
"""
Reshape array.
@ -419,9 +426,9 @@ class Rotation:
return self.copy(rotation=self.quaternion.reshape(tuple(shape)+(4,),order=order))
def broadcast_to(self,
def broadcast_to(self: MyType,
shape: Union[int, Tuple[int, ...]],
mode: Literal['left', 'right'] = 'right') -> 'Rotation':
mode: Literal['left', 'right'] = 'right') -> MyType:
"""
Broadcast array.