enable type hints for Colormap

custom constructor ensures that type of color is always an np.ndarray
This commit is contained in:
Martin Diehl 2024-02-23 00:15:23 +01:00 committed by achalhp
parent 74413859ab
commit d6b1423c20
1 changed files with 24 additions and 3 deletions

View File

@ -47,6 +47,23 @@ class Colormap(mpl.colors.ListedColormap):
""" """
def __init__(self,
colors: np.ndarray, name: str):
"""
New colormap.
Parameters
----------
colors : numpy.ndarray, shape (:,3) or (:,4)
Color specifications as RGB(A) values.
name : str
String to identify the colormap.
"""
super().__init__(colors,name)
self.colors: np.ndarray = np.asarray(colors)
def __eq__(self, def __eq__(self,
other: object) -> bool: other: object) -> bool:
""" """
@ -57,8 +74,8 @@ class Colormap(mpl.colors.ListedColormap):
""" """
if not isinstance(other, Colormap): if not isinstance(other, Colormap):
return NotImplemented return NotImplemented
return len(self.colors) == len(other.colors) \ return np.array_equal(self.colors,other.colors)
and bool(np.all(self.colors == other.colors))
def __add__(self, def __add__(self,
other: 'Colormap') -> 'Colormap': other: 'Colormap') -> 'Colormap':
@ -71,6 +88,7 @@ class Colormap(mpl.colors.ListedColormap):
return Colormap(np.vstack((self.colors,other.colors)), return Colormap(np.vstack((self.colors,other.colors)),
f'{self.name}+{other.name}') f'{self.name}+{other.name}')
def __iadd__(self, def __iadd__(self,
other: 'Colormap') -> 'Colormap': other: 'Colormap') -> 'Colormap':
""" """
@ -81,6 +99,7 @@ class Colormap(mpl.colors.ListedColormap):
""" """
return self.__add__(other) return self.__add__(other)
def __mul__(self, def __mul__(self,
factor: int) -> 'Colormap': factor: int) -> 'Colormap':
""" """
@ -101,6 +120,7 @@ class Colormap(mpl.colors.ListedColormap):
""" """
return self.__mul__(factor) return self.__mul__(factor)
def __invert__(self) -> 'Colormap': def __invert__(self) -> 'Colormap':
""" """
Return ~self. Return ~self.
@ -110,6 +130,7 @@ class Colormap(mpl.colors.ListedColormap):
""" """
return self.reversed() return self.reversed()
def __repr__(self) -> str: def __repr__(self) -> str:
""" """
Return repr(self). Return repr(self).
@ -118,7 +139,7 @@ class Colormap(mpl.colors.ListedColormap):
""" """
fig = plt.figure(self.name,figsize=(5,.5)) fig = plt.figure(self.name,figsize=(5,.5))
ax1 = fig.add_axes([0, 0, 1, 1]) ax1 = fig.add_axes((0, 0, 1, 1))
ax1.set_axis_off() ax1.set_axis_off()
ax1.imshow(np.linspace(0,1,self.N).reshape(1,-1), ax1.imshow(np.linspace(0,1,self.N).reshape(1,-1),
aspect='auto', cmap=self, interpolation='nearest') aspect='auto', cmap=self, interpolation='nearest')