!= and == work componentwise
This commit is contained in:
parent
b705be9683
commit
d8b4b7e0f5
|
@ -225,10 +225,21 @@ class Orientation(Rotation):
|
||||||
Orientation to check for equality.
|
Orientation to check for equality.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return super().__eq__(other) \
|
matching_type = all([hasattr(other,attr) and getattr(self,attr) == getattr(other,attr)
|
||||||
and self.family == other.family \
|
for attr in ['family','lattice','parameters']])
|
||||||
and self.lattice == other.lattice \
|
return np.logical_and(super().__eq__(other),matching_type)
|
||||||
and self.parameters == other.parameters
|
|
||||||
|
def __ne__(self,other):
|
||||||
|
"""
|
||||||
|
Not equal to other.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
other : Orientation
|
||||||
|
Orientation to check for equality.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return np.logical_not(self==other)
|
||||||
|
|
||||||
|
|
||||||
def __matmul__(self,other):
|
def __matmul__(self,other):
|
||||||
|
|
|
@ -66,7 +66,7 @@ class Rotation:
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""Represent rotation as unit quaternion, rotation matrix, and Bunge-Euler angles."""
|
"""Represent rotation as unit quaternion, rotation matrix, and Bunge-Euler angles."""
|
||||||
if self == Rotation():
|
if self.shape == () and self == Rotation():
|
||||||
return 'Rotation()'
|
return 'Rotation()'
|
||||||
else:
|
else:
|
||||||
return f'Quaternions {self.shape}:\n'+str(self.quaternion) \
|
return f'Quaternions {self.shape}:\n'+str(self.quaternion) \
|
||||||
|
@ -105,10 +105,27 @@ class Rotation:
|
||||||
Rotation to check for equality.
|
Rotation to check for equality.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ambiguous = np.isclose(self.quaternion[...,0],0)
|
s = self.quaternion
|
||||||
return np.prod(self.shape,dtype=int) == np.prod(other.shape,dtype=int) \
|
o = other.quaternion
|
||||||
and ( np.allclose(self.quaternion,other.quaternion) \
|
if self.shape == () == other.shape:
|
||||||
or np.allclose(self.quaternion[ambiguous],-1*other.quaternion[ambiguous]))
|
return np.allclose(s,o) or (np.isclose(s[0],0.0) and np.allclose(s,-1.0*o))
|
||||||
|
else:
|
||||||
|
return np.all(np.isclose(s,o),-1) + np.all(np.isclose(s,-1.0*o),-1) * np.isclose(s[...,0],0.0)
|
||||||
|
|
||||||
|
def __ne__(self,other):
|
||||||
|
"""
|
||||||
|
Not equal to other.
|
||||||
|
|
||||||
|
Equality is determined taking limited floating point precision into
|
||||||
|
account. See numpy.allclose for details.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
other : Rotation
|
||||||
|
Rotation to check for equality.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return np.logical_not(self==other)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -25,13 +25,16 @@ class TestOrientation:
|
||||||
@pytest.mark.parametrize('shape',[None,5,(4,6)])
|
@pytest.mark.parametrize('shape',[None,5,(4,6)])
|
||||||
def test_equal(self,lattice,shape):
|
def test_equal(self,lattice,shape):
|
||||||
R = Rotation.from_random(shape)
|
R = Rotation.from_random(shape)
|
||||||
assert Orientation(R,lattice) == Orientation(R,lattice)
|
assert Orientation(R,lattice) == Orientation(R,lattice) if shape is None else \
|
||||||
|
(Orientation(R,lattice) == Orientation(R,lattice)).all()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('lattice',Orientation.crystal_families)
|
@pytest.mark.parametrize('lattice',Orientation.crystal_families)
|
||||||
@pytest.mark.parametrize('shape',[None,5,(4,6)])
|
@pytest.mark.parametrize('shape',[None,5,(4,6)])
|
||||||
def test_unequal(self,lattice,shape):
|
def test_unequal(self,lattice,shape):
|
||||||
R = Rotation.from_random(shape)
|
R = Rotation.from_random(shape)
|
||||||
assert not(Orientation(R,lattice) != Orientation(R,lattice))
|
assert not ( Orientation(R,lattice) != Orientation(R,lattice) if shape is None else \
|
||||||
|
(Orientation(R,lattice) != Orientation(R,lattice)).any())
|
||||||
|
|
||||||
@pytest.mark.parametrize('a,b',[
|
@pytest.mark.parametrize('a,b',[
|
||||||
(dict(rotation=[1,0,0,0]),
|
(dict(rotation=[1,0,0,0]),
|
||||||
|
@ -403,7 +406,7 @@ class TestOrientation:
|
||||||
def test_relationship_vectorize(self,set_of_quaternions,lattice,model):
|
def test_relationship_vectorize(self,set_of_quaternions,lattice,model):
|
||||||
r = Orientation(rotation=set_of_quaternions[:200].reshape((50,4,4)),lattice=lattice).related(model)
|
r = Orientation(rotation=set_of_quaternions[:200].reshape((50,4,4)),lattice=lattice).related(model)
|
||||||
for i in range(200):
|
for i in range(200):
|
||||||
assert r.reshape((-1,200))[:,i] == Orientation(set_of_quaternions[i],lattice).related(model)
|
assert (r.reshape((-1,200))[:,i] == Orientation(set_of_quaternions[i],lattice).related(model)).all()
|
||||||
|
|
||||||
@pytest.mark.parametrize('model',['Bain','KS','GT','GT_prime','NW','Pitsch'])
|
@pytest.mark.parametrize('model',['Bain','KS','GT','GT_prime','NW','Pitsch'])
|
||||||
@pytest.mark.parametrize('lattice',['cF','cI'])
|
@pytest.mark.parametrize('lattice',['cF','cI'])
|
||||||
|
|
|
@ -783,14 +783,22 @@ class TestRotation:
|
||||||
else:
|
else:
|
||||||
assert r.shape == shape
|
assert r.shape == shape
|
||||||
|
|
||||||
def test_equal(self):
|
@pytest.mark.parametrize('shape',[None,5,(4,6)])
|
||||||
assert Rotation.from_random(rng_seed=1) == Rotation.from_random(rng_seed=1)
|
def test_equal(self,shape):
|
||||||
|
R = Rotation.from_random(shape,rng_seed=1)
|
||||||
|
assert R == R if shape is None else (R == R).all()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('shape',[None,5,(4,6)])
|
||||||
|
def test_unequal(self,shape):
|
||||||
|
R = Rotation.from_random(shape,rng_seed=1)
|
||||||
|
assert not (R != R if shape is None else (R != R).any())
|
||||||
|
|
||||||
|
|
||||||
def test_equal_ambiguous(self):
|
def test_equal_ambiguous(self):
|
||||||
qu = np.random.rand(10,4)
|
qu = np.random.rand(10,4)
|
||||||
qu[:,0] = 0.
|
qu[:,0] = 0.
|
||||||
qu/=np.linalg.norm(qu,axis=1,keepdims=True)
|
qu/=np.linalg.norm(qu,axis=1,keepdims=True)
|
||||||
assert Rotation(qu) == Rotation(-qu)
|
assert (Rotation(qu) == Rotation(-qu)).all()
|
||||||
|
|
||||||
def test_inversion(self):
|
def test_inversion(self):
|
||||||
r = Rotation.from_random()
|
r = Rotation.from_random()
|
||||||
|
@ -807,7 +815,7 @@ class TestRotation:
|
||||||
p = Rotation.from_random(shape=shape)
|
p = Rotation.from_random(shape=shape)
|
||||||
s = r.append(p)
|
s = r.append(p)
|
||||||
print(f'append 2x {shape} --> {s.shape}')
|
print(f'append 2x {shape} --> {s.shape}')
|
||||||
assert s[0,...] == r[0,...] and s[-1,...] == p[-1,...]
|
assert np.logical_and(s[0,...] == r[0,...], s[-1,...] == p[-1,...]).all()
|
||||||
|
|
||||||
@pytest.mark.parametrize('quat,standardized',[
|
@pytest.mark.parametrize('quat,standardized',[
|
||||||
([-1,0,0,0],[1,0,0,0]),
|
([-1,0,0,0],[1,0,0,0]),
|
||||||
|
@ -829,7 +837,7 @@ class TestRotation:
|
||||||
@pytest.mark.parametrize('order',['C','F'])
|
@pytest.mark.parametrize('order',['C','F'])
|
||||||
def test_flatten_reshape(self,shape,order):
|
def test_flatten_reshape(self,shape,order):
|
||||||
r = Rotation.from_random(shape=shape)
|
r = Rotation.from_random(shape=shape)
|
||||||
assert r == r.flatten(order).reshape(shape,order)
|
assert (r == r.flatten(order).reshape(shape,order)).all()
|
||||||
|
|
||||||
@pytest.mark.parametrize('function',[Rotation.from_quaternion,
|
@pytest.mark.parametrize('function',[Rotation.from_quaternion,
|
||||||
Rotation.from_Euler_angles,
|
Rotation.from_Euler_angles,
|
||||||
|
|
Loading…
Reference in New Issue