standard batch multiplication can be done with @

This commit is contained in:
Daniel Otto de Mentock 2024-02-08 10:54:25 +01:00 committed by achalhp
parent dffc567ba0
commit 76912da8bf
2 changed files with 2 additions and 2 deletions

View File

@ -337,7 +337,7 @@ def _polar_decomposition(T: _np.ndarray,
if isinstance(requested, str): requested = [requested]
u, _, vh = _np.linalg.svd(T)
R = _np.einsum('...ij,...jk',u,vh)
R = u @ vh
output = []
if 'R' in requested:

View File

@ -158,7 +158,7 @@ class TestMechanics:
@pytest.mark.parametrize('side',[('left','V'),('right','U')])
def test_polar_decomposition(self,side):
F = np.random.rand(self.n,3,3)
F = np.einsum('...ij,...jk',F,F) # positive determinant
F = F @ F # positive determinant
F_vec = np.reshape(F,(self.n//10,10,3,3))
p = mechanics._polar_decomposition(F_vec,side[1])
for p_,F_ in zip(np.reshape(p,F.shape),F):