more flexible shapeblending
This commit is contained in:
parent
0604e510ec
commit
b3f98ab877
|
@ -512,7 +512,8 @@ def shapeshifter(fro: _Tuple[int, ...],
|
||||||
return tuple(final_shape[::-1] if mode == 'left' else final_shape)
|
return tuple(final_shape[::-1] if mode == 'left' else final_shape)
|
||||||
|
|
||||||
def shapeblender(a: _Tuple[int, ...],
|
def shapeblender(a: _Tuple[int, ...],
|
||||||
b: _Tuple[int, ...]) -> _Tuple[int, ...]:
|
b: _Tuple[int, ...],
|
||||||
|
keep_ones: bool = True) -> _Tuple[int, ...]:
|
||||||
"""
|
"""
|
||||||
Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'.
|
Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'.
|
||||||
|
|
||||||
|
@ -522,6 +523,9 @@ def shapeblender(a: _Tuple[int, ...],
|
||||||
Shape of first array.
|
Shape of first array.
|
||||||
b : tuple
|
b : tuple
|
||||||
Shape of second array.
|
Shape of second array.
|
||||||
|
keep_ones : bool, optional
|
||||||
|
Treat innermost '1's as literal value instead of dimensional placeholder.
|
||||||
|
Defaults to True.
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
|
@ -531,13 +535,30 @@ def shapeblender(a: _Tuple[int, ...],
|
||||||
(1,2,3)
|
(1,2,3)
|
||||||
>>> shapeblender((1,),(2,2,1))
|
>>> shapeblender((1,),(2,2,1))
|
||||||
(1,2,2,1)
|
(1,2,2,1)
|
||||||
|
>>> shapeblender((1,),(2,2,1),False)
|
||||||
|
(2,2,1)
|
||||||
>>> shapeblender((3,2),(3,2))
|
>>> shapeblender((3,2),(3,2))
|
||||||
(3,2)
|
(3,2)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
i = min(len(a),len(b))
|
def is_broadcastable(a,b):
|
||||||
while i > 0 and a[-i:] != b[:i]: i -= 1
|
try:
|
||||||
return a + b[i:]
|
_np.broadcast_shapes(a,b)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
a_,_b = a,b
|
||||||
|
if keep_ones:
|
||||||
|
i = min(len(a_),len(_b))
|
||||||
|
while i > 0 and a_[-i:] != _b[:i]: i -= 1
|
||||||
|
return a_ + _b[i:]
|
||||||
|
else:
|
||||||
|
a_ += max(0,len(_b)-len(a_))*(1,)
|
||||||
|
while not is_broadcastable(a_,_b):
|
||||||
|
a_ = a_ + ((1,) if len(a_)<=len(_b) else ())
|
||||||
|
_b = ((1,) if len(_b)<len(a_) else ()) + _b
|
||||||
|
return _np.broadcast_shapes(a_,_b)
|
||||||
|
|
||||||
|
|
||||||
def _docstringer(docstring: _Union[str, _Callable],
|
def _docstringer(docstring: _Union[str, _Callable],
|
||||||
|
|
|
@ -128,18 +128,22 @@ class TestUtil:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
util.shapeshifter(fro,to,mode)
|
util.shapeshifter(fro,to,mode)
|
||||||
|
|
||||||
@pytest.mark.parametrize('a,b,answer',
|
@pytest.mark.parametrize('a,b,ones,answer',
|
||||||
[
|
[
|
||||||
((),(1,),(1,)),
|
((),(1,),True,(1,)),
|
||||||
((1,),(),(1,)),
|
((1,),(),False,(1,)),
|
||||||
((1,),(7,),(1,7)),
|
((1,1),(7,),False,(1,7)),
|
||||||
((2,),(2,2),(2,2)),
|
((1,),(7,),False,(7,)),
|
||||||
((1,2),(2,2),(1,2,2)),
|
((1,),(7,),True,(1,7)),
|
||||||
((1,2,3),(2,3,4),(1,2,3,4)),
|
((2,),(2,2),False,(2,2)),
|
||||||
((1,2,3),(1,2,3),(1,2,3)),
|
((1,2),(2,2),False,(2,2)),
|
||||||
|
((1,1,2),(2,2),False,(1,2,2)),
|
||||||
|
((1,1,2),(2,2),True,(1,1,2,2)),
|
||||||
|
((1,2,3),(2,3,4),False,(1,2,3,4)),
|
||||||
|
((1,2,3),(1,2,3),False,(1,2,3)),
|
||||||
])
|
])
|
||||||
def test_shapeblender(self,a,b,answer):
|
def test_shapeblender(self,a,b,ones,answer):
|
||||||
assert util.shapeblender(a,b) == answer
|
assert util.shapeblender(a,b,ones) == answer
|
||||||
|
|
||||||
@pytest.mark.parametrize('style',[util.emph,util.deemph,util.warn,util.strikeout])
|
@pytest.mark.parametrize('style',[util.emph,util.deemph,util.warn,util.strikeout])
|
||||||
def test_decorate(self,style):
|
def test_decorate(self,style):
|
||||||
|
|
Loading…
Reference in New Issue