"""Miscellaneous helper functionality."""

import sys
import datetime
import os
import subprocess
import shlex
import re
import fractions
from collections import abc
from functools import reduce
from typing import Union, Tuple, Iterable, Callable, Dict, List, Any, Literal, Collection
from pathlib import Path

import numpy as np
import h5py

from . import version
from ._typehints import FloatSequence, NumpyRngSeed

# limit visibility
__all__=[
         'srepr',
         'emph', 'deemph', 'warn', 'strikeout',
         'run',
         'natural_sort',
         'show_progress',
         'scale_to_coprime',
         'project_equal_angle', 'project_equal_area',
         'hybrid_IA',
         'execution_stamp',
         'shapeshifter', 'shapeblender',
         'extend_docstring', 'extended_docstring',
         'Bravais_to_Miller', 'Miller_to_Bravais',
         'DREAM3D_base_group', 'DREAM3D_cell_data_group',
         'dict_prune', 'dict_flatten'
        ]

# https://svn.blender.org/svnroot/bf-blender/trunk/blender/build_files/scons/tools/bcolors.py
# https://stackoverflow.com/questions/287871
_colors = {
           'header' :   '\033[95m',
           'OK_blue':   '\033[94m',
           'OK_green':  '\033[92m',
           'warning':   '\033[93m',
           'fail':      '\033[91m',
           'end_color': '\033[0m',
           'bold':      '\033[1m',
           'dim':       '\033[2m',
           'underline': '\033[4m',
           'crossout':  '\033[9m'
          }

####################################################################################################
# Functions
####################################################################################################
def srepr(msg,
          glue: str = '\n') -> str:
    r"""
    Join items with glue string.

    Parameters
    ----------
    msg : object with __repr__ or sequence of objects with __repr__
        Items to join.
    glue : str, optional
        Glue used for joining operation. Defaults to '\n'.

    Returns
    -------
    joined : str
        String representation of the joined items.

    """
    if (not hasattr(msg, 'strip') and
           (hasattr(msg, '__getitem__') or
            hasattr(msg, '__iter__'))):
        return glue.join(str(x) for x in msg)
    else:
        return msg if isinstance(msg,str) else repr(msg)


def emph(msg) -> str:
    """
    Format with emphasis.

    Parameters
    ----------
    msg : object with __repr__ or sequence of objects with __repr__
        Message to format.

    Returns
    -------
    formatted : str
        Formatted string representation of the joined items.

    """
    return _colors['bold']+srepr(msg)+_colors['end_color']

def deemph(msg) -> str:
    """
    Format with deemphasis.

    Parameters
    ----------
    msg : object with __repr__ or sequence of objects with __repr__
        Message to format.

    Returns
    -------
    formatted : str
        Formatted string representation of the joined items.

    """
    return _colors['dim']+srepr(msg)+_colors['end_color']

def warn(msg) -> str:
    """
    Format for warning.

    Parameters
    ----------
    msg : object with __repr__ or sequence of objects with __repr__
        Message to format.

    Returns
    -------
    formatted : str
        Formatted string representation of the joined items.

    """
    return _colors['warning']+emph(msg)+_colors['end_color']

def strikeout(msg) -> str:
    """
    Format as strikeout.

    Parameters
    ----------
    msg : object with __repr__ or iterable of objects with __repr__
        Message to format.

    Returns
    -------
    formatted : str
        Formatted string representation of the joined items.

    """
    return _colors['crossout']+srepr(msg)+_colors['end_color']


def run(cmd: str,
        wd: str = './',
        env: Dict[str, str] = None,
        timeout: int = None) -> Tuple[str, str]:
    """
    Run a command.

    Parameters
    ----------
    cmd : str
        Command to be executed.
    wd : str, optional
        Working directory of process. Defaults to './'.
    env : dict, optional
        Environment for execution.
    timeout : integer, optional
        Timeout in seconds.

    Returns
    -------
    stdout, stderr : (str, str)
        Output of the executed command.

    """
    print(f"running '{cmd}' in '{wd}'")
    process = subprocess.run(shlex.split(cmd),
                             stdout = subprocess.PIPE,
                             stderr = subprocess.PIPE,
                             env = os.environ if env is None else env,
                             cwd = wd,
                             encoding = 'utf-8',
                             timeout = timeout)

    if process.returncode != 0:
        print(process.stdout)
        print(process.stderr)
        raise RuntimeError(f"'{cmd}' failed with returncode {process.returncode}")

    return process.stdout, process.stderr


execute = run


