!= and == work componentwise

This commit is contained in:
Martin Diehl 2021-01-03 21:49:01 +01:00
parent b705be9683
commit d8b4b7e0f5
4 changed files with 56 additions and 17 deletions

View File

@ -225,10 +225,21 @@ class Orientation(Rotation):
Orientation to check for equality.
"""
return super().__eq__(other) \
and self.family == other.family \
and self.lattice == other.lattice \
and self.parameters == other.parameters
matching_type = all([hasattr(other,attr) and getattr(self,attr) == getattr(other,attr)
for attr in ['family','lattice','parameters']])
return np.logical_and(super().__eq__(other),matching_type)
def __ne__(self,other):
"""
Not equal to other.
Parameters
----------
other : Orientation
Orientation to check for equality.
"""
return np.logical_not(self==other)
def __matmul__(self,other):

View File

@ -66,7 +66,7 @@ class Rotation:
def __repr__(self):
"""Represent rotation as unit quaternion, rotation matrix, and Bunge-Euler angles."""
if self == Rotation():
if self.shape == () and self == Rotation():
return 'Rotation()'
else:
return f'Quaternions {self.shape}:\n'+str(self.quaternion) \
@ -105,10 +105,27 @@ class Rotation:
Rotation to check for equality.
"""
ambiguous = np.isclose(self.quaternion[...,0],0)
return np.prod(self.shape,dtype=int) == np.prod(other.shape,dtype=int) \
and ( np.allclose(self.quaternion,other.quaternion) \
or np.allclose(self.quaternion[ambiguous],-1*other.quaternion[ambiguous]))
s = self.quaternion
o = other.quaternion
if self.shape == () == other.shape:
return np.allclose(s,o) or (np.isclose(s[0],0.0) and np.allclose(s,-1.0*o))
else:
return np.all(np.isclose(s,o),-1) + np.all(np.isclose(s,-1.0*o),-1) * np.isclose(s[...,0],0.0)
def __ne__(self,other):
"""
Not equal to other.
Equality is determined taking limited floating point precision into
account. See numpy.allclose for details.
Parameters
----------
other : Rotation
Rotation to check for equality.
"""
return np.logical_not(self==other)
@property

View File

@ -25,13 +25,16 @@ class TestOrientation:
@pytest.mark.parametrize('shape',[None,5,(4,6)])
def test_equal(self,lattice,shape):
R = Rotation.from_random(shape)
assert Orientation(R,lattice) == Orientation(R,lattice)
assert Orientation(R,lattice) == Orientation(R,lattice) if shape is None else \
(Orientation(R,lattice) == Orientation(R,lattice)).all()
@pytest.mark.parametrize('lattice',Orientation.crystal_families)
@pytest.mark.parametrize('shape',[None,5,(4,6)])
def test_unequal(self,lattice,shape):
R = Rotation.from_random(shape)
assert not(Orientation(R,lattice) != Orientation(R,lattice))
assert not ( Orientation(R,lattice) != Orientation(R,lattice) if shape is None else \
(Orientation(R,lattice) != Orientation(R,lattice)).any())
@pytest.mark.parametrize('a,b',[
(dict(rotation=[1,0,0,0]),
@ -403,7 +406,7 @@ class TestOrientation:
def test_relationship_vectorize(self,set_of_quaternions,lattice,model):
r = Orientation(rotation=set_of_quaternions[:200].reshape((50,4,4)),lattice=lattice).related(model)
for i in range(200):
assert r.reshape((-1,200))[:,i] == Orientation(set_of_quaternions[i],lattice).related(model)
assert (r.reshape((-1,200))[:,i] == Orientation(set_of_quaternions[i],lattice).related(model)).all()
@pytest.mark.parametrize('model',['Bain','KS','GT','GT_prime','NW','Pitsch'])
@pytest.mark.parametrize('lattice',['cF','cI'])

View File

@ -783,14 +783,22 @@ class TestRotation:
else:
assert r.shape == shape
def test_equal(self):
assert Rotation.from_random(rng_seed=1) == Rotation.from_random(rng_seed=1)
@pytest.mark.parametrize('shape',[None,5,(4,6)])
def test_equal(self,shape):
R = Rotation.from_random(shape,rng_seed=1)
assert R == R if shape is None else (R == R).all()
@pytest.mark.parametrize('shape',[None,5,(4,6)])
def test_unequal(self,shape):
R = Rotation.from_random(shape,rng_seed=1)
assert not (R != R if shape is None else (R != R).any())
def test_equal_ambiguous(self):
qu = np.random.rand(10,4)
qu[:,0] = 0.
qu/=np.linalg.norm(qu,axis=1,keepdims=True)
assert Rotation(qu) == Rotation(-qu)
assert (Rotation(qu) == Rotation(-qu)).all()
def test_inversion(self):
r = Rotation.from_random()
@ -807,7 +815,7 @@ class TestRotation:
p = Rotation.from_random(shape=shape)
s = r.append(p)
print(f'append 2x {shape} --> {s.shape}')
assert s[0,...] == r[0,...] and s[-1,...] == p[-1,...]
assert np.logical_and(s[0,...] == r[0,...], s[-1,...] == p[-1,...]).all()
@pytest.mark.parametrize('quat,standardized',[
([-1,0,0,0],[1,0,0,0]),
@ -829,7 +837,7 @@ class TestRotation:
@pytest.mark.parametrize('order',['C','F'])
def test_flatten_reshape(self,shape,order):
r = Rotation.from_random(shape=shape)
assert r == r.flatten(order).reshape(shape,order)
assert (r == r.flatten(order).reshape(shape,order)).all()
@pytest.mark.parametrize('function',[Rotation.from_quaternion,
Rotation.from_Euler_angles,