ensure correct shape

This commit is contained in:
Martin Diehl 2020-04-21 13:22:55 +02:00
parent b1be4e7ac8
commit 97a5880d76
2 changed files with 25 additions and 0 deletions

View File

@ -267,6 +267,8 @@ class Rotation:
P = -1):
qu = np.array(quaternion,dtype=float)
if qu.shape[:-2:-1] != (4,):
raise ValueError('Invalid shape.')
if P > 0: qu[...,1:4] *= -1 # convert from P=1 to P=-1
if acceptHomomorph:
@ -284,6 +286,8 @@ class Rotation:
degrees = False):
eu = np.array(eulers,dtype=float)
if eu.shape[:-2:-1] != (3,):
raise ValueError('Invalid shape.')
eu = np.radians(eu) if degrees else eu
if np.any(eu < 0.0) or np.any(eu > 2.0*np.pi) or np.any(eu[...,1] > np.pi): # ToDo: No separate check for PHI
@ -298,6 +302,8 @@ class Rotation:
P = -1):
ax = np.array(axis_angle,dtype=float)
if ax.shape[:-2:-1] != (4,):
raise ValueError('Invalid shape.')
if P > 0: ax[...,0:3] *= -1 # convert from P=1 to P=-1
if degrees: ax[..., 3] = np.radians(ax[...,3])
@ -315,6 +321,8 @@ class Rotation:
reciprocal = False):
om = np.array(basis,dtype=float)
if om.shape[:-3:-1] != (3,3):
raise ValueError('Invalid shape.')
if reciprocal:
om = np.linalg.inv(mechanics.transpose(om)/np.pi) # transform reciprocal basis set
@ -342,6 +350,8 @@ class Rotation:
P = -1):
ro = np.array(rodrigues,dtype=float)
if ro.shape[:-2:-1] != (4,):
raise ValueError('Invalid shape.')
if P > 0: ro[...,0:3] *= -1 # convert from P=1 to P=-1
if normalise: ro[...,0:3] /= np.linalg.norm(ro[...,0:3],axis=-1)
@ -357,6 +367,8 @@ class Rotation:
P = -1):
ho = np.array(homochoric,dtype=float)
if ho.shape[:-2:-1] != (3,):
raise ValueError('Invalid shape.')
if P > 0: ho *= -1 # convert from P=1 to P=-1
@ -370,6 +382,8 @@ class Rotation:
P = -1):
cu = np.array(cubochoric,dtype=float)
if cu.shape[:-2:-1] != (3,):
raise ValueError('Invalid shape.')
if np.abs(np.max(cu))>np.pi**(2./3.) * 0.5+1e-9:
raise ValueError('Cubochoric coordinate outside of the cube: {} {} {}.'.format(*cu))

View File

@ -157,6 +157,17 @@ class TestRotation:
print(m,o,rot.asQuaternion())
assert ok and o.max() < np.pi**(2./3.)*0.5+1.e-9
@pytest.mark.parametrize('function',[Rotation.from_quaternion,
Rotation.from_Eulers,
Rotation.from_axis_angle,
Rotation.from_matrix,
Rotation.from_Rodrigues,
Rotation.from_homochoric])
def test_invalid_shape(self,function):
invalid_shape = np.random.random(np.random.randint(8,32,(3)))
with pytest.raises(ValueError):
function(invalid_shape)
@pytest.mark.parametrize('function,invalid',[(Rotation.from_quaternion, np.array([-1,0,0,0])),
(Rotation.from_quaternion, np.array([1,1,1,0])),
(Rotation.from_Eulers, np.array([1,4,0])),