more flexible shapeblending
This commit is contained in:
parent
0604e510ec
commit
b3f98ab877
|
@ -512,7 +512,8 @@ def shapeshifter(fro: _Tuple[int, ...],
|
|||
return tuple(final_shape[::-1] if mode == 'left' else final_shape)
|
||||
|
||||
def shapeblender(a: _Tuple[int, ...],
|
||||
b: _Tuple[int, ...]) -> _Tuple[int, ...]:
|
||||
b: _Tuple[int, ...],
|
||||
keep_ones: bool = True) -> _Tuple[int, ...]:
|
||||
"""
|
||||
Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'.
|
||||
|
||||
|
@ -522,6 +523,9 @@ def shapeblender(a: _Tuple[int, ...],
|
|||
Shape of first array.
|
||||
b : tuple
|
||||
Shape of second array.
|
||||
keep_ones : bool, optional
|
||||
Treat innermost '1's as literal value instead of dimensional placeholder.
|
||||
Defaults to True.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
@ -531,13 +535,30 @@ def shapeblender(a: _Tuple[int, ...],
|
|||
(1,2,3)
|
||||
>>> shapeblender((1,),(2,2,1))
|
||||
(1,2,2,1)
|
||||
>>> shapeblender((1,),(2,2,1),False)
|
||||
(2,2,1)
|
||||
>>> shapeblender((3,2),(3,2))
|
||||
(3,2)
|
||||
|
||||
"""
|
||||
i = min(len(a),len(b))
|
||||
while i > 0 and a[-i:] != b[:i]: i -= 1
|
||||
return a + b[i:]
|
||||
def is_broadcastable(a,b):
|
||||
try:
|
||||
_np.broadcast_shapes(a,b)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
a_,_b = a,b
|
||||
if keep_ones:
|
||||
i = min(len(a_),len(_b))
|
||||
while i > 0 and a_[-i:] != _b[:i]: i -= 1
|
||||
return a_ + _b[i:]
|
||||
else:
|
||||
a_ += max(0,len(_b)-len(a_))*(1,)
|
||||
while not is_broadcastable(a_,_b):
|
||||
a_ = a_ + ((1,) if len(a_)<=len(_b) else ())
|
||||
_b = ((1,) if len(_b)<len(a_) else ()) + _b
|
||||
return _np.broadcast_shapes(a_,_b)
|
||||
|
||||
|
||||
def _docstringer(docstring: _Union[str, _Callable],
|
||||
|
|
|
@ -128,18 +128,22 @@ class TestUtil:
|
|||
with pytest.raises(ValueError):
|
||||
util.shapeshifter(fro,to,mode)
|
||||
|
||||
@pytest.mark.parametrize('a,b,answer',
|
||||
@pytest.mark.parametrize('a,b,ones,answer',
|
||||
[
|
||||
((),(1,),(1,)),
|
||||
((1,),(),(1,)),
|
||||
((1,),(7,),(1,7)),
|
||||
((2,),(2,2),(2,2)),
|
||||
((1,2),(2,2),(1,2,2)),
|
||||
((1,2,3),(2,3,4),(1,2,3,4)),
|
||||
((1,2,3),(1,2,3),(1,2,3)),
|
||||
((),(1,),True,(1,)),
|
||||
((1,),(),False,(1,)),
|
||||
((1,1),(7,),False,(1,7)),
|
||||
((1,),(7,),False,(7,)),
|
||||
((1,),(7,),True,(1,7)),
|
||||
((2,),(2,2),False,(2,2)),
|
||||
((1,2),(2,2),False,(2,2)),
|
||||
((1,1,2),(2,2),False,(1,2,2)),
|
||||
((1,1,2),(2,2),True,(1,1,2,2)),
|
||||
((1,2,3),(2,3,4),False,(1,2,3,4)),
|
||||
((1,2,3),(1,2,3),False,(1,2,3)),
|
||||
])
|
||||
def test_shapeblender(self,a,b,answer):
|
||||
assert util.shapeblender(a,b) == answer
|
||||
def test_shapeblender(self,a,b,ones,answer):
|
||||
assert util.shapeblender(a,b,ones) == answer
|
||||
|
||||
@pytest.mark.parametrize('style',[util.emph,util.deemph,util.warn,util.strikeout])
|
||||
def test_decorate(self,style):
|
||||
|
|
Loading…
Reference in New Issue