add keywords to signature of Orientation functions that inherit from Rotation

This commit is contained in:
Martin Diehl 2023-02-23 18:49:08 +00:00
parent 133dae2cce
commit c40b446854
3 changed files with 106 additions and 70 deletions

View File

@ -1,6 +1,5 @@
import inspect
import copy
from typing import Optional, Union, Callable, Dict, Any, Tuple, TypeVar
from typing import Optional, Union, TypeVar
import numpy as np
@ -261,130 +260,87 @@ class Orientation(Rotation,Crystal):
Compound rotation self*other, i.e. first other then self rotation.
"""
if isinstance(other, (Orientation,Rotation)):
return self.copy(Rotation(self.quaternion)*Rotation(other.quaternion))
else:
if not isinstance(other, (Orientation,Rotation)):
raise TypeError('use "O@b", i.e. matmul, to apply Orientation "O" to object "b"')
@staticmethod
def _split_kwargs(kwargs: Dict[str, Any],
target: Callable) -> Tuple[Dict[str, Any], ...]:
"""
Separate keyword arguments in 'kwargs' targeted at 'target' from general keyword arguments of Orientation objects.
Parameters
----------
kwargs : dictionary
Contains all **kwargs.
target: method
Function to scan for kwarg signature.
Returns
-------
rot_kwargs: dictionary
Valid keyword arguments of 'target' function of Rotation class.
ori_kwargs: dictionary
Valid keyword arguments of Orientation object.
"""
kws: Tuple[Dict[str, Any], ...] = ()
for t in (target,Orientation.__init__):
kws += ({key: kwargs[key] for key in set(inspect.signature(t).parameters) & set(kwargs)},)
invalid_keys = set(kwargs)-(set(kws[0])|set(kws[1]))
if invalid_keys:
raise TypeError(f"{inspect.stack()[1][3]}() got an unexpected keyword argument '{invalid_keys.pop()}'")
return kws
return self.copy(Rotation(self.quaternion)*Rotation(other.quaternion))
@classmethod
@util.extend_docstring(Rotation.from_random,
extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_random, wrapped=__init__)
def from_random(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_random)
return cls(rotation=Rotation.from_random(**kwargs_rot),**kwargs_ori)
return cls(**kwargs)
@classmethod
@util.extend_docstring(Rotation.from_quaternion,
extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_quaternion, wrapped=__init__)
def from_quaternion(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_quaternion)
return cls(rotation=Rotation.from_quaternion(**kwargs_rot),**kwargs_ori)
return cls(**kwargs)
@classmethod
@util.extend_docstring(Rotation.from_Euler_angles,
extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_Euler_angles, wrapped=__init__)
def from_Euler_angles(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_Euler_angles)
return cls(rotation=Rotation.from_Euler_angles(**kwargs_rot),**kwargs_ori)
return cls(**kwargs)
@classmethod
@util.extend_docstring(Rotation.from_axis_angle,
extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_axis_angle, wrapped=__init__)
def from_axis_angle(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_axis_angle)
return cls(rotation=Rotation.from_axis_angle(**kwargs_rot),**kwargs_ori)
return cls(**kwargs)
@classmethod
@util.extend_docstring(Rotation.from_basis,
extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_basis, wrapped=__init__)
def from_basis(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_basis)
return cls(rotation=Rotation.from_basis(**kwargs_rot),**kwargs_ori)
return cls(**kwargs)
@classmethod
@util.extend_docstring(Rotation.from_matrix,
extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_matrix, wrapped=__init__)
def from_matrix(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_matrix)
return cls(rotation=Rotation.from_matrix(**kwargs_rot),**kwargs_ori)
return cls(**kwargs)
@classmethod
@util.extend_docstring(Rotation.from_Rodrigues_vector,
extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_Rodrigues_vector, wrapped=__init__)
def from_Rodrigues_vector(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_Rodrigues_vector)
return cls(rotation=Rotation.from_Rodrigues_vector(**kwargs_rot),**kwargs_ori)
return cls(**kwargs)
@classmethod
@util.extend_docstring(Rotation.from_homochoric,
extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_homochoric, wrapped=__init__)
def from_homochoric(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_homochoric)
return cls(rotation=Rotation.from_homochoric(**kwargs_rot),**kwargs_ori)
return cls(**kwargs)
@classmethod
@util.extend_docstring(Rotation.from_cubochoric,
extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_cubochoric, wrapped=__init__)
def from_cubochoric(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_cubochoric)
return cls(rotation=Rotation.from_cubochoric(**kwargs_rot),**kwargs_ori)
return cls(**kwargs)
@classmethod
@util.extend_docstring(Rotation.from_spherical_component,
extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_spherical_component, wrapped=__init__)
def from_spherical_component(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_spherical_component)
return cls(rotation=Rotation.from_spherical_component(**kwargs_rot),**kwargs_ori)
return cls(**kwargs)
@classmethod
@util.extend_docstring(Rotation.from_fiber_component,
extra_parameters=_parameter_doc)
@util.pass_on('rotation', Rotation.from_fiber_component, wrapped=__init__)
def from_fiber_component(cls, **kwargs) -> 'Orientation':
kwargs_rot,kwargs_ori = Orientation._split_kwargs(kwargs,Rotation.from_fiber_component)
return cls(rotation=Rotation.from_fiber_component(**kwargs_rot),**kwargs_ori)
return cls(**kwargs)
@classmethod

View File

@ -9,7 +9,8 @@ import re as _re
import signal as _signal
import fractions as _fractions
from collections import abc as _abc
from functools import reduce as _reduce, partial as _partial
from functools import reduce as _reduce, partial as _partial, wraps as _wraps
import inspect
from typing import Optional as _Optional, Callable as _Callable, Union as _Union, Iterable as _Iterable, \
Dict as _Dict, List as _List, Tuple as _Tuple, Literal as _Literal, \
Any as _Any, TextIO as _TextIO
@ -618,6 +619,48 @@ def extend_docstring(docstring: _Union[None, str, _Callable] = None,
return func
return _decorator
def pass_on(keyword: str,
target: _Callable,
wrapped: _Callable = None) -> _Callable: # type: ignore
"""
Decorator: Combine signatures of 'wrapped' and 'target' functions and pass on output of 'target' as 'keyword' argument.
Parameters
----------
keyword : str
Keyword added to **kwargs of the decorated function
passing on the result of 'target'.
target : callable
The output of this function is passed to the
decorated function as 'keyword' argument.
wrapped: callable, optional
Signature of 'wrapped' function combined with
that of 'target' yields the overall signature of decorated function.
Notes
-----
The keywords used by 'target' will be prioritized
if they overlap with those of the decorated function.
Functions 'target' and 'wrapped' are assumed to only have keyword arguments.
"""
def decorator(func):
@_wraps(func)
def wrapper(*args, **kwargs):
kw_wrapped = set(kwargs.keys()) - set(inspect.getfullargspec(target).args)
kwargs_wrapped = {kw: kwargs.pop(kw) for kw in kw_wrapped}
kwargs_wrapped[keyword] = target(**kwargs)
return func(*args, **kwargs_wrapped)
args_ = [] if wrapped is None or 'self' not in inspect.signature(wrapped).parameters \
else [inspect.signature(wrapped).parameters['self']]
for f in [target] if wrapped is None else [target,wrapped]:
for param in inspect.signature(f).parameters.values():
if param.name != keyword \
and param.name not in [p.name for p in args_]+['self','cls', 'args', 'kwargs']: args_.append(param)
wrapper.__signature__ = inspect.Signature(parameters=args_,return_annotation=inspect.signature(func).return_annotation)
return wrapper
return decorator
def DREAM3D_base_group(fname: _Union[str, _Path]) -> str:
"""

