replaced typehint in shapeshifter function

This commit is contained in:
Daniel Otto de Mentock 2022-01-21 11:45:14 +01:00
parent c2fa17e903
commit 76ccd4aaaa
1 changed files with 3 additions and 3 deletions

View File

@ -8,7 +8,7 @@ import shlex
import re import re
import fractions import fractions
from functools import reduce from functools import reduce
from typing import Union, Tuple, Sequence, Callable, Dict, List, Any, Literal from typing import Union, Tuple, Sequence, Callable, Dict, List, Any, Literal, Optional
import pathlib import pathlib
import numpy as np import numpy as np
@ -398,7 +398,7 @@ def hybrid_IA(dist: np.ndarray, N: int, rng_seed: Union[int, IntSequence] = None
def shapeshifter(fro: Tuple[int, ...], def shapeshifter(fro: Tuple[int, ...],
to: Tuple[int, ...], to: Tuple[int, ...],
mode: Literal['left','right'] = 'left', mode: Literal['left','right'] = 'left',
keep_ones: bool = False) -> Tuple[int, ...]: keep_ones: bool = False) -> Tuple[Optional[int], ...]:
""" """
Return dimensions that reshape 'fro' to become broadcastable to 'to'. Return dimensions that reshape 'fro' to become broadcastable to 'to'.
@ -454,7 +454,7 @@ def shapeshifter(fro: Tuple[int, ...],
except AssertionError: except AssertionError:
raise ValueError(f'Shapes can not be shifted {fro} --> {to}') raise ValueError(f'Shapes can not be shifted {fro} --> {to}')
grp: Sequence[str] = match.groups() grp: Sequence[str] = match.groups()
fill: Tuple[int, ...] = () fill: Tuple[Optional[int], ...] = ()
for g,d in zip(grp,fro+(None,)): for g,d in zip(grp,fro+(None,)):
fill += (1,)*g.count(',')+(d,) fill += (1,)*g.count(',')+(d,)
return fill[:-1] return fill[:-1]