From d8b4b7e0f596986f1d3aaeb72d757037a41af009 Mon Sep 17 00:00:00 2001 From: Martin Diehl Date: Sun, 3 Jan 2021 21:49:01 +0100 Subject: [PATCH] != and == work componentwise --- python/damask/_orientation.py | 19 +++++++++++++++---- python/damask/_rotation.py | 27 ++++++++++++++++++++++----- python/tests/test_Orientation.py | 9 ++++++--- python/tests/test_Rotation.py | 18 +++++++++++++----- 4 files changed, 56 insertions(+), 17 deletions(-) diff --git a/python/damask/_orientation.py b/python/damask/_orientation.py index 4bd8a1e96..d5be5a751 100644 --- a/python/damask/_orientation.py +++ b/python/damask/_orientation.py @@ -225,10 +225,21 @@ class Orientation(Rotation): Orientation to check for equality. """ - return super().__eq__(other) \ - and self.family == other.family \ - and self.lattice == other.lattice \ - and self.parameters == other.parameters + matching_type = all([hasattr(other,attr) and getattr(self,attr) == getattr(other,attr) + for attr in ['family','lattice','parameters']]) + return np.logical_and(super().__eq__(other),matching_type) + + def __ne__(self,other): + """ + Not equal to other. + + Parameters + ---------- + other : Orientation + Orientation to check for equality. + + """ + return np.logical_not(self==other) def __matmul__(self,other): diff --git a/python/damask/_rotation.py b/python/damask/_rotation.py index 9fb83af7b..50b7a3678 100644 --- a/python/damask/_rotation.py +++ b/python/damask/_rotation.py @@ -66,7 +66,7 @@ class Rotation: def __repr__(self): """Represent rotation as unit quaternion, rotation matrix, and Bunge-Euler angles.""" - if self == Rotation(): + if self.shape == () and self == Rotation(): return 'Rotation()' else: return f'Quaternions {self.shape}:\n'+str(self.quaternion) \ @@ -105,10 +105,27 @@ class Rotation: Rotation to check for equality. """ - ambiguous = np.isclose(self.quaternion[...,0],0) - return np.prod(self.shape,dtype=int) == np.prod(other.shape,dtype=int) \ - and ( np.allclose(self.quaternion,other.quaternion) \ - or np.allclose(self.quaternion[ambiguous],-1*other.quaternion[ambiguous])) + s = self.quaternion + o = other.quaternion + if self.shape == () == other.shape: + return np.allclose(s,o) or (np.isclose(s[0],0.0) and np.allclose(s,-1.0*o)) + else: + return np.all(np.isclose(s,o),-1) + np.all(np.isclose(s,-1.0*o),-1) * np.isclose(s[...,0],0.0) + + def __ne__(self,other): + """ + Not equal to other. + + Equality is determined taking limited floating point precision into + account. See numpy.allclose for details. + + Parameters + ---------- + other : Rotation + Rotation to check for equality. + + """ + return np.logical_not(self==other) @property diff --git a/python/tests/test_Orientation.py b/python/tests/test_Orientation.py index 5ab0361a8..436b73c04 100644 --- a/python/tests/test_Orientation.py +++ b/python/tests/test_Orientation.py @@ -25,13 +25,16 @@ class TestOrientation: @pytest.mark.parametrize('shape',[None,5,(4,6)]) def test_equal(self,lattice,shape): R = Rotation.from_random(shape) - assert Orientation(R,lattice) == Orientation(R,lattice) + assert Orientation(R,lattice) == Orientation(R,lattice) if shape is None else \ + (Orientation(R,lattice) == Orientation(R,lattice)).all() + @pytest.mark.parametrize('lattice',Orientation.crystal_families) @pytest.mark.parametrize('shape',[None,5,(4,6)]) def test_unequal(self,lattice,shape): R = Rotation.from_random(shape) - assert not(Orientation(R,lattice) != Orientation(R,lattice)) + assert not ( Orientation(R,lattice) != Orientation(R,lattice) if shape is None else \ + (Orientation(R,lattice) != Orientation(R,lattice)).any()) @pytest.mark.parametrize('a,b',[ (dict(rotation=[1,0,0,0]), @@ -403,7 +406,7 @@ class TestOrientation: def test_relationship_vectorize(self,set_of_quaternions,lattice,model): r = Orientation(rotation=set_of_quaternions[:200].reshape((50,4,4)),lattice=lattice).related(model) for i in range(200): - assert r.reshape((-1,200))[:,i] == Orientation(set_of_quaternions[i],lattice).related(model) + assert (r.reshape((-1,200))[:,i] == Orientation(set_of_quaternions[i],lattice).related(model)).all() @pytest.mark.parametrize('model',['Bain','KS','GT','GT_prime','NW','Pitsch']) @pytest.mark.parametrize('lattice',['cF','cI']) diff --git a/python/tests/test_Rotation.py b/python/tests/test_Rotation.py index 5aed0bea2..014efda99 100644 --- a/python/tests/test_Rotation.py +++ b/python/tests/test_Rotation.py @@ -783,14 +783,22 @@ class TestRotation: else: assert r.shape == shape - def test_equal(self): - assert Rotation.from_random(rng_seed=1) == Rotation.from_random(rng_seed=1) + @pytest.mark.parametrize('shape',[None,5,(4,6)]) + def test_equal(self,shape): + R = Rotation.from_random(shape,rng_seed=1) + assert R == R if shape is None else (R == R).all() + + @pytest.mark.parametrize('shape',[None,5,(4,6)]) + def test_unequal(self,shape): + R = Rotation.from_random(shape,rng_seed=1) + assert not (R != R if shape is None else (R != R).any()) + def test_equal_ambiguous(self): qu = np.random.rand(10,4) qu[:,0] = 0. qu/=np.linalg.norm(qu,axis=1,keepdims=True) - assert Rotation(qu) == Rotation(-qu) + assert (Rotation(qu) == Rotation(-qu)).all() def test_inversion(self): r = Rotation.from_random() @@ -807,7 +815,7 @@ class TestRotation: p = Rotation.from_random(shape=shape) s = r.append(p) print(f'append 2x {shape} --> {s.shape}') - assert s[0,...] == r[0,...] and s[-1,...] == p[-1,...] + assert np.logical_and(s[0,...] == r[0,...], s[-1,...] == p[-1,...]).all() @pytest.mark.parametrize('quat,standardized',[ ([-1,0,0,0],[1,0,0,0]), @@ -829,7 +837,7 @@ class TestRotation: @pytest.mark.parametrize('order',['C','F']) def test_flatten_reshape(self,shape,order): r = Rotation.from_random(shape=shape) - assert r == r.flatten(order).reshape(shape,order) + assert (r == r.flatten(order).reshape(shape,order)).all() @pytest.mark.parametrize('function',[Rotation.from_quaternion, Rotation.from_Euler_angles,