type hints for tensor

adjustments to Rotation/Orientation needed to enable type checking with
mypy
This commit is contained in:
Martin Diehl 2021-10-31 22:37:54 +01:00
parent 9f152aca0d
commit 509835bf0b
3 changed files with 14 additions and 9 deletions

View File

@ -125,7 +125,7 @@ class Orientation(Rotation,Crystal):
"""Create deep copy."""
dup = copy.deepcopy(self)
if rotation is not None:
dup.quaternion = Orientation(rotation,family='cubic').quaternion
dup.quaternion = Rotation(rotation).quaternion
return dup
copy = __copy__

View File

@ -1,3 +1,5 @@
import copy
import numpy as np
from . import tensor
@ -85,9 +87,12 @@ class Rotation:
+ str(self.quaternion)
def __copy__(self,**kwargs):
def __copy__(self,rotation=None):
"""Create deep copy."""
return self.__class__(rotation=kwargs['rotation'] if 'rotation' in kwargs else self.quaternion)
dup = copy.deepcopy(self)
if rotation is not None:
dup.quaternion = Rotation(rotation).quaternion
return dup
copy = __copy__

View File

@ -8,7 +8,7 @@ All routines operate on numpy.ndarrays of shape (...,3,3).
import numpy as _np
def deviatoric(T):
def deviatoric(T: _np.ndarray) -> _np.ndarray:
"""
Calculate deviatoric part of a tensor.
@ -26,7 +26,7 @@ def deviatoric(T):
return T - spherical(T,tensor=True)
def eigenvalues(T_sym):
def eigenvalues(T_sym: _np.ndarray) -> _np.ndarray:
"""
Eigenvalues, i.e. principal components, of a symmetric tensor.
@ -45,7 +45,7 @@ def eigenvalues(T_sym):
return _np.linalg.eigvalsh(symmetric(T_sym))
def eigenvectors(T_sym,RHS=False):
def eigenvectors(T_sym: _np.ndarray, RHS: bool = False) -> _np.ndarray:
"""
Eigenvectors of a symmetric tensor.
@ -70,7 +70,7 @@ def eigenvectors(T_sym,RHS=False):
return v
def spherical(T,tensor=True):
def spherical(T: _np.ndarray, tensor: bool = True) -> _np.ndarray:
"""
Calculate spherical part of a tensor.
@ -92,7 +92,7 @@ def spherical(T,tensor=True):
return _np.einsum('...jk,...',_np.eye(3),sph) if tensor else sph
def symmetric(T):
def symmetric(T: _np.ndarray) -> _np.ndarray:
"""
Symmetrize tensor.
@ -110,7 +110,7 @@ def symmetric(T):
return (T+transpose(T))*0.5
def transpose(T):
def transpose(T: _np.ndarray) -> _np.ndarray:
"""
Transpose tensor.