View File

@ -1,5 +1,6 @@
import sys
import random
import pydoc
import pytest
import numpy as np
@ -341,3 +342,39 @@ p2 : str, optional
"""
assert expected == util._docstringer(original_func,return_type=decorated_func)
assert expected == util._docstringer(original_func,return_type=TestClassDecorated.decorated_func_bound)
def test_passon_result(self):
def testfunction_inner(a=None,b=None):
return a+b
@util.pass_on('inner_result',testfunction_inner)
def testfunction_outer(**kwargs):
return kwargs['inner_result']+";"+kwargs['c']+kwargs['d']
assert testfunction_outer(a='1',b='2',c='3',d='4',e='5') == '12;34'
def test_passon_signature(self):
def testfunction_inner(a='1',b='2'):
return a+b
def testfunction_extra(e='5',f='6'):
return e+f
@util.pass_on('inner_result', testfunction_inner, wrapped=testfunction_extra)
def testfunction_outer(**kwargs):
return kwargs['inner_result']+";"+kwargs['c']+kwargs['d']
assert [(param.name, param.default) for param in testfunction_outer.__signature__.parameters.values()] == \
[('a', '1'), ('b', '2'), ('e', '5'), ('f', '6')]
def test_passon_help(self):
def testfunction_inner(a=None,b=None):
return a+b
def testfunction_extra(*,c=None,d=None):
return c+d
@util.pass_on('inner_result', testfunction_inner, wrapped=testfunction_extra)
def testfunction_outer(**kwargs) -> int:
return kwargs['inner_result']+kwargs['c']+kwargs['d']
assert pydoc.render_doc(testfunction_outer, renderer=pydoc.plaintext).split("\n")[-2] ==\
'testfunction_outer(a=None, b=None, *, c=None, d=None) -> int'