Modified shapeshifter function
This commit is contained in:
parent
e13196b770
commit
400323a9aa
|
@ -478,33 +478,26 @@ def shapeshifter(fro: _Tuple[int, ...],
|
|||
>>> (a * np.broadcast_to(b_extended,a.shape)).shape
|
||||
(3,4,2)
|
||||
|
||||
|
||||
"""
|
||||
if len(fro) == 0 and len(to) == 0: return ()
|
||||
|
||||
beg = dict(left ='(^.*\\b)',
|
||||
right='(^.*?\\b)')
|
||||
sep = dict(left ='(.*\\b)',
|
||||
right='(.*?\\b)')
|
||||
end = dict(left ='(.*?$)',
|
||||
right='(.*$)')
|
||||
fro = (1,) if len(fro) == 0 else fro
|
||||
to = (1,) if len(to) == 0 else to
|
||||
try:
|
||||
match = _re.match(beg[mode]
|
||||
+f',{sep[mode]}'.join(map(lambda x: f'{x}'
|
||||
if x>1 or (keep_ones and len(fro)>1) else
|
||||
'\\d+',fro))
|
||||
+f',{end[mode]}',','.join(map(str,to))+',')
|
||||
assert match
|
||||
grp = match.groups()
|
||||
except AssertionError:
|
||||
raise ValueError(f'shapes cannot be shifted {fro} --> {to}')
|
||||
fill: _Any = ()
|
||||
for g,d in zip(grp,fro+(None,)):
|
||||
fill += (1,)*g.count(',')+(d,)
|
||||
return fill[:-1]
|
||||
if len(fro) == 0 and len(to) == 0: return tuple()
|
||||
_fro = [1] if len(fro) == 0 else list(fro)[::-1 if mode=='left' else 1]
|
||||
_to = [1] if len(to) == 0 else list(to) [::-1 if mode=='left' else 1]
|
||||
|
||||
final_shape: _List[int] = []
|
||||
index = 0
|
||||
for i,item in enumerate(_to):
|
||||
if item==_fro[index]:
|
||||
final_shape.append(item)
|
||||
index+=1
|
||||
else:
|
||||
final_shape.append(1)
|
||||
if _fro[index]==1 and not keep_ones:
|
||||
index+=1
|
||||
if index==len(_fro):
|
||||
final_shape = final_shape+[1]*(len(_to)-i-1)
|
||||
break
|
||||
if index!=len(_fro): raise ValueError(f'shapes cannot be shifted {fro} --> {to}')
|
||||
return tuple(final_shape[::-1] if mode=='left' else final_shape)
|
||||
|
||||
def shapeblender(a: _Tuple[int, ...],
|
||||
b: _Tuple[int, ...]) -> _Tuple[int, ...]:
|
||||
|
|
Loading…
Reference in New Issue