def natural_sort(key: str) -> List[Union[int, str]]:
    """
    Natural sort.

    For use in python's 'sorted'.

    References
    ----------
    https://en.wikipedia.org/wiki/Natural_sort_order

    """
    convert = lambda text: int(text) if text.isdigit() else text
    return [ convert(c) for c in re.split('([0-9]+)', key) ]


def show_progress(iterable: Iterable,
                  N_iter: int = None,
                  prefix: str = '',
                  bar_length: int = 50) -> Any:
    """
    Decorate a loop with a progress bar.

    Use similar like enumerate.

    Parameters
    ----------
    iterable : iterable
        Iterable to be decorated.
    N_iter : int, optional
        Total number of iterations. Required if iterable is not a sequence.
    prefix : str, optional
        Prefix string.
    bar_length : int, optional
        Length of progress bar in characters. Defaults to 50.

    """
    if isinstance(iterable,abc.Sequence):
        if N_iter is None:
            N = len(iterable)
        else:
            raise ValueError('N_iter given for sequence')
    else:
        if N_iter is None:
            raise ValueError('N_iter not given')

        N = N_iter

    if N <= 1:
        for item in iterable:
            yield item
    else:
        status = ProgressBar(N,prefix,bar_length)
        for i,item in enumerate(iterable):
            yield item
            status.update(i)


def scale_to_coprime(v: FloatSequence) -> np.ndarray:
    """
    Scale vector to co-prime (relatively prime) integers.

    Parameters
    ----------
    v : sequence of float, len (:)
        Vector to scale.

    Returns
    -------
    m : numpy.ndarray, shape (:)
        Vector scaled to co-prime numbers.

    """
    MAX_DENOMINATOR = 1000000

    def get_square_denominator(x):
        """Denominator of the square of a number."""
        return fractions.Fraction(x ** 2).limit_denominator(MAX_DENOMINATOR).denominator

    def lcm(a,b):
        """Least common multiple."""
        try:
            return np.lcm(a,b)                                                                      # numpy > 1.18
        except AttributeError:
            return a * b // np.gcd(a, b)

    v_ = np.array(v)
    m = (v_ * reduce(lcm, map(lambda x: int(get_square_denominator(x)),v_))**0.5).astype(int)
    m = m//reduce(np.gcd,m)

    with np.errstate(invalid='ignore'):
        if not np.allclose(np.ma.masked_invalid(v_/m),v_[np.argmax(abs(v_))]/m[np.argmax(abs(v_))]):
            raise ValueError(f'invalid result "{m}" for input "{v_}"')

    return m


