more flexible shapeblending

This commit is contained in:
Philip Eisenlohr 2023-09-19 17:43:27 -04:00
parent 0604e510ec
commit b3f98ab877
2 changed files with 39 additions and 14 deletions

View File

@ -512,7 +512,8 @@ def shapeshifter(fro: _Tuple[int, ...],
return tuple(final_shape[::-1] if mode == 'left' else final_shape)
def shapeblender(a: _Tuple[int, ...],
b: _Tuple[int, ...]) -> _Tuple[int, ...]:
b: _Tuple[int, ...],
keep_ones: bool = True) -> _Tuple[int, ...]:
"""
Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'.
@ -522,6 +523,9 @@ def shapeblender(a: _Tuple[int, ...],
Shape of first array.
b : tuple
Shape of second array.
keep_ones : bool, optional
Treat innermost '1's as literal value instead of dimensional placeholder.
Defaults to True.
Examples
--------
@ -531,13 +535,30 @@ def shapeblender(a: _Tuple[int, ...],
(1,2,3)
>>> shapeblender((1,),(2,2,1))
(1,2,2,1)
>>> shapeblender((1,),(2,2,1),False)
(2,2,1)
>>> shapeblender((3,2),(3,2))
(3,2)
"""
i = min(len(a),len(b))
while i > 0 and a[-i:] != b[:i]: i -= 1
return a + b[i:]
def is_broadcastable(a,b):
try:
_np.broadcast_shapes(a,b)
return True
except ValueError:
return False
a_,_b = a,b
if keep_ones:
i = min(len(a_),len(_b))
while i > 0 and a_[-i:] != _b[:i]: i -= 1
return a_ + _b[i:]
else:
a_ += max(0,len(_b)-len(a_))*(1,)
while not is_broadcastable(a_,_b):
a_ = a_ + ((1,) if len(a_)<=len(_b) else ())
_b = ((1,) if len(_b)<len(a_) else ()) + _b
return _np.broadcast_shapes(a_,_b)
def _docstringer(docstring: _Union[str, _Callable],

View File

@ -128,18 +128,22 @@ class TestUtil:
with pytest.raises(ValueError):
util.shapeshifter(fro,to,mode)
@pytest.mark.parametrize('a,b,answer',
@pytest.mark.parametrize('a,b,ones,answer',
[
((),(1,),(1,)),
((1,),(),(1,)),
((1,),(7,),(1,7)),
((2,),(2,2),(2,2)),
((1,2),(2,2),(1,2,2)),
((1,2,3),(2,3,4),(1,2,3,4)),
((1,2,3),(1,2,3),(1,2,3)),
((),(1,),True,(1,)),
((1,),(),False,(1,)),
((1,1),(7,),False,(1,7)),
((1,),(7,),False,(7,)),
((1,),(7,),True,(1,7)),
((2,),(2,2),False,(2,2)),
((1,2),(2,2),False,(2,2)),
((1,1,2),(2,2),False,(1,2,2)),
((1,1,2),(2,2),True,(1,1,2,2)),
((1,2,3),(2,3,4),False,(1,2,3,4)),
((1,2,3),(1,2,3),False,(1,2,3)),
])
def test_shapeblender(self,a,b,answer):
assert util.shapeblender(a,b) == answer
def test_shapeblender(self,a,b,ones,answer):
assert util.shapeblender(a,b,ones) == answer
@pytest.mark.parametrize('style',[util.emph,util.deemph,util.warn,util.strikeout])
def test_decorate(self,style):