simplified and corrected

- Optional not needed for 'None' argument
- Use TextIO for typehints, TextIOWrapper seems to cause problems
- Test for tuple/list input for public functions
- internal functions that are always called with np.ndarray don't need
  to offer flexibility. They might work, but we don't guarantee
  anything.
This commit is contained in:
Martin Diehl 2021-11-23 17:59:56 +01:00
parent 3410a8d4cb
commit 889ab87914
2 changed files with 25 additions and 24 deletions

View File

@ -3,8 +3,7 @@ import json
import functools
import colorsys
from pathlib import Path
from typing import Sequence, Union, Optional, List, TextIO
from io import TextIOWrapper
from typing import Sequence, Union, List, TextIO
import numpy as np
@ -56,7 +55,7 @@ class Colormap(mpl.colors.ListedColormap):
"""Reverse."""
return self.reversed()
def __repr__(self) -> "Colormap":
def __repr__(self) -> str:
"""Show as matplotlib figure."""
fig = plt.figure(self.name,figsize=(5,.5))
ax1 = fig.add_axes([0, 0, 1, 1])
@ -195,8 +194,8 @@ class Colormap(mpl.colors.ListedColormap):
def shade(self,
field: np.ndarray,
bounds: Optional[Sequence[float]] = None,
gap: Optional[np.ndarray] = None) -> Image:
bounds: Sequence[float] = None,
gap: float = None) -> Image:
"""
Generate PIL image of 2D field using colormap.
@ -262,7 +261,7 @@ class Colormap(mpl.colors.ListedColormap):
return Colormap(np.array(rev.colors),rev.name[:-4] if rev.name.endswith('_r_r') else rev.name)
def _get_file_handle(self, fname: Union[TextIOWrapper, str, pathlib.Path, None], ext: str) -> TextIO:
def _get_file_handle(self, fname: Union[TextIO, str, Path, None], ext: str) -> TextIO:
"""
Provide file handle.
@ -270,7 +269,6 @@ class Colormap(mpl.colors.ListedColormap):
----------
fname : file, str, pathlib.Path, or None
Filename or filehandle, will be name of the colormap+extension if None.
ext: str
Extension of the filename.
@ -280,17 +278,15 @@ class Colormap(mpl.colors.ListedColormap):
File handle with write access.
"""
fname = pathlib.Path(self.name.replace(' ','_'))\
.with_suffix(('' if ext is None or ext.startswith('.') else '.')+ext) if fname is None else fname
if isinstance(fname, (str,pathlib.Path)):
if fname is None:
return open(self.name.replace(' ','_')+'.'+ext, 'w', newline='\n')
elif isinstance(fname, (str, Path)):
return open(fname, 'w', newline='\n')
if isinstance(fname, TextIOWrapper):
else:
return fname
raise TypeError
def save_paraview(self, fname: Optional[Union[TextIOWrapper, str, pathlib.Path]] = None):
def save_paraview(self, fname: Union[TextIO, str, Path] = None):
"""
Save as JSON file for use in Paraview.
@ -316,7 +312,7 @@ class Colormap(mpl.colors.ListedColormap):
json.dump(out,self._get_file_handle(fname,'json'),indent=4)
def save_ASCII(self, fname: Union[TextIOWrapper, str, pathlib.Path] = None):
def save_ASCII(self, fname: Union[TextIO, str, Path] = None):
"""
Save as ASCII file.
@ -332,7 +328,7 @@ class Colormap(mpl.colors.ListedColormap):
t.save(self._get_file_handle(fname,'txt'))
def save_GOM(self, fname: Union[TextIOWrapper, str, pathlib.Path] = None):
def save_GOM(self, fname: Union[TextIO, str, Path] = None):
"""
Save as ASCII file for use in GOM Aramis.
@ -353,7 +349,7 @@ class Colormap(mpl.colors.ListedColormap):
self._get_file_handle(fname,'legend').write(GOM_str)
def save_gmsh(self, fname: Optional[Union[TextIOWrapper, str, pathlib.Path]] = None):
def save_gmsh(self, fname: Union[TextIO, str, Path] = None):
"""
Save as ASCII file for use in gmsh.
@ -373,8 +369,8 @@ class Colormap(mpl.colors.ListedColormap):
@staticmethod
def _interpolate_msh(frac,
low: Sequence[float],
high: Sequence[float]) -> np.ndarray:
low: np.ndarray,
high: np.ndarray) -> np.ndarray:
"""
Interpolate in Msh color space.
@ -451,24 +447,24 @@ class Colormap(mpl.colors.ListedColormap):
@staticmethod
def _hsv2rgb(hsv: Sequence[float]) -> np.ndarray:
def _hsv2rgb(hsv: np.ndarray) -> np.ndarray:
"""H(ue) S(aturation) V(alue) to R(red) G(reen) B(lue)."""
return np.array(colorsys.hsv_to_rgb(hsv[0]/360.,hsv[1],hsv[2]))
@staticmethod
def _rgb2hsv(rgb: Sequence[float]) -> np.ndarray:
def _rgb2hsv(rgb: np.ndarray) -> np.ndarray:
"""R(ed) G(reen) B(lue) to H(ue) S(aturation) V(alue)."""
h,s,v = colorsys.rgb_to_hsv(rgb[0],rgb[1],rgb[2])
return np.array([h*360,s,v])
@staticmethod
def _hsl2rgb(hsl: Sequence[float]) -> np.ndarray:
def _hsl2rgb(hsl: np.ndarray) -> np.ndarray:
"""H(ue) S(aturation) L(uminance) to R(red) G(reen) B(lue)."""
return np.array(colorsys.hls_to_rgb(hsl[0]/360.,hsl[2],hsl[1]))
@staticmethod
def _rgb2hsl(rgb: Sequence[float]) -> np.ndarray:
def _rgb2hsl(rgb: np.ndarray) -> np.ndarray:
"""R(ed) G(reen) B(lue) to H(ue) S(aturation) L(uminance)."""
h,l,s = colorsys.rgb_to_hls(rgb[0],rgb[1],rgb[2])
return np.array([h*360,s,l])
@ -532,7 +528,7 @@ class Colormap(mpl.colors.ListedColormap):
])*(ref_white if ref_white is not None else _ref_white)
@staticmethod
def _xyz2lab(xyz: np.ndarray, ref_white: Optional[np.ndarray] = None) -> np.ndarray:
def _xyz2lab(xyz: np.ndarray, ref_white: np.ndarray = None) -> np.ndarray:
"""
CIE Xyz to CIE Lab.

View File

@ -77,6 +77,11 @@ class TestColormap:
# xyz2msh
assert np.allclose(Colormap._xyz2msh(xyz),msh,atol=1.e-6,rtol=0)
@pytest.mark.parametrize('low,high',[((0,0,0),(1,1,1)),
([0,0,0],[1,1,1]),
(np.array([0,0,0]),np.array([1,1,1]))])
def test_from_range_types(self,low,high):
c = Colormap.from_range(low,high) # noqa
@pytest.mark.parametrize('format',['ASCII','paraview','GOM','gmsh'])
@pytest.mark.parametrize('model',['rgb','hsv','hsl','xyz','lab','msh'])