type hints for tensor

adjustments to Rotation/Orientation needed to enable type checking with
mypy
This commit is contained in:
Martin Diehl 2021-10-31 22:37:54 +01:00
parent 9f152aca0d
commit 509835bf0b
3 changed files with 14 additions and 9 deletions

View File

@ -125,7 +125,7 @@ class Orientation(Rotation,Crystal):
"""Create deep copy.""" """Create deep copy."""
dup = copy.deepcopy(self) dup = copy.deepcopy(self)
if rotation is not None: if rotation is not None:
dup.quaternion = Orientation(rotation,family='cubic').quaternion dup.quaternion = Rotation(rotation).quaternion
return dup return dup
copy = __copy__ copy = __copy__

View File

@ -1,3 +1,5 @@
import copy
import numpy as np import numpy as np
from . import tensor from . import tensor
@ -85,9 +87,12 @@ class Rotation:
+ str(self.quaternion) + str(self.quaternion)
def __copy__(self,**kwargs): def __copy__(self,rotation=None):
"""Create deep copy.""" """Create deep copy."""
return self.__class__(rotation=kwargs['rotation'] if 'rotation' in kwargs else self.quaternion) dup = copy.deepcopy(self)
if rotation is not None:
dup.quaternion = Rotation(rotation).quaternion
return dup
copy = __copy__ copy = __copy__

View File

@ -8,7 +8,7 @@ All routines operate on numpy.ndarrays of shape (...,3,3).
import numpy as _np import numpy as _np
def deviatoric(T): def deviatoric(T: _np.ndarray) -> _np.ndarray:
""" """
Calculate deviatoric part of a tensor. Calculate deviatoric part of a tensor.
@ -26,7 +26,7 @@ def deviatoric(T):
return T - spherical(T,tensor=True) return T - spherical(T,tensor=True)
def eigenvalues(T_sym): def eigenvalues(T_sym: _np.ndarray) -> _np.ndarray:
""" """
Eigenvalues, i.e. principal components, of a symmetric tensor. Eigenvalues, i.e. principal components, of a symmetric tensor.
@ -45,7 +45,7 @@ def eigenvalues(T_sym):
return _np.linalg.eigvalsh(symmetric(T_sym)) return _np.linalg.eigvalsh(symmetric(T_sym))
def eigenvectors(T_sym,RHS=False): def eigenvectors(T_sym: _np.ndarray, RHS: bool = False) -> _np.ndarray:
""" """
Eigenvectors of a symmetric tensor. Eigenvectors of a symmetric tensor.
@ -70,7 +70,7 @@ def eigenvectors(T_sym,RHS=False):
return v return v
def spherical(T,tensor=True): def spherical(T: _np.ndarray, tensor: bool = True) -> _np.ndarray:
""" """
Calculate spherical part of a tensor. Calculate spherical part of a tensor.
@ -92,7 +92,7 @@ def spherical(T,tensor=True):
return _np.einsum('...jk,...',_np.eye(3),sph) if tensor else sph return _np.einsum('...jk,...',_np.eye(3),sph) if tensor else sph
def symmetric(T): def symmetric(T: _np.ndarray) -> _np.ndarray:
""" """
Symmetrize tensor. Symmetrize tensor.
@ -110,7 +110,7 @@ def symmetric(T):
return (T+transpose(T))*0.5 return (T+transpose(T))*0.5
def transpose(T): def transpose(T: _np.ndarray) -> _np.ndarray:
""" """
Transpose tensor. Transpose tensor.