simplified and corrected

- Optional not needed for 'None' argument
- Use TextIO for typehints, TextIOWrapper seems to cause problems
- Test for tuple/list input for public functions
- internal functions that are always called with np.ndarray don't need
  to offer flexibility. They might work, but we don't guarantee
  anything.
This commit is contained in:
Martin Diehl 2021-11-23 17:59:56 +01:00
parent 3410a8d4cb
commit 889ab87914
2 changed files with 25 additions and 24 deletions

View File

@ -3,8 +3,7 @@ import json
import functools import functools
import colorsys import colorsys
from pathlib import Path from pathlib import Path
from typing import Sequence, Union, Optional, List, TextIO from typing import Sequence, Union, List, TextIO
from io import TextIOWrapper
import numpy as np import numpy as np
@ -56,7 +55,7 @@ class Colormap(mpl.colors.ListedColormap):
"""Reverse.""" """Reverse."""
return self.reversed() return self.reversed()
def __repr__(self) -> "Colormap": def __repr__(self) -> str:
"""Show as matplotlib figure.""" """Show as matplotlib figure."""
fig = plt.figure(self.name,figsize=(5,.5)) fig = plt.figure(self.name,figsize=(5,.5))
ax1 = fig.add_axes([0, 0, 1, 1]) ax1 = fig.add_axes([0, 0, 1, 1])
@ -195,8 +194,8 @@ class Colormap(mpl.colors.ListedColormap):
def shade(self, def shade(self,
field: np.ndarray, field: np.ndarray,
bounds: Optional[Sequence[float]] = None, bounds: Sequence[float] = None,
gap: Optional[np.ndarray] = None) -> Image: gap: float = None) -> Image:
""" """
Generate PIL image of 2D field using colormap. Generate PIL image of 2D field using colormap.
@ -262,7 +261,7 @@ class Colormap(mpl.colors.ListedColormap):
return Colormap(np.array(rev.colors),rev.name[:-4] if rev.name.endswith('_r_r') else rev.name) return Colormap(np.array(rev.colors),rev.name[:-4] if rev.name.endswith('_r_r') else rev.name)
def _get_file_handle(self, fname: Union[TextIOWrapper, str, pathlib.Path, None], ext: str) -> TextIO: def _get_file_handle(self, fname: Union[TextIO, str, Path, None], ext: str) -> TextIO:
""" """
Provide file handle. Provide file handle.
@ -270,7 +269,6 @@ class Colormap(mpl.colors.ListedColormap):
---------- ----------
fname : file, str, pathlib.Path, or None fname : file, str, pathlib.Path, or None
Filename or filehandle, will be name of the colormap+extension if None. Filename or filehandle, will be name of the colormap+extension if None.
ext: str ext: str
Extension of the filename. Extension of the filename.
@ -280,17 +278,15 @@ class Colormap(mpl.colors.ListedColormap):
File handle with write access. File handle with write access.
""" """
fname = pathlib.Path(self.name.replace(' ','_'))\ if fname is None:
.with_suffix(('' if ext is None or ext.startswith('.') else '.')+ext) if fname is None else fname return open(self.name.replace(' ','_')+'.'+ext, 'w', newline='\n')
if isinstance(fname, (str,pathlib.Path)): elif isinstance(fname, (str, Path)):
return open(fname, 'w', newline='\n') return open(fname, 'w', newline='\n')
if isinstance(fname, TextIOWrapper): else:
return fname return fname
raise TypeError
def save_paraview(self, fname: Union[TextIO, str, Path] = None):
def save_paraview(self, fname: Optional[Union[TextIOWrapper, str, pathlib.Path]] = None):
""" """
Save as JSON file for use in Paraview. Save as JSON file for use in Paraview.
@ -316,7 +312,7 @@ class Colormap(mpl.colors.ListedColormap):
json.dump(out,self._get_file_handle(fname,'json'),indent=4) json.dump(out,self._get_file_handle(fname,'json'),indent=4)
def save_ASCII(self, fname: Union[TextIOWrapper, str, pathlib.Path] = None): def save_ASCII(self, fname: Union[TextIO, str, Path] = None):
""" """
Save as ASCII file. Save as ASCII file.
@ -332,7 +328,7 @@ class Colormap(mpl.colors.ListedColormap):
t.save(self._get_file_handle(fname,'txt')) t.save(self._get_file_handle(fname,'txt'))
def save_GOM(self, fname: Union[TextIOWrapper, str, pathlib.Path] = None): def save_GOM(self, fname: Union[TextIO, str, Path] = None):
""" """
Save as ASCII file for use in GOM Aramis. Save as ASCII file for use in GOM Aramis.
@ -353,7 +349,7 @@ class Colormap(mpl.colors.ListedColormap):
self._get_file_handle(fname,'legend').write(GOM_str) self._get_file_handle(fname,'legend').write(GOM_str)
def save_gmsh(self, fname: Optional[Union[TextIOWrapper, str, pathlib.Path]] = None): def save_gmsh(self, fname: Union[TextIO, str, Path] = None):
""" """
Save as ASCII file for use in gmsh. Save as ASCII file for use in gmsh.
@ -373,8 +369,8 @@ class Colormap(mpl.colors.ListedColormap):
@staticmethod @staticmethod
def _interpolate_msh(frac, def _interpolate_msh(frac,
low: Sequence[float], low: np.ndarray,
high: Sequence[float]) -> np.ndarray: high: np.ndarray) -> np.ndarray:
""" """
Interpolate in Msh color space. Interpolate in Msh color space.
@ -451,24 +447,24 @@ class Colormap(mpl.colors.ListedColormap):
@staticmethod @staticmethod
def _hsv2rgb(hsv: Sequence[float]) -> np.ndarray: def _hsv2rgb(hsv: np.ndarray) -> np.ndarray:
"""H(ue) S(aturation) V(alue) to R(red) G(reen) B(lue).""" """H(ue) S(aturation) V(alue) to R(red) G(reen) B(lue)."""
return np.array(colorsys.hsv_to_rgb(hsv[0]/360.,hsv[1],hsv[2])) return np.array(colorsys.hsv_to_rgb(hsv[0]/360.,hsv[1],hsv[2]))
@staticmethod @staticmethod
def _rgb2hsv(rgb: Sequence[float]) -> np.ndarray: def _rgb2hsv(rgb: np.ndarray) -> np.ndarray:
"""R(ed) G(reen) B(lue) to H(ue) S(aturation) V(alue).""" """R(ed) G(reen) B(lue) to H(ue) S(aturation) V(alue)."""
h,s,v = colorsys.rgb_to_hsv(rgb[0],rgb[1],rgb[2]) h,s,v = colorsys.rgb_to_hsv(rgb[0],rgb[1],rgb[2])
return np.array([h*360,s,v]) return np.array([h*360,s,v])
@staticmethod @staticmethod
def _hsl2rgb(hsl: Sequence[float]) -> np.ndarray: def _hsl2rgb(hsl: np.ndarray) -> np.ndarray:
"""H(ue) S(aturation) L(uminance) to R(red) G(reen) B(lue).""" """H(ue) S(aturation) L(uminance) to R(red) G(reen) B(lue)."""
return np.array(colorsys.hls_to_rgb(hsl[0]/360.,hsl[2],hsl[1])) return np.array(colorsys.hls_to_rgb(hsl[0]/360.,hsl[2],hsl[1]))
@staticmethod @staticmethod
def _rgb2hsl(rgb: Sequence[float]) -> np.ndarray: def _rgb2hsl(rgb: np.ndarray) -> np.ndarray:
"""R(ed) G(reen) B(lue) to H(ue) S(aturation) L(uminance).""" """R(ed) G(reen) B(lue) to H(ue) S(aturation) L(uminance)."""
h,l,s = colorsys.rgb_to_hls(rgb[0],rgb[1],rgb[2]) h,l,s = colorsys.rgb_to_hls(rgb[0],rgb[1],rgb[2])
return np.array([h*360,s,l]) return np.array([h*360,s,l])
@ -532,7 +528,7 @@ class Colormap(mpl.colors.ListedColormap):
])*(ref_white if ref_white is not None else _ref_white) ])*(ref_white if ref_white is not None else _ref_white)
@staticmethod @staticmethod
def _xyz2lab(xyz: np.ndarray, ref_white: Optional[np.ndarray] = None) -> np.ndarray: def _xyz2lab(xyz: np.ndarray, ref_white: np.ndarray = None) -> np.ndarray:
""" """
CIE Xyz to CIE Lab. CIE Xyz to CIE Lab.

View File

@ -77,6 +77,11 @@ class TestColormap:
# xyz2msh # xyz2msh
assert np.allclose(Colormap._xyz2msh(xyz),msh,atol=1.e-6,rtol=0) assert np.allclose(Colormap._xyz2msh(xyz),msh,atol=1.e-6,rtol=0)
@pytest.mark.parametrize('low,high',[((0,0,0),(1,1,1)),
([0,0,0],[1,1,1]),
(np.array([0,0,0]),np.array([1,1,1]))])
def test_from_range_types(self,low,high):
c = Colormap.from_range(low,high) # noqa
@pytest.mark.parametrize('format',['ASCII','paraview','GOM','gmsh']) @pytest.mark.parametrize('format',['ASCII','paraview','GOM','gmsh'])
@pytest.mark.parametrize('model',['rgb','hsv','hsl','xyz','lab','msh']) @pytest.mark.parametrize('model',['rgb','hsv','hsl','xyz','lab','msh'])