def project_equal_angle(vector: np.ndarray,
                        direction: Literal['x', 'y', 'z'] = 'z',
                        normalize: bool = True,
                        keepdims: bool = False) -> np.ndarray:
    """
    Apply equal-angle projection to vector.

    Parameters
    ----------
    vector : numpy.ndarray, shape (...,3)
        Vector coordinates to be projected.
    direction : {'x', 'y', 'z'}
        Projection direction. Defaults to 'z'.
    normalize : bool
        Ensure unit length of input vector. Defaults to True.
    keepdims : bool
        Maintain three-dimensional output coordinates.
        Defaults to False.

    Returns
    -------
    coordinates : numpy.ndarray, shape (...,2 | 3)
        Projected coordinates.

    Notes
    -----
    Two-dimensional output uses right-handed frame spanned by
    the next and next-next axis relative to the projection direction,
    e.g. x-y when projecting along z and z-x when projecting along y.

    Examples
    --------
    >>> import damask
    >>> import numpy as np
    >>> project_equal_angle(np.ones(3))
        [0.3660254, 0.3660254]
    >>> project_equal_angle(np.ones(3),direction='x',normalize=False,keepdims=True)
        [0, 0.5, 0.5]
    >>> project_equal_angle([0,1,1],direction='y',normalize=True,keepdims=False)
        [0.41421356, 0]

    """
    shift = 'zyx'.index(direction)
    v = np.roll(vector/np.linalg.norm(vector,axis=-1,keepdims=True) if normalize else vector,
                shift,axis=-1)
    return np.roll(np.block([v[...,:2]/(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
                   -shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]

def project_equal_area(vector: np.ndarray,
                       direction: Literal['x', 'y', 'z'] = 'z',
                       normalize: bool = True,
                       keepdims: bool = False) -> np.ndarray:
    """
    Apply equal-area projection to vector.

    Parameters
    ----------
    vector : numpy.ndarray, shape (...,3)
        Vector coordinates to be projected.
    direction : {'x', 'y', 'z'}
        Projection direction. Defaults to 'z'.
    normalize : bool
        Ensure unit length of input vector. Defaults to True.
    keepdims : bool
        Maintain three-dimensional output coordinates.
        Defaults to False.

    Returns
    -------
    coordinates : numpy.ndarray, shape (...,2 | 3)
        Projected coordinates.

    Notes
    -----
    Two-dimensional output uses right-handed frame spanned by
    the next and next-next axis relative to the projection direction,
    e.g. x-y when projecting along z and z-x when projecting along y.


    Examples
    --------
    >>> import damask
    >>> import numpy as np
    >>> project_equal_area(np.ones(3))
        [0.45970084, 0.45970084]
    >>> project_equal_area(np.ones(3),direction='x',normalize=False,keepdims=True)
        [0.0, 0.70710678, 0.70710678]
    >>> project_equal_area([0,1,1],direction='y',normalize=True,keepdims=False)
        [0.5411961, 0.0]

    """
    shift = 'zyx'.index(direction)
    v = np.roll(vector/np.linalg.norm(vector,axis=-1,keepdims=True) if normalize else vector,
                shift,axis=-1)
    return np.roll(np.block([v[...,:2]/np.sqrt(1.0+np.abs(v[...,2:3])),np.zeros_like(v[...,2:3])]),
                   -shift if keepdims else 0,axis=-1)[...,:3 if keepdims else 2]

def execution_stamp(class_name: str,
                    function_name: str = None) -> str:
    """Timestamp the execution of a (function within a) class."""
    now = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S%z')
    _function_name = '' if function_name is None else f'.{function_name}'
    return f'damask.{class_name}{_function_name} v{version} ({now})'


def hybrid_IA(dist: np.ndarray,
              N: int,
              rng_seed: NumpyRngSeed = None) -> np.ndarray:
    """
    Hybrid integer approximation.

    Parameters
    ----------
    dist : numpy.ndarray
        Distribution to be approximated
    N : int
        Number of samples to draw.
    rng_seed : {None, int, array_like[ints], SeedSequence, BitGenerator, Generator}, optional
        A seed to initialize the BitGenerator. Defaults to None.
        If None, then fresh, unpredictable entropy will be pulled from the OS.

    """
    N_opt_samples,N_inv_samples = (max(np.count_nonzero(dist),N),0)                                 # random subsampling if too little samples requested

    scale_,scale,inc_factor = (0.0,float(N_opt_samples),1.0)
    while (not np.isclose(scale, scale_)) and (N_inv_samples != N_opt_samples):
        repeats = np.rint(scale*dist).astype(int)
        N_inv_samples = np.sum(repeats)
        scale_,scale,inc_factor = (scale,scale+inc_factor*0.5*(scale - scale_), inc_factor*2.0) \
                                   if N_inv_samples < N_opt_samples else \
                                  (scale_,0.5*(scale_ + scale), 1.0)

    return np.repeat(np.arange(len(dist)),repeats)[np.random.default_rng(rng_seed).permutation(N_inv_samples)[:N]]


def shapeshifter(fro: Tuple[int, ...],
                 to: Tuple[int, ...],
                 mode: Literal['left','right'] = 'left',
                 keep_ones: bool = False) -> Tuple[int, ...]:
    """
    Return dimensions that reshape 'fro' to become broadcastable to 'to'.

    Parameters
    ----------
    fro : tuple
        Original shape of array.
    to : tuple
        Target shape of array after broadcasting.
        len(to) cannot be less than len(fro).
    mode : {'left', 'right'}, optional
        Indicates whether new axes are preferably added to
        either left or right of the original shape.
        Defaults to 'left'.
    keep_ones : bool, optional
        Treat '1' in fro as literal value instead of dimensional placeholder.
        Defaults to False.

    Returns
    -------
    new_dims : tuple
        Dimensions for reshape.

    Example
    -------
    >>> import numpy as np
    >>> from damask import util
    >>> a = np.ones((3,4,2))
    >>> b = np.ones(4)
    >>> b_extended = b.reshape(util.shapeshifter(b.shape,a.shape))
    >>> (a * np.broadcast_to(b_extended,a.shape)).shape
    (3,4,2)


    """
    if len(fro) == 0 and len(to) == 0: return ()

    beg = dict(left ='(^.*\\b)',
               right='(^.*?\\b)')
    sep = dict(left ='(.*\\b)',
               right='(.*?\\b)')
    end = dict(left ='(.*?$)',
               right='(.*$)')
    fro = (1,) if len(fro) == 0 else fro
    to  = (1,) if len(to) == 0 else to
    try:
        match = re.match(beg[mode]
                      +f',{sep[mode]}'.join(map(lambda x: f'{x}'
                                                          if x>1 or (keep_ones and len(fro)>1) else
                                                          '\\d+',fro))
                      +f',{end[mode]}',','.join(map(str,to))+',')
        assert match
        grp = match.groups()
    except AssertionError:
        raise ValueError(f'shapes cannot be shifted {fro} --> {to}')
    fill: Any = ()
    for g,d in zip(grp,fro+(None,)):
        fill += (1,)*g.count(',')+(d,)
    return fill[:-1]


def shapeblender(a: Tuple[int, ...],
                 b: Tuple[int, ...]) -> Tuple[int, ...]:
    """
    Return a shape that overlaps the rightmost entries of 'a' with the leftmost of 'b'.

    Parameters
    ----------
    a : tuple
        Shape of first array.
    b : tuple
        Shape of second array.

    Examples
    --------
    >>> shapeblender((4,4,3),(3,2,1))
        (4,4,3,2,1)
    >>> shapeblender((1,2),(1,2,3))
        (1,2,3)
    >>> shapeblender((1,),(2,2,1))
        (1,2,2,1)
    >>> shapeblender((3,2),(3,2))
        (3,2)

    """
    i = min(len(a),len(b))
    while i > 0 and a[-i:] != b[:i]: i -= 1
    return a + b[i:]


def extend_docstring(extra_docstring: str) -> Callable:
    """
    Decorator: Append to function's docstring.

    Parameters
    ----------
    extra_docstring : str
       Docstring to append.

    """
    def _decorator(func):
        func.__doc__ += extra_docstring
        return func
    return _decorator


def extended_docstring(f: Callable,
                       extra_docstring: str) -> Callable:
    """
    Decorator: Combine another function's docstring with a given docstring.

    Parameters
    ----------
    f : function
       Function of which the docstring is taken.
    extra_docstring : str
       Docstring to append.

    """
    def _decorator(func):
        func.__doc__ = f.__doc__ + extra_docstring
        return func
    return _decorator


def DREAM3D_base_group(fname: Union[str, Path]) -> str:
    """
    Determine the base group of a DREAM.3D file.

    The base group is defined as the group (folder) that contains
    a 'SPACING' dataset in a '_SIMPL_GEOMETRY' group.

    Parameters
    ----------
    fname : str or pathlib.Path
        Filename of the DREAM.3D (HDF5) file.

    Returns
    -------
    path : str
        Path to the base group.

    """
    with h5py.File(fname,'r') as f:
        base_group = f.visit(lambda path: path.rsplit('/',2)[0] if '_SIMPL_GEOMETRY/SPACING' in path else None)

    if base_group is None:
        raise ValueError(f'could not determine base group in file "{fname}"')

    return base_group

def DREAM3D_cell_data_group(fname: Union[str, Path]) -> str:
    """
    Determine the cell data group of a DREAM.3D file.

    The cell data group is defined as the group (folder) that contains
    a dataset in the base group whose length matches the total number
    of points as specified in '_SIMPL_GEOMETRY/DIMENSIONS'.

    Parameters
    ----------
    fname : str or pathlib.Path
        Filename of the DREAM.3D (HDF5) file.

    Returns
    -------
    path : str
        Path to the cell data group.

    """
    base_group = DREAM3D_base_group(fname)
    with h5py.File(fname,'r') as f:
        cells = tuple(f['/'.join([base_group,'_SIMPL_GEOMETRY','DIMENSIONS'])][()][::-1])
        cell_data_group = f[base_group].visititems(lambda path,obj: path.split('/')[0] \
                                                   if isinstance(obj,h5py._hl.dataset.Dataset) and np.shape(obj)[:-1] == cells \
                                                   else None)

    if cell_data_group is None:
        raise ValueError(f'could not determine cell-data group in file "{fname}/{base_group}"')

    return cell_data_group


def Bravais_to_Miller(*,
                      uvtw: np.ndarray = None,
                      hkil: np.ndarray = None) -> np.ndarray:
    """
    Transform 4 Miller–Bravais indices to 3 Miller indices of crystal direction [uvw] or plane normal (hkl).

    Parameters
    ----------
    uvtw|hkil : numpy.ndarray, shape (...,4)
        Miller–Bravais indices of crystallographic direction [uvtw] or plane normal (hkil).

    Returns
    -------
    uvw|hkl : numpy.ndarray, shape (...,3)
        Miller indices of [uvw] direction or (hkl) plane normal.

    """
    if (uvtw is not None) ^ (hkil is None):
        raise KeyError('specify either "uvtw" or "hkil"')
    axis,basis  = (np.array(uvtw),np.array([[1,0,-1,0],
                                            [0,1,-1,0],
                                            [0,0, 0,1]])) \
                  if hkil is None else \
                  (np.array(hkil),np.array([[1,0,0,0],
                                            [0,1,0,0],
                                            [0,0,0,1]]))
    return np.einsum('il,...l',basis,axis)


def Miller_to_Bravais(*,
                      uvw: np.ndarray = None,
                      hkl: np.ndarray = None) -> np.ndarray:
    """
    Transform 3 Miller indices to 4 Miller–Bravais indices of crystal direction [uvtw] or plane normal (hkil).

    Parameters
    ----------
    uvw|hkl : numpy.ndarray, shape (...,3)
        Miller indices of crystallographic direction [uvw] or plane normal (hkl).

    Returns
    -------
    uvtw|hkil : numpy.ndarray, shape (...,4)
        Miller–Bravais indices of [uvtw] direction or (hkil) plane normal.

    """
    if (uvw is not None) ^ (hkl is None):
        raise KeyError('specify either "uvw" or "hkl"')
    axis,basis  = (np.array(uvw),np.array([[ 2,-1, 0],
                                           [-1, 2, 0],
                                           [-1,-1, 0],
                                           [ 0, 0, 3]])/3) \
                  if hkl is None else \
                  (np.array(hkl),np.array([[ 1, 0, 0],
                                           [ 0, 1, 0],
                                           [-1,-1, 0],
                                           [ 0, 0, 1]]))
    return np.einsum('il,...l',basis,axis)


def dict_prune(d: Dict) -> Dict:
    """
    Recursively remove empty dictionaries.

    Parameters
    ----------
    d : dict
        Dictionary to prune.

    Returns
    -------
    pruned : dict
        Pruned dictionary.

    """
    # https://stackoverflow.com/questions/48151953
    new = {}
    for k,v in d.items():
        if isinstance(v, dict):
            v = dict_prune(v)
        if not isinstance(v,dict) or v != {}:
            new[k] = v

    return new


def dict_flatten(d: Dict) -> Dict:
    """
    Recursively remove keys of single-entry dictionaries.

    Parameters
    ----------
    d : dict
        Dictionary to flatten.

    Returns
    -------
    flattened : dict
        Flattened dictionary.

    """
    if isinstance(d,dict) and len(d) == 1:
        entry = d[list(d.keys())[0]]
        new = dict_flatten(entry.copy()) if isinstance(entry,dict) else entry
    else:
        new = {k: (dict_flatten(v) if isinstance(v, dict) else v) for k,v in d.items()}

    return new

def tbd(arg) -> List:
    if arg is None:
        return []
    elif isinstance(arg,(np.ndarray,Collection)):
        return list(arg)
    else:
        return [arg]

####################################################################################################
# Classes
####################################################################################################
class ProgressBar:
    """
    Report progress of an interation as a status bar.

    Works for 0-based loops, ETA is estimated by linear extrapolation.
    """

    def __init__(self,
                 total: int,
                 prefix: str,
                 bar_length: int):
        """
        Set current time as basis for ETA estimation.

        Parameters
        ----------
        total : int
            Total # of iterations.
        prefix : str
            Prefix string.
        bar_length : int
            Character length of bar.

        """
        self.total = total
        self.prefix = prefix
        self.bar_length = bar_length
        self.time_start = self.time_last_update = datetime.datetime.now()
        self.fraction_last = 0.0

        sys.stderr.write(f"{self.prefix} {'░'*self.bar_length}   0% ETA n/a")
        sys.stderr.flush()

    def update(self,
               iteration: int) -> None:

        fraction = (iteration+1) / self.total

        if (filled_length := int(self.bar_length * fraction)) > int(self.bar_length * self.fraction_last) or \
            datetime.datetime.now() - self.time_last_update > datetime.timedelta(seconds=10):
            self.time_last_update = datetime.datetime.now()
            bar = '█' * filled_length + '░' * (self.bar_length - filled_length)
            remaining_time = (datetime.datetime.now() - self.time_start) \
                           * (self.total - (iteration+1)) / (iteration+1)
            remaining_time -= datetime.timedelta(microseconds=remaining_time.microseconds)          # remove μs
            sys.stderr.write(f'\r{self.prefix} {bar} {fraction:>4.0%} ETA {remaining_time}')
            sys.stderr.flush()

        self.fraction_last = fraction

        if iteration == self.total - 1:
            sys.stderr.write('\n')
            sys.stderr.flush()