Added generic type to rotation functions not overwritten by orientation
This commit is contained in:
parent
71bc92fed0
commit
3df411469b
|
@ -6,7 +6,7 @@ from . import tensor
|
|||
from . import util
|
||||
from . import grid_filters
|
||||
|
||||
from typing import Union, Sequence, Tuple, Literal, List
|
||||
from typing import Union, Sequence, Tuple, Literal, List, TypeVar
|
||||
from ._typehints import FloatSequence, IntSequence
|
||||
|
||||
_P = -1
|
||||
|
@ -16,6 +16,8 @@ _sc = np.pi**(1./6.)/6.**(1./6.)
|
|||
_beta = np.pi**(5./6.)/6.**(1./6.)/2.
|
||||
_R1 = (3.*np.pi/4.)**(1./3.)
|
||||
|
||||
MyType = TypeVar('MyType', bound='Rotation')
|
||||
|
||||
class Rotation:
|
||||
u"""
|
||||
Rotation with functionality for conversion between different representations.
|
||||
|
@ -92,8 +94,8 @@ class Rotation:
|
|||
+ str(self.quaternion)
|
||||
|
||||
|
||||
def __copy__(self,
|
||||
rotation: Union[FloatSequence, 'Rotation'] = None) -> 'Rotation':
|
||||
def __copy__(self: MyType,
|
||||
rotation: Union[FloatSequence, 'Rotation'] = None) -> MyType:
|
||||
"""Create deep copy."""
|
||||
dup = copy.deepcopy(self)
|
||||
if rotation is not None:
|
||||
|
@ -220,14 +222,15 @@ class Rotation:
|
|||
return 0 if self.shape == () else self.shape[0]
|
||||
|
||||
|
||||
def __invert__(self) -> 'Rotation':
|
||||
def __invert__(self: MyType) -> MyType:
|
||||
"""Inverse rotation (backward rotation)."""
|
||||
dup = self.copy()
|
||||
dup.quaternion[...,1:] *= -1
|
||||
return dup
|
||||
|
||||
|
||||
def __pow__(self, exp: int) -> 'Rotation':
|
||||
def __pow__(self: MyType,
|
||||
exp: int) -> MyType:
|
||||
"""
|
||||
Perform the rotation 'exp' times.
|
||||
|
||||
|
@ -241,7 +244,8 @@ class Rotation:
|
|||
p = self.quaternion[...,1:]/np.linalg.norm(self.quaternion[...,1:],axis=-1,keepdims=True)
|
||||
return self.copy(rotation=Rotation(np.block([np.cos(exp*phi),np.sin(exp*phi)*p]))._standardize())
|
||||
|
||||
def __ipow__(self, exp: int) -> 'Rotation':
|
||||
def __ipow__(self: MyType,
|
||||
exp: int) -> MyType:
|
||||
"""
|
||||
Perform the rotation 'exp' times (in-place).
|
||||
|
||||
|
@ -280,7 +284,8 @@ class Rotation:
|
|||
else:
|
||||
raise TypeError('Use "R@b", i.e. matmul, to apply rotation "R" to object "b"')
|
||||
|
||||
def __imul__(self, other: 'Rotation') -> 'Rotation':
|
||||
def __imul__(self,
|
||||
other: 'Rotation') -> 'Rotation':
|
||||
"""
|
||||
Compose with other (in-place).
|
||||
|
||||
|
@ -293,7 +298,8 @@ class Rotation:
|
|||
return self*other
|
||||
|
||||
|
||||
def __truediv__(self, other: 'Rotation') -> 'Rotation':
|
||||
def __truediv__(self: 'Rotation',
|
||||
other: 'Rotation') -> 'Rotation':
|
||||
"""
|
||||
Compose with inverse of other.
|
||||
|
||||
|
@ -313,7 +319,8 @@ class Rotation:
|
|||
else:
|
||||
raise TypeError('Use "R@b", i.e. matmul, to apply rotation "R" to object "b"')
|
||||
|
||||
def __itruediv__(self, other: 'Rotation') -> 'Rotation':
|
||||
def __itruediv__(self: 'Rotation',
|
||||
other: 'Rotation') -> 'Rotation':
|
||||
"""
|
||||
Compose with inverse of other (in-place).
|
||||
|
||||
|
@ -369,14 +376,14 @@ class Rotation:
|
|||
apply = __matmul__
|
||||
|
||||
|
||||
def _standardize(self) -> 'Rotation':
|
||||
def _standardize(self: MyType) -> MyType:
|
||||
"""Standardize quaternion (ensure positive real hemisphere)."""
|
||||
self.quaternion[self.quaternion[...,0] < 0.0] *= -1
|
||||
return self
|
||||
|
||||
|
||||
def append(self,
|
||||
other: Union['Rotation', List['Rotation']]) -> 'Rotation':
|
||||
def append(self: MyType,
|
||||
other: Union[MyType, List[MyType]]) -> MyType:
|
||||
"""
|
||||
Extend array along first dimension with other array(s).
|
||||
|
||||
|
@ -389,8 +396,8 @@ class Rotation:
|
|||
[self]+other if isinstance(other,list) else [self,other]))))
|
||||
|
||||
|
||||
def flatten(self,
|
||||
order: Literal['C','F','A'] = 'C') -> 'Rotation':
|
||||
def flatten(self: MyType,
|
||||
order: Literal['C','F','A'] = 'C') -> MyType:
|
||||
"""
|
||||
Flatten array.
|
||||
|
||||
|
@ -403,9 +410,9 @@ class Rotation:
|
|||
return self.copy(rotation=self.quaternion.reshape((-1,4),order=order))
|
||||
|
||||
|
||||
def reshape(self,
|
||||
def reshape(self: MyType,
|
||||
shape: Union[int, Tuple[int, ...]],
|
||||
order: Literal['C','F','A'] = 'C') -> 'Rotation':
|
||||
order: Literal['C','F','A'] = 'C') -> MyType:
|
||||
"""
|
||||
Reshape array.
|
||||
|
||||
|
@ -419,9 +426,9 @@ class Rotation:
|
|||
return self.copy(rotation=self.quaternion.reshape(tuple(shape)+(4,),order=order))
|
||||
|
||||
|
||||
def broadcast_to(self,
|
||||
def broadcast_to(self: MyType,
|
||||
shape: Union[int, Tuple[int, ...]],
|
||||
mode: Literal['left', 'right'] = 'right') -> 'Rotation':
|
||||
mode: Literal['left', 'right'] = 'right') -> MyType:
|
||||
"""
|
||||
Broadcast array.
|
||||
|
||||
|
|
Loading…
Reference in New Issue