more flexible shapeblending

This commit is contained in:
Philip Eisenlohr 2023-09-19 17:43:27 -04:00
parent 0604e510ec
commit b3f98ab877
2 changed files with 39 additions and 14 deletions

View File

@ -512,7 +512,8 @@ def shapeshifter(fro: _Tuple[int, ...],
return tuple(final_shape[::-1] if mode == 'left' else final_shape) return tuple(final_shape[::-1] if mode == 'left' else final_shape)
def shapeblender(a: _Tuple[int, ...], def shapeblender(a: _Tuple[int, ...],
b: _Tuple[int, ...]) -> _Tuple[int, ...]: b: _Tuple[int, ...],
keep_ones: bool = True) -> _Tuple[int, ...]:
""" """
Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'. Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'.
@ -522,6 +523,9 @@ def shapeblender(a: _Tuple[int, ...],
Shape of first array. Shape of first array.
b : tuple b : tuple
Shape of second array. Shape of second array.
keep_ones : bool, optional
Treat innermost '1's as literal value instead of dimensional placeholder.
Defaults to True.
Examples Examples
-------- --------
@ -531,13 +535,30 @@ def shapeblender(a: _Tuple[int, ...],
(1,2,3) (1,2,3)
>>> shapeblender((1,),(2,2,1)) >>> shapeblender((1,),(2,2,1))
(1,2,2,1) (1,2,2,1)
>>> shapeblender((1,),(2,2,1),False)
(2,2,1)
>>> shapeblender((3,2),(3,2)) >>> shapeblender((3,2),(3,2))
(3,2) (3,2)
""" """
i = min(len(a),len(b)) def is_broadcastable(a,b):
while i > 0 and a[-i:] != b[:i]: i -= 1 try:
return a + b[i:] _np.broadcast_shapes(a,b)
return True
except ValueError:
return False
a_,_b = a,b
if keep_ones:
i = min(len(a_),len(_b))
while i > 0 and a_[-i:] != _b[:i]: i -= 1
return a_ + _b[i:]
else:
a_ += max(0,len(_b)-len(a_))*(1,)
while not is_broadcastable(a_,_b):
a_ = a_ + ((1,) if len(a_)<=len(_b) else ())
_b = ((1,) if len(_b)<len(a_) else ()) + _b
return _np.broadcast_shapes(a_,_b)
def _docstringer(docstring: _Union[str, _Callable], def _docstringer(docstring: _Union[str, _Callable],

View File

@ -128,18 +128,22 @@ class TestUtil:
with pytest.raises(ValueError): with pytest.raises(ValueError):
util.shapeshifter(fro,to,mode) util.shapeshifter(fro,to,mode)
@pytest.mark.parametrize('a,b,answer', @pytest.mark.parametrize('a,b,ones,answer',
[ [
((),(1,),(1,)), ((),(1,),True,(1,)),
((1,),(),(1,)), ((1,),(),False,(1,)),
((1,),(7,),(1,7)), ((1,1),(7,),False,(1,7)),
((2,),(2,2),(2,2)), ((1,),(7,),False,(7,)),
((1,2),(2,2),(1,2,2)), ((1,),(7,),True,(1,7)),
((1,2,3),(2,3,4),(1,2,3,4)), ((2,),(2,2),False,(2,2)),
((1,2,3),(1,2,3),(1,2,3)), ((1,2),(2,2),False,(2,2)),
((1,1,2),(2,2),False,(1,2,2)),
((1,1,2),(2,2),True,(1,1,2,2)),
((1,2,3),(2,3,4),False,(1,2,3,4)),
((1,2,3),(1,2,3),False,(1,2,3)),
]) ])
def test_shapeblender(self,a,b,answer): def test_shapeblender(self,a,b,ones,answer):
assert util.shapeblender(a,b) == answer assert util.shapeblender(a,b,ones) == answer
@pytest.mark.parametrize('style',[util.emph,util.deemph,util.warn,util.strikeout]) @pytest.mark.parametrize('style',[util.emph,util.deemph,util.warn,util.strikeout])
def test_decorate(self,style): def test_decorate(self,style):