distinguish isclose/allclose and __eq__
This commit is contained in:
parent
0ef6e43e62
commit
d5806075d4
|
@ -239,7 +239,9 @@ class Orientation(Rotation):
|
||||||
"""
|
"""
|
||||||
matching_type = all([hasattr(other,attr) and getattr(self,attr) == getattr(other,attr)
|
matching_type = all([hasattr(other,attr) and getattr(self,attr) == getattr(other,attr)
|
||||||
for attr in ['family','lattice','parameters']])
|
for attr in ['family','lattice','parameters']])
|
||||||
return np.logical_and(super().__eq__(other),matching_type)
|
s = self if self.family is None else self.reduced
|
||||||
|
o = other if other.family is None else other.reduced
|
||||||
|
return np.logical_and(super(__class__,s).__eq__(o),matching_type)
|
||||||
|
|
||||||
def __ne__(self,other):
|
def __ne__(self,other):
|
||||||
"""
|
"""
|
||||||
|
@ -254,6 +256,59 @@ class Orientation(Rotation):
|
||||||
return np.logical_not(self==other)
|
return np.logical_not(self==other)
|
||||||
|
|
||||||
|
|
||||||
|
def isclose(self,other,rtol=1e-5,atol=1e-8,equal_nan=True):
|
||||||
|
"""
|
||||||
|
Report where values are approximately equal to corresponding ones of other Orientation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
other : Orientation
|
||||||
|
Orientation to compare against.
|
||||||
|
rtol : float, optional
|
||||||
|
Relative tolerance of equality.
|
||||||
|
atol : float, optional
|
||||||
|
Absolute tolerance of equality.
|
||||||
|
equal_nan : bool, optional
|
||||||
|
Consider matching NaN values as equal. Defaults to True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
mask : numpy.ndarray bool
|
||||||
|
Mask indicating where corresponding orientations are close.
|
||||||
|
|
||||||
|
"""
|
||||||
|
matching_type = all([hasattr(other,attr) and getattr(self,attr) == getattr(other,attr)
|
||||||
|
for attr in ['family','lattice','parameters']])
|
||||||
|
s = self if self.family is None else self.reduced
|
||||||
|
o = other if other.family is None else other.reduced
|
||||||
|
return np.logical_and(super(__class__,s).isclose(o),matching_type)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def allclose(self,other,rtol=1e-5,atol=1e-8,equal_nan=True):
|
||||||
|
"""
|
||||||
|
Test whether all values are approximately equal to corresponding ones of other Orientation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
other : Orientation
|
||||||
|
Orientation to compare against.
|
||||||
|
rtol : float, optional
|
||||||
|
Relative tolerance of equality.
|
||||||
|
atol : float, optional
|
||||||
|
Absolute tolerance of equality.
|
||||||
|
equal_nan : bool, optional
|
||||||
|
Consider matching NaN values as equal. Defaults to True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
answer : bool
|
||||||
|
Whether all values are close between both orientations.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return np.all(self.isclose(other,rtol,atol,equal_nan))
|
||||||
|
|
||||||
|
|
||||||
def __mul__(self,other):
|
def __mul__(self,other):
|
||||||
"""
|
"""
|
||||||
Compose this orientation with other.
|
Compose this orientation with other.
|
||||||
|
|
|
@ -103,29 +103,20 @@ class Rotation:
|
||||||
"""
|
"""
|
||||||
Equal to other.
|
Equal to other.
|
||||||
|
|
||||||
Equality is determined taking limited floating point precision into account.
|
|
||||||
See numpy.allclose for details.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
other : Rotation
|
other : Rotation
|
||||||
Rotation to check for equality.
|
Rotation to check for equality.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
s = self.quaternion
|
return np.logical_or(np.all(self.quaternion == other.quaternion,axis=-1),
|
||||||
o = other.quaternion
|
np.all(self.quaternion == -1.0*other.quaternion,axis=-1))
|
||||||
if self.shape == () == other.shape:
|
|
||||||
return np.allclose(s,o) or (np.isclose(s[0],0.0) and np.allclose(s,-1.0*o))
|
|
||||||
else:
|
|
||||||
return np.all(np.isclose(s,o),-1) + np.all(np.isclose(s,-1.0*o),-1) * np.isclose(s[...,0],0.0)
|
|
||||||
|
|
||||||
def __ne__(self,other):
|
def __ne__(self,other):
|
||||||
"""
|
"""
|
||||||
Not equal to other.
|
Not equal to other.
|
||||||
|
|
||||||
Equality is determined taking limited floating point precision into
|
|
||||||
account. See numpy.allclose for details.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
other : Rotation
|
other : Rotation
|
||||||
|
@ -135,6 +126,57 @@ class Rotation:
|
||||||
return np.logical_not(self==other)
|
return np.logical_not(self==other)
|
||||||
|
|
||||||
|
|
||||||
|
def isclose(self,other,rtol=1e-5,atol=1e-8,equal_nan=True):
|
||||||
|
"""
|
||||||
|
Report where values are approximately equal to corresponding ones of other Rotation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
other : Rotation
|
||||||
|
Rotation to compare against.
|
||||||
|
rtol : float, optional
|
||||||
|
Relative tolerance of equality.
|
||||||
|
atol : float, optional
|
||||||
|
Absolute tolerance of equality.
|
||||||
|
equal_nan : bool, optional
|
||||||
|
Consider matching NaN values as equal. Defaults to True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
mask : numpy.ndarray bool
|
||||||
|
Mask indicating where corresponding rotations are close.
|
||||||
|
|
||||||
|
"""
|
||||||
|
s = self.quaternion
|
||||||
|
o = other.quaternion
|
||||||
|
return np.logical_or(np.all(np.isclose(s, o,rtol,atol,equal_nan),axis=-1),
|
||||||
|
np.all(np.isclose(s,-1.0*o,rtol,atol,equal_nan),axis=-1))
|
||||||
|
|
||||||
|
|
||||||
|
def allclose(self,other,rtol=1e-5,atol=1e-8,equal_nan=True):
|
||||||
|
"""
|
||||||
|
Test whether all values are approximately equal to corresponding ones of other Rotation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
other : Rotation
|
||||||
|
Rotation to compare against.
|
||||||
|
rtol : float, optional
|
||||||
|
Relative tolerance of equality.
|
||||||
|
atol : float, optional
|
||||||
|
Absolute tolerance of equality.
|
||||||
|
equal_nan : bool, optional
|
||||||
|
Consider matching NaN values as equal. Defaults to True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
answer : bool
|
||||||
|
Whether all values are close between both rotations.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return np.all(self.isclose(other,rtol,atol,equal_nan))
|
||||||
|
|
||||||
|
|
||||||
def __array__(self):
|
def __array__(self):
|
||||||
"""Initializer for numpy."""
|
"""Initializer for numpy."""
|
||||||
return self.quaternion
|
return self.quaternion
|
||||||
|
|
|
@ -140,6 +140,5 @@ class TestConfigMaterial:
|
||||||
if update:
|
if update:
|
||||||
cur.save(ref_path/'measured.material_yaml')
|
cur.save(ref_path/'measured.material_yaml')
|
||||||
for i,m in enumerate(ref['material']):
|
for i,m in enumerate(ref['material']):
|
||||||
assert Rotation(m['constituents'][0]['O']) == \
|
assert Rotation(m['constituents'][0]['O']).isclose(Rotation(cur['material'][i]['constituents'][0]['O']))
|
||||||
Rotation(cur['material'][i]['constituents'][0]['O'])
|
|
||||||
assert cur.is_valid and cur['phase'] == ref['phase'] and cur['homogenization'] == ref['homogenization']
|
assert cur.is_valid and cur['phase'] == ref['phase'] and cur['homogenization'] == ref['homogenization']
|
||||||
|
|
|
@ -222,7 +222,7 @@ class TestOrientation:
|
||||||
blend = util.shapeblender(o.shape,p.shape)
|
blend = util.shapeblender(o.shape,p.shape)
|
||||||
for loc in np.random.randint(0,blend,(10,len(blend))):
|
for loc in np.random.randint(0,blend,(10,len(blend))):
|
||||||
assert o[tuple(loc[:len(o.shape)])].disorientation(p[tuple(loc[-len(p.shape):])]) \
|
assert o[tuple(loc[:len(o.shape)])].disorientation(p[tuple(loc[-len(p.shape):])]) \
|
||||||
== o.disorientation(p)[tuple(loc)]
|
.isclose(o.disorientation(p)[tuple(loc)])
|
||||||
|
|
||||||
@pytest.mark.parametrize('lattice',Orientation.crystal_families)
|
@pytest.mark.parametrize('lattice',Orientation.crystal_families)
|
||||||
def test_disorientation360(self,lattice):
|
def test_disorientation360(self,lattice):
|
||||||
|
|
|
@ -960,7 +960,7 @@ class TestRotation:
|
||||||
if axis_angle[3] > np.pi:
|
if axis_angle[3] > np.pi:
|
||||||
axis_angle[3] -= 2.*np.pi
|
axis_angle[3] -= 2.*np.pi
|
||||||
axis_angle *= -1
|
axis_angle *= -1
|
||||||
assert R**pwr == Rotation.from_axis_angle(axis_angle)
|
assert (R**pwr).isclose(Rotation.from_axis_angle(axis_angle))
|
||||||
|
|
||||||
def test_rotate_inverse(self):
|
def test_rotate_inverse(self):
|
||||||
R = Rotation.from_random()
|
R = Rotation.from_random()
|
||||||
|
@ -1027,7 +1027,7 @@ class TestRotation:
|
||||||
|
|
||||||
def test_invariant(self):
|
def test_invariant(self):
|
||||||
R = Rotation.from_random()
|
R = Rotation.from_random()
|
||||||
assert R/R == R*R**(-1) == Rotation()
|
assert (R/R).isclose(R*R**(-1)) and (R/R).isclose(Rotation())
|
||||||
|
|
||||||
@pytest.mark.parametrize('item',[np.ones(3),np.ones((3,3)), np.ones((3,3,3,3))])
|
@pytest.mark.parametrize('item',[np.ones(3),np.ones((3,3)), np.ones((3,3,3,3))])
|
||||||
def test_apply(self,item):
|
def test_apply(self,item):
|
||||||
|
|
Loading…
Reference in New Issue