polishing

classes should return 'MyType' for inheritance without hassle
This commit is contained in:
Martin Diehl 2022-02-13 01:24:02 +01:00
parent bdc951c39b
commit 2907facfd3
4 changed files with 47 additions and 52 deletions

View File

@ -1,6 +1,6 @@
import inspect import inspect
import copy import copy
from typing import Union, Callable, List, Dict, Any, Tuple from typing import Union, Callable, List, Dict, Any, Tuple, TypeVar
import numpy as np import numpy as np
@ -11,7 +11,6 @@ from . import util
from . import tensor from . import tensor
_parameter_doc = \ _parameter_doc = \
""" """
family : {'triclinic', 'monoclinic', 'orthorhombic', 'tetragonal', 'hexagonal', 'cubic'}, optional. family : {'triclinic', 'monoclinic', 'orthorhombic', 'tetragonal', 'hexagonal', 'cubic'}, optional.
@ -36,6 +35,7 @@ _parameter_doc = \
""" """
MyType = TypeVar('MyType', bound='Orientation')
class Orientation(Rotation,Crystal): class Orientation(Rotation,Crystal):
""" """
@ -124,8 +124,8 @@ class Orientation(Rotation,Crystal):
return '\n'.join([Crystal.__repr__(self), return '\n'.join([Crystal.__repr__(self),
Rotation.__repr__(self)]) Rotation.__repr__(self)])
def __copy__(self, def __copy__(self: MyType,
rotation: Union[FloatSequence, Rotation] = None) -> 'Orientation': rotation: Union[FloatSequence, Rotation] = None) -> MyType:
"""Create deep copy.""" """Create deep copy."""
dup = copy.deepcopy(self) dup = copy.deepcopy(self)
if rotation is not None: if rotation is not None:
@ -189,7 +189,7 @@ class Orientation(Rotation,Crystal):
Returns Returns
------- -------
mask : numpy.ndarray of bool mask : numpy.ndarray of bool, shape (self.shape)
Mask indicating where corresponding orientations are close. Mask indicating where corresponding orientations are close.
""" """
@ -230,8 +230,8 @@ class Orientation(Rotation,Crystal):
return bool(np.all(self.isclose(other,rtol,atol,equal_nan))) return bool(np.all(self.isclose(other,rtol,atol,equal_nan)))
def __mul__(self, def __mul__(self: MyType,
other: Union[Rotation, 'Orientation']) -> 'Orientation': other: Union[Rotation, 'Orientation']) -> MyType:
""" """
Compose this orientation with other. Compose this orientation with other.
@ -246,8 +246,8 @@ class Orientation(Rotation,Crystal):
Compound rotation self*other, i.e. first other then self rotation. Compound rotation self*other, i.e. first other then self rotation.
""" """
if isinstance(other,Orientation) or isinstance(other,Rotation): if isinstance(other, (Orientation,Rotation)):
return self.copy(rotation=Rotation.__mul__(self,Rotation(other.quaternion))) return self.copy(Rotation(self.quaternion)*Rotation(other.quaternion))
else: else:
raise TypeError('use "O@b", i.e. matmul, to apply Orientation "O" to object "b"') raise TypeError('use "O@b", i.e. matmul, to apply Orientation "O" to object "b"')
@ -382,11 +382,11 @@ class Orientation(Rotation,Crystal):
x = o.to_frame(uvw=uvw) x = o.to_frame(uvw=uvw)
z = o.to_frame(hkl=hkl) z = o.to_frame(hkl=hkl)
om = np.stack([x,np.cross(z,x),z],axis=-2) om = np.stack([x,np.cross(z,x),z],axis=-2)
return o.copy(rotation=Rotation.from_matrix(tensor.transpose(om/np.linalg.norm(om,axis=-1,keepdims=True)))) return o.copy(Rotation.from_matrix(tensor.transpose(om/np.linalg.norm(om,axis=-1,keepdims=True))))
@property @property
def equivalent(self) -> 'Orientation': def equivalent(self: MyType) -> MyType:
""" """
Orientations that are symmetrically equivalent. Orientations that are symmetrically equivalent.
@ -396,11 +396,11 @@ class Orientation(Rotation,Crystal):
""" """
sym_ops = self.symmetry_operations sym_ops = self.symmetry_operations
o = sym_ops.broadcast_to(sym_ops.shape+self.shape,mode='right') o = sym_ops.broadcast_to(sym_ops.shape+self.shape,mode='right')
return self.copy(rotation=o*Rotation(self.quaternion).broadcast_to(o.shape,mode='left')) return self.copy(o*Rotation(self.quaternion).broadcast_to(o.shape,mode='left'))
@property @property
def reduced(self) -> 'Orientation': def reduced(self: MyType) -> MyType:
"""Select symmetrically equivalent orientation that falls into fundamental zone according to symmetry.""" """Select symmetrically equivalent orientation that falls into fundamental zone according to symmetry."""
eq = self.equivalent eq = self.equivalent
ok = eq.in_FZ ok = eq.in_FZ
@ -616,11 +616,8 @@ class Orientation(Rotation,Crystal):
np.argmin(m,axis=0)[np.newaxis,...,np.newaxis], np.argmin(m,axis=0)[np.newaxis,...,np.newaxis],
axis=0), axis=0),
axis=0)) axis=0))
return ( return ((self.copy(Rotation(r).average(weights)),self.copy(Rotation(r))) if return_cloud else
(self.copy(rotation=Rotation(r).average(weights)), self.copy(Rotation(r).average(weights))
self.copy(rotation=Rotation(r)))
if return_cloud else
self.copy(rotation=Rotation(r).average(weights))
) )
@ -930,7 +927,7 @@ class Orientation(Rotation,Crystal):
if active == '*': active = [len(a) for a in kinematics['direction']] if active == '*': active = [len(a) for a in kinematics['direction']]
if not active: if not active:
raise RuntimeError # ToDo raise ValueError('Schmid matrix not defined')
d = self.to_frame(uvw=np.vstack([kinematics['direction'][i][:n] for i,n in enumerate(active)])) d = self.to_frame(uvw=np.vstack([kinematics['direction'][i][:n] for i,n in enumerate(active)]))
p = self.to_frame(hkl=np.vstack([kinematics['plane'][i][:n] for i,n in enumerate(active)])) p = self.to_frame(hkl=np.vstack([kinematics['plane'][i][:n] for i,n in enumerate(active)]))
P = np.einsum('...i,...j',d/np.linalg.norm(d,axis=1,keepdims=True), P = np.einsum('...i,...j',d/np.linalg.norm(d,axis=1,keepdims=True),
@ -941,8 +938,8 @@ class Orientation(Rotation,Crystal):
@ np.broadcast_to(P.reshape(util.shapeshifter(P.shape,shape)),shape) @ np.broadcast_to(P.reshape(util.shapeshifter(P.shape,shape)),shape)
def related(self, def related(self: MyType,
model: str) -> 'Orientation': model: str) -> MyType:
""" """
Orientations derived from the given relationship. Orientations derived from the given relationship.

View File

@ -1,14 +1,13 @@
import copy import copy
from typing import Union, Sequence, Tuple, Literal, List, TypeVar
import numpy as np import numpy as np
from ._typehints import FloatSequence, IntSequence, NumpyRngSeed
from . import tensor from . import tensor
from . import util from . import util
from . import grid_filters from . import grid_filters
from typing import Union, Sequence, Tuple, Literal, List, TypeVar
from ._typehints import FloatSequence, IntSequence, NumpyRngSeed
_P = -1 _P = -1
# parameters for conversion from/to cubochoric # parameters for conversion from/to cubochoric
@ -109,7 +108,7 @@ class Rotation:
item: Union[Tuple[int], int, bool, np.bool_, np.ndarray]): item: Union[Tuple[int], int, bool, np.bool_, np.ndarray]):
"""Return slice according to item.""" """Return slice according to item."""
return self.copy() if self.shape == () else \ return self.copy() if self.shape == () else \
self.copy(rotation=self.quaternion[item+(slice(None),)] if isinstance(item,tuple) else self.quaternion[item]) self.copy(self.quaternion[item+(slice(None),)] if isinstance(item,tuple) else self.quaternion[item])
def __eq__(self, def __eq__(self,
@ -162,7 +161,7 @@ class Rotation:
Returns Returns
------- -------
mask : numpy.ndarray of bool mask : numpy.ndarray of bool, shape (self.shape)
Mask indicating where corresponding rotations are close. Mask indicating where corresponding rotations are close.
""" """
@ -233,13 +232,13 @@ class Rotation:
Parameters Parameters
---------- ----------
exp : scalar exp : float
Exponent. Exponent.
""" """
phi = np.arccos(self.quaternion[...,0:1]) phi = np.arccos(self.quaternion[...,0:1])
p = self.quaternion[...,1:]/np.linalg.norm(self.quaternion[...,1:],axis=-1,keepdims=True) p = self.quaternion[...,1:]/np.linalg.norm(self.quaternion[...,1:],axis=-1,keepdims=True)
return self.copy(rotation=Rotation(np.block([np.cos(exp*phi),np.sin(exp*phi)*p]))._standardize()) return self.copy(Rotation(np.block([np.cos(exp*phi),np.sin(exp*phi)*p]))._standardize())
def __ipow__(self: MyType, def __ipow__(self: MyType,
exp: Union[float, int]) -> MyType: exp: Union[float, int]) -> MyType:
@ -248,7 +247,7 @@ class Rotation:
Parameters Parameters
---------- ----------
exp : scalar exp : float
Exponent. Exponent.
""" """
@ -278,7 +277,7 @@ class Rotation:
p_o = other.quaternion[...,1:] p_o = other.quaternion[...,1:]
q = (q_m*q_o - np.einsum('...i,...i',p_m,p_o).reshape(self.shape+(1,))) q = (q_m*q_o - np.einsum('...i,...i',p_m,p_o).reshape(self.shape+(1,)))
p = q_m*p_o + q_o*p_m + _P * np.cross(p_m,p_o) p = q_m*p_o + q_o*p_m + _P * np.cross(p_m,p_o)
return Rotation(np.block([q,p]))._standardize() #type: ignore return self.copy(Rotation(np.block([q,p]))._standardize())
else: else:
raise TypeError('Use "R@b", i.e. matmul, to apply rotation "R" to object "b"') raise TypeError('Use "R@b", i.e. matmul, to apply rotation "R" to object "b"')
@ -391,7 +390,7 @@ class Rotation:
other : (list of) damask.Rotation other : (list of) damask.Rotation
""" """
return self.copy(rotation=np.vstack(tuple(map(lambda x:x.quaternion, return self.copy(np.vstack(tuple(map(lambda x:x.quaternion,
[self]+other if isinstance(other,list) else [self,other])))) [self]+other if isinstance(other,list) else [self,other]))))
@ -415,7 +414,7 @@ class Rotation:
Rotation flattened to single dimension. Rotation flattened to single dimension.
""" """
return self.copy(rotation=self.quaternion.reshape((-1,4),order=order)) return self.copy(self.quaternion.reshape((-1,4),order=order))
def reshape(self: MyType, def reshape(self: MyType,
@ -443,7 +442,7 @@ class Rotation:
""" """
if isinstance(shape,(int,np.integer)): shape = (shape,) if isinstance(shape,(int,np.integer)): shape = (shape,)
return self.copy(rotation=self.quaternion.reshape(tuple(shape)+(4,),order=order)) return self.copy(self.quaternion.reshape(tuple(shape)+(4,),order=order))
def broadcast_to(self: MyType, def broadcast_to(self: MyType,
@ -467,18 +466,18 @@ class Rotation:
""" """
if isinstance(shape,(int,np.integer)): shape = (shape,) if isinstance(shape,(int,np.integer)): shape = (shape,)
return self.copy(rotation=np.broadcast_to(self.quaternion.reshape(util.shapeshifter(self.shape,shape,mode)+(4,)), return self.copy(np.broadcast_to(self.quaternion.reshape(util.shapeshifter(self.shape,shape,mode)+(4,)),
shape+(4,))) shape+(4,)))
def average(self, def average(self: MyType,
weights: FloatSequence = None) -> 'Rotation': weights: FloatSequence = None) -> MyType:
""" """
Average along last array dimension. Average along last array dimension.
Parameters Parameters
---------- ----------
weights : numpy.ndarray, optional weights : numpy.ndarray, shape (self.shape), optional
Relative weight of each rotation. Relative weight of each rotation.
Returns Returns
@ -501,13 +500,13 @@ class Rotation:
eig, vec = np.linalg.eig(np.sum(_M(self.quaternion) * weights_[...,np.newaxis,np.newaxis],axis=-3) \ eig, vec = np.linalg.eig(np.sum(_M(self.quaternion) * weights_[...,np.newaxis,np.newaxis],axis=-3) \
/np.sum( weights_[...,np.newaxis,np.newaxis],axis=-3)) /np.sum( weights_[...,np.newaxis,np.newaxis],axis=-3))
return Rotation.from_quaternion(np.real( return self.copy(Rotation.from_quaternion(np.real(
np.squeeze( np.squeeze(
np.take_along_axis(vec, np.take_along_axis(vec,
eig.argmax(axis=-1)[...,np.newaxis,np.newaxis], eig.argmax(axis=-1)[...,np.newaxis,np.newaxis],
axis=-1), axis=-1),
axis=-1)), axis=-1)),
accept_homomorph = True) accept_homomorph = True))
def misorientation(self: MyType, def misorientation(self: MyType,
@ -730,7 +729,7 @@ class Rotation:
Sign convention. Defaults to -1. Sign convention. Defaults to -1.
""" """
qu: np.ndarray = np.array(q,dtype=float) qu = np.array(q,dtype=float)
if qu.shape[:-2:-1] != (4,): if qu.shape[:-2:-1] != (4,):
raise ValueError('Invalid shape.') raise ValueError('Invalid shape.')
if abs(P) != 1: if abs(P) != 1:
@ -996,7 +995,7 @@ class Rotation:
Defaults to None, i.e. unpredictable entropy will be pulled from the OS. Defaults to None, i.e. unpredictable entropy will be pulled from the OS.
""" """
rng: np.random.Generator = np.random.default_rng(rng_seed) rng = np.random.default_rng(rng_seed)
r = rng.random(3 if shape is None else tuple(shape)+(3,) if hasattr(shape, '__iter__') else (shape,3)) #type: ignore r = rng.random(3 if shape is None else tuple(shape)+(3,) if hasattr(shape, '__iter__') else (shape,3)) #type: ignore
A = np.sqrt(r[...,2]) A = np.sqrt(r[...,2])
@ -1131,8 +1130,8 @@ class Rotation:
""" """
rng = np.random.default_rng(rng_seed) rng = np.random.default_rng(rng_seed)
sigma_: np.ndarray; alpha_: np.ndarray; beta_: np.ndarray sigma_,alpha_,beta_ = (np.radians(coordinate) for coordinate in (sigma,alpha,beta)) if degrees else \
sigma_,alpha_,beta_ = (np.radians(coordinate) for coordinate in (sigma,alpha,beta)) if degrees else (sigma,alpha,beta) #type: ignore map(np.array, (sigma,alpha,beta))
d_cr = np.array([np.sin(alpha_[0])*np.cos(alpha_[1]), np.sin(alpha_[0])*np.sin(alpha_[1]), np.cos(alpha_[0])]) d_cr = np.array([np.sin(alpha_[0])*np.cos(alpha_[1]), np.sin(alpha_[0])*np.sin(alpha_[1]), np.cos(alpha_[0])])
d_lab = np.array([np.sin( beta_[0])*np.cos( beta_[1]), np.sin( beta_[0])*np.sin( beta_[1]), np.cos( beta_[0])]) d_lab = np.array([np.sin( beta_[0])*np.cos( beta_[1]), np.sin( beta_[0])*np.sin( beta_[1]), np.cos( beta_[0])])

View File

@ -9,10 +9,9 @@ import numpy as np
FloatSequence = Union[np.ndarray,Sequence[float]] FloatSequence = Union[np.ndarray,Sequence[float]]
IntSequence = Union[np.ndarray,Sequence[int]] IntSequence = Union[np.ndarray,Sequence[int]]
FileHandle = Union[TextIO, str, Path] FileHandle = Union[TextIO, str, Path]
NumpyRngSeed = Union[int, IntSequence, np.random.SeedSequence, np.random.Generator]
CrystalFamily = Union[None,Literal['triclinic', 'monoclinic', 'orthorhombic', 'tetragonal', 'hexagonal', 'cubic']] CrystalFamily = Union[None,Literal['triclinic', 'monoclinic', 'orthorhombic', 'tetragonal', 'hexagonal', 'cubic']]
CrystalLattice = Union[None,Literal['aP', 'mP', 'mS', 'oP', 'oS', 'oI', 'oF', 'tP', 'tI', 'hP', 'cP', 'cI', 'cF']] CrystalLattice = Union[None,Literal['aP', 'mP', 'mS', 'oP', 'oS', 'oI', 'oF', 'tP', 'tI', 'hP', 'cP', 'cI', 'cF']]
CrystalKinematics = Literal['slip', 'twin'] CrystalKinematics = Literal['slip', 'twin']
NumpyRngSeed = Union[int, IntSequence, np.random.SeedSequence, np.random.Generator]
# BitGenerator does not exists in older numpy versions # BitGenerator does not exists in older numpy versions
#NumpyRngSeed = Union[int, IntSequence, np.random.SeedSequence, np.random.BitGenerator, np.random.Generator] #NumpyRngSeed = Union[int, IntSequence, np.random.SeedSequence, np.random.BitGenerator, np.random.Generator]