add keywords to signature of Orientation functions that inherit from Rotation

This commit is contained in:
Martin Diehl 2023-02-23 18:49:08 +00:00
parent 133dae2cce
commit c40b446854
3 changed files with 106 additions and 70 deletions

View File

@ -1,6 +1,5 @@
import inspect
import copy import copy
from typing import Optional, Union, Callable, Dict, Any, Tuple, TypeVar from typing import Optional, Union, TypeVar
import numpy as np import numpy as np
@ -261,130 +260,87 @@ class Orientation(Rotation,Crystal):
Compound rotation self*other, i.e. first other then self rotation. Compound rotation self*other, i.e. first other then self rotation.
""" """
if isinstance(other, (Orientation,Rotation)): if not isinstance(other, (Orientation,Rotation)):
return self.copy(Rotation(self.quaternion)*Rotation(other.quaternion))
else:
raise TypeError('use "O@b", i.e. matmul, to apply Orientation "O" to object "b"') raise TypeError('use "O@b", i.e. matmul, to apply Orientation "O" to object "b"')
return self.copy(Rotation(self.quaternion)*Rotation(other.quaternion))
@staticmethod
def _split_kwargs(kwargs: Dict[str, Any],
target: Callable) -> Tuple[Dict[str, Any], ...]:
"""
Separate keyword arguments in 'kwargs' targeted at 'target' from general keyword arguments of Orientation objects.
Parameters
----------
kwargs : dictionary
Contains all **kwargs.
target: method
Function to scan for kwarg signature.
Returns
-------
rot_kwargs: dictionary
Valid keyword arguments of 'target' function of Rotation class.
ori_kwargs: dictionary
Valid keyword arguments of Orientation object.
"""
kws: Tuple[Dict[str, Any], ...] = ()
for t in (target,Orientation.__init__):
kws += ({key: kwargs[key] for key in set(inspect.signature(t).parameters) & set(kwargs)},)
invalid_keys = set(kwargs)-(set(kws[0])|set(kws[1]))
if invalid_keys:
raise TypeError(f"{inspect.stack()[1][3]}() got an unexpected keyword argument '{invalid_keys.pop()}'")
return kws
@classmethod @classmethod
@util.extend_docstring(Rotation.from_random, @util.extend_docstring(Rotation.from_random,
extra_parameters=_parameter_doc) extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_random, wrapped=__init__)
def from_random(cls, **kwargs) -> 'Orientation': def from_random(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_random) return cls(**kwargs)
return cls(rotation=Rotation.from_random(**kwargs_rot),**kwargs_ori)
@classmethod @classmethod
@util.extend_docstring(Rotation.from_quaternion, @util.extend_docstring(Rotation.from_quaternion,
extra_parameters=_parameter_doc) extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_quaternion, wrapped=__init__)
def from_quaternion(cls, **kwargs) -> 'Orientation': def from_quaternion(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_quaternion) return cls(**kwargs)
return cls(rotation=Rotation.from_quaternion(**kwargs_rot),**kwargs_ori)
@classmethod @classmethod
@util.extend_docstring(Rotation.from_Euler_angles, @util.extend_docstring(Rotation.from_Euler_angles,
extra_parameters=_parameter_doc) extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_Euler_angles, wrapped=__init__)
def from_Euler_angles(cls, **kwargs) -> 'Orientation': def from_Euler_angles(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_Euler_angles) return cls(**kwargs)
return cls(rotation=Rotation.from_Euler_angles(**kwargs_rot),**kwargs_ori)
@classmethod @classmethod
@util.extend_docstring(Rotation.from_axis_angle, @util.extend_docstring(Rotation.from_axis_angle,
extra_parameters=_parameter_doc) extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_axis_angle, wrapped=__init__)
def from_axis_angle(cls, **kwargs) -> 'Orientation': def from_axis_angle(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_axis_angle) return cls(**kwargs)
return cls(rotation=Rotation.from_axis_angle(**kwargs_rot),**kwargs_ori)
@classmethod @classmethod
@util.extend_docstring(Rotation.from_basis, @util.extend_docstring(Rotation.from_basis,
extra_parameters=_parameter_doc) extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_basis, wrapped=__init__)
def from_basis(cls, **kwargs) -> 'Orientation': def from_basis(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_basis) return cls(**kwargs)
return cls(rotation=Rotation.from_basis(**kwargs_rot),**kwargs_ori)
@classmethod @classmethod
@util.extend_docstring(Rotation.from_matrix, @util.extend_docstring(Rotation.from_matrix,
extra_parameters=_parameter_doc) extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_matrix, wrapped=__init__)
def from_matrix(cls, **kwargs) -> 'Orientation': def from_matrix(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_matrix) return cls(**kwargs)
return cls(rotation=Rotation.from_matrix(**kwargs_rot),**kwargs_ori)
@classmethod @classmethod
@util.extend_docstring(Rotation.from_Rodrigues_vector, @util.extend_docstring(Rotation.from_Rodrigues_vector,
extra_parameters=_parameter_doc) extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_Rodrigues_vector, wrapped=__init__)
def from_Rodrigues_vector(cls, **kwargs) -> 'Orientation': def from_Rodrigues_vector(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_Rodrigues_vector) return cls(**kwargs)
return cls(rotation=Rotation.from_Rodrigues_vector(**kwargs_rot),**kwargs_ori)
@classmethod @classmethod
@util.extend_docstring(Rotation.from_homochoric, @util.extend_docstring(Rotation.from_homochoric,
extra_parameters=_parameter_doc) extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_homochoric, wrapped=__init__)
def from_homochoric(cls, **kwargs) -> 'Orientation': def from_homochoric(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_homochoric) return cls(**kwargs)
return cls(rotation=Rotation.from_homochoric(**kwargs_rot),**kwargs_ori)
@classmethod @classmethod
@util.extend_docstring(Rotation.from_cubochoric, @util.extend_docstring(Rotation.from_cubochoric,
extra_parameters=_parameter_doc) extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_cubochoric, wrapped=__init__)
def from_cubochoric(cls, **kwargs) -> 'Orientation': def from_cubochoric(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_cubochoric) return cls(**kwargs)
return cls(rotation=Rotation.from_cubochoric(**kwargs_rot),**kwargs_ori)
@classmethod @classmethod
@util.extend_docstring(Rotation.from_spherical_component, @util.extend_docstring(Rotation.from_spherical_component,
extra_parameters=_parameter_doc) extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_spherical_component, wrapped=__init__)
def from_spherical_component(cls, **kwargs) -> 'Orientation': def from_spherical_component(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_spherical_component) return cls(**kwargs)
return cls(rotation=Rotation.from_spherical_component(**kwargs_rot),**kwargs_ori)
@classmethod @classmethod
@util.extend_docstring(Rotation.from_fiber_component, @util.extend_docstring(Rotation.from_fiber_component,
extra_parameters=_parameter_doc) extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_fiber_component, wrapped=__init__)
def from_fiber_component(cls, **kwargs) -> 'Orientation': def from_fiber_component(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_fiber_component) return cls(**kwargs)
return cls(rotation=Rotation.from_fiber_component(**kwargs_rot),**kwargs_ori)
@classmethod @classmethod

