diff --git a/python/damask/util.py b/python/damask/util.py index d61ca221b..e4b7ddcb5 100644 --- a/python/damask/util.py +++ b/python/damask/util.py @@ -512,7 +512,8 @@ def shapeshifter(fro: _Tuple[int, ...], return tuple(final_shape[::-1] if mode == 'left' else final_shape) def shapeblender(a: _Tuple[int, ...], - b: _Tuple[int, ...]) -> _Tuple[int, ...]: + b: _Tuple[int, ...], + keep_ones: bool = True) -> _Tuple[int, ...]: """ Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'. @@ -522,6 +523,9 @@ def shapeblender(a: _Tuple[int, ...], Shape of first array. b : tuple Shape of second array. + keep_ones : bool, optional + Treat innermost '1's as literal value instead of dimensional placeholder. + Defaults to True. Examples -------- @@ -531,13 +535,30 @@ def shapeblender(a: _Tuple[int, ...], (1,2,3) >>> shapeblender((1,),(2,2,1)) (1,2,2,1) + >>> shapeblender((1,),(2,2,1),False) + (2,2,1) >>> shapeblender((3,2),(3,2)) (3,2) """ - i = min(len(a),len(b)) - while i > 0 and a[-i:] != b[:i]: i -= 1 - return a + b[i:] + def is_broadcastable(a,b): + try: + _np.broadcast_shapes(a,b) + return True + except ValueError: + return False + + a_,_b = a,b + if keep_ones: + i = min(len(a_),len(_b)) + while i > 0 and a_[-i:] != _b[:i]: i -= 1 + return a_ + _b[i:] + else: + a_ += max(0,len(_b)-len(a_))*(1,) + while not is_broadcastable(a_,_b): + a_ = a_ + ((1,) if len(a_)<=len(_b) else ()) + _b = ((1,) if len(_b)