simplified

This commit is contained in:
Martin Diehl 2020-04-09 12:52:12 +02:00
parent cbfde73a29
commit b025c1838e
1 changed files with 5 additions and 5 deletions

View File

@ -528,7 +528,7 @@ class Rotation:
with np.errstate(invalid='ignore',divide='ignore'):
s = np.sign(qu[...,0:1])/np.sqrt(qu[...,1:2]**2+qu[...,2:3]**2+qu[...,3:4]**2)
omega = 2.0 * np.arccos(np.clip(qu[...,0:1],-1.0,1.0))
ax = np.where(np.expand_dims(np.sum(np.abs(qu[:,1:4])**2,axis=-1) < 1.0e-6,-1),
ax = np.where(np.sum(np.abs(qu[:,1:4])**2,axis=-1,keepdims=True) < 1.0e-6,
[0.0, 0.0, 1.0, 0.0], np.block([qu[...,1:4]*s,omega]))
ax = np.where(qu[...,0:1] < 1.0e-6,
np.block([qu[...,1:4],np.ones(qu.shape[:-1]+(1,))*np.pi]),ax) # TODO: Where not needed
@ -541,12 +541,12 @@ class Rotation:
if iszero(qu[0]):
ro = np.array([qu[1], qu[2], qu[3], np.inf])
else:
s = np.linalg.norm([qu[1],qu[2],qu[3]])
s = np.linalg.norm(qu[1:4])
ro = np.array([0.0,0.0,P,0.0] if iszero(s) else \
[ qu[1]/s, qu[2]/s, qu[3]/s, np.tan(np.arccos(np.clip(qu[0],-1.0,1.0)))])
else:
with np.errstate(invalid='ignore',divide='ignore'):
s = np.expand_dims(np.linalg.norm(qu[...,1:4],axis=1),-1)
s = np.linalg.norm(qu[...,1:4],axis=-1,keepdims=True)
ro = np.where(np.abs(s) < 1.0e-12,
[0.0,0.0,P,0.0],
np.block([qu[...,1:2]/s,qu[...,2:3]/s,qu[...,3:4]/s,
@ -573,7 +573,7 @@ class Rotation:
omega = 2.0 * np.arccos(np.clip(qu[...,0:1],-1.0,1.0))
ho = np.where(np.abs(omega) < 1.0e-12,
np.zeros(3),
qu[...,1:4]/np.linalg.norm(qu[...,1:4],axis=1).reshape(qu.shape[:-1]+(1,)) \
qu[...,1:4]/np.linalg.norm(qu[...,1:4],axis=-1,keepdims=True) \
* (0.75*(omega - np.sin(omega)))**(1./3.))
return ho
@ -738,7 +738,7 @@ class Rotation:
t = np.tan(eu[...,1:2]*0.5)
sigma = 0.5*(eu[...,0:1]+eu[...,2:3])
delta = 0.5*(eu[...,0:1]-eu[...,2:3])
tau = np.linalg.norm(np.block([t,np.sin(sigma)]),axis=-1).reshape(-1,1)
tau = np.linalg.norm(np.block([t,np.sin(sigma)]),axis=-1,keepdims=True)
alpha = np.where(np.abs(np.cos(sigma))<1.e-12,np.pi,2.0*np.arctan(tau/np.cos(sigma)))
with np.errstate(invalid='ignore',divide='ignore'):
ax = np.where(np.broadcast_to(np.abs(alpha)<1.0e-12,eu.shape[:-1]+(4,)),