View File

@ -9,7 +9,8 @@ import re as _re
import signal as _signal import signal as _signal
import fractions as _fractions import fractions as _fractions
from collections import abc as _abc from collections import abc as _abc
from functools import reduce as _reduce, partial as _partial from functools import reduce as _reduce, partial as _partial, wraps as _wraps
import inspect
from typing import Optional as _Optional, Callable as _Callable, Union as _Union, Iterable as _Iterable, \ from typing import Optional as _Optional, Callable as _Callable, Union as _Union, Iterable as _Iterable, \
Dict as _Dict, List as _List, Tuple as _Tuple, Literal as _Literal, \ Dict as _Dict, List as _List, Tuple as _Tuple, Literal as _Literal, \
Any as _Any, TextIO as _TextIO Any as _Any, TextIO as _TextIO
@ -618,6 +619,48 @@ def extend_docstring(docstring: _Union[None, str, _Callable] = None,
return func return func
return _decorator return _decorator
def pass_on(keyword: str,
target: _Callable,
wrapped: _Callable = None) -> _Callable: # type: ignore
"""
Decorator: Combine signatures of 'wrapped' and 'target' functions and pass on output of 'target' as 'keyword' argument.
Parameters
----------
keyword : str
Keyword added to **kwargs of the decorated function
passing on the result of 'target'.
target : callable
The output of this function is passed to the
decorated function as 'keyword' argument.
wrapped: callable, optional
Signature of 'wrapped' function combined with
that of 'target' yields the overall signature of decorated function.
Notes
-----
The keywords used by 'target' will be prioritized
if they overlap with those of the decorated function.
Functions 'target' and 'wrapped' are assumed to only have keyword arguments.
"""
def decorator(func):
@_wraps(func)
def wrapper(*args, **kwargs):
kw_wrapped = set(kwargs.keys()) - set(inspect.getfullargspec(target).args)
kwargs_wrapped = {kw: kwargs.pop(kw) for kw in kw_wrapped}
kwargs_wrapped[keyword] = target(**kwargs)
return func(*args, **kwargs_wrapped)
args_ = [] if wrapped is None or 'self' not in inspect.signature(wrapped).parameters \
else [inspect.signature(wrapped).parameters['self']]
for f in [target] if wrapped is None else [target,wrapped]:
for param in inspect.signature(f).parameters.values():
if param.name != keyword \
and param.name not in [p.name for p in args_]+['self','cls', 'args', 'kwargs']: args_.append(param)
wrapper.__signature__ = inspect.Signature(parameters=args_,return_annotation=inspect.signature(func).return_annotation)
return wrapper
return decorator
def DREAM3D_base_group(fname: _Union[str, _Path]) -> str: def DREAM3D_base_group(fname: _Union[str, _Path]) -> str:
""" """

View File

@ -1,5 +1,6 @@
import sys import sys
import random import random
import pydoc
import pytest import pytest
import numpy as np import numpy as np
@ -341,3 +342,39 @@ p2 : str, optional
""" """
assert expected == util._docstringer(original_func,return_type=decorated_func) assert expected == util._docstringer(original_func,return_type=decorated_func)
assert expected == util._docstringer(original_func,return_type=TestClassDecorated.decorated_func_bound) assert expected == util._docstringer(original_func,return_type=TestClassDecorated.decorated_func_bound)
def test_passon_result(self):
def testfunction_inner(a=None,b=None):
return a+b
@util.pass_on('inner_result',testfunction_inner)
def testfunction_outer(**kwargs):
return kwargs['inner_result']+";"+kwargs['c']+kwargs['d']
assert testfunction_outer(a='1',b='2',c='3',d='4',e='5') == '12;34'
def test_passon_signature(self):
def testfunction_inner(a='1',b='2'):
return a+b
def testfunction_extra(e='5',f='6'):
return e+f
@util.pass_on('inner_result', testfunction_inner, wrapped=testfunction_extra)
def testfunction_outer(**kwargs):
return kwargs['inner_result']+";"+kwargs['c']+kwargs['d']
assert [(param.name, param.default) for param in testfunction_outer.__signature__.parameters.values()] == \
[('a', '1'), ('b', '2'), ('e', '5'), ('f', '6')]
def test_passon_help(self):
def testfunction_inner(a=None,b=None):
return a+b
def testfunction_extra(*,c=None,d=None):
return c+d
@util.pass_on('inner_result', testfunction_inner, wrapped=testfunction_extra)
def testfunction_outer(**kwargs) -> int:
return kwargs['inner_result']+kwargs['c']+kwargs['d']
assert pydoc.render_doc(testfunction_outer, renderer=pydoc.plaintext).split("\n")[-2] ==\
'testfunction_outer(a=None, b=None, *, c=None, d=None) -> int'