"""Utilities used to generate various figures in the documentation."""
from itertools import product

import numpy as np
from matplotlib import pyplot as plt

from ._dwt import pad

__all__ = ['wavedec_keys', 'wavedec2_keys', 'draw_2d_wp_basis',
           'draw_2d_fswavedecn_basis', 'boundary_mode_subplot']


def wavedec_keys(level):
    """Subband keys corresponding to a wavedec decomposition."""
    approx = ''
    coeffs = {}
    for lev in range(level):
        for k in ['a', 'd']:
            coeffs[approx + k] = None
        approx = 'a' * (lev + 1)
        if lev < level - 1:
            coeffs.pop(approx)
    return list(coeffs.keys())


def wavedec2_keys(level):
    """Subband keys corresponding to a wavedec2 decomposition."""
    approx = ''
    coeffs = {}
    for lev in range(level):
        for k in ['a', 'h', 'v', 'd']:
            coeffs[approx + k] = None
        approx = 'a' * (lev + 1)
        if lev < level - 1:
            coeffs.pop(approx)
    return list(coeffs.keys())


def _box(bl, ur):
    """(x, y) coordinates for the 4 lines making up a rectangular box.

    Parameters
    ==========
    bl : float
        The bottom left corner of the box
    ur : float
        The upper right corner of the box

    Returns
    =======
    coords : 2-tuple
        The first and second elements of the tuple are the x and y coordinates
        of the box.
    """
    xl, xr = bl[0], ur[0]
    yb, yt = bl[1], ur[1]
    box_x = [xl, xr,
             xr, xr,
             xr, xl,
             xl, xl]
    box_y = [yb, yb,
             yb, yt,
             yt, yt,
             yt, yb]
    return (box_x, box_y)


def _2d_wp_basis_coords(shape, keys):
    # Coordinates of the lines to be drawn by draw_2d_wp_basis
    coords = []
    centers = {}  # retain center of boxes for use in labeling
    for key in keys:
        offset_x = offset_y = 0
        for n, char in enumerate(key):
            if char in ['h', 'd']:
                offset_x += shape[0] // 2**(n + 1)
            if char in ['v', 'd']:
                offset_y += shape[1] // 2**(n + 1)
        sx = shape[0] // 2**(n + 1)
        sy = shape[1] // 2**(n + 1)
        xc, yc = _box((offset_x, -offset_y),
                      (offset_x + sx, -offset_y - sy))
        coords.append((xc, yc))
        centers[key] = (offset_x + sx // 2, -offset_y - sy // 2)
    return coords, centers


def draw_2d_wp_basis(shape, keys, fmt='k', plot_kwargs={}, ax=None,
                     label_levels=0):
    """Plot a 2D representation of a WaveletPacket2D basis."""
    coords, centers = _2d_wp_basis_coords(shape, keys)
    if ax is None:
        fig, ax = plt.subplots(1, 1)
    else:
        fig = ax.get_figure()
    for coord in coords:
        ax.plot(coord[0], coord[1], fmt)
    ax.set_axis_off()
    ax.axis('square')
    if label_levels > 0:
        for key, c in centers.items():
            if len(key) <= label_levels:
                ax.text(c[0], c[1], key,
                        horizontalalignment='center',
                        verticalalignment='center')
    return fig, ax


def _2d_fswavedecn_coords(shape, levels):
    coords = []
    centers = {}  # retain center of boxes for use in labeling
    for key in product(wavedec_keys(levels), repeat=2):
        (key0, key1) = key
        offsets = [0, 0]
        widths = list(shape)
        for n0, char in enumerate(key0):
            if char in ['d']:
                offsets[0] += shape[0] // 2**(n0 + 1)
        for n1, char in enumerate(key1):
            if char in ['d']:
                offsets[1] += shape[1] // 2**(n1 + 1)
        widths[0] = shape[0] // 2**(n0 + 1)
        widths[1] = shape[1] // 2**(n1 + 1)
        xc, yc = _box((offsets[0], -offsets[1]),
                      (offsets[0] + widths[0], -offsets[1] - widths[1]))
        coords.append((xc, yc))
        centers[(key0, key1)] = (offsets[0] + widths[0] / 2,
                                 -offsets[1] - widths[1] / 2)
    return coords, centers


def draw_2d_fswavedecn_basis(shape, levels, fmt='k', plot_kwargs={}, ax=None,
                             label_levels=0):
    """Plot a 2D representation of a WaveletPacket2D basis."""
    coords, centers = _2d_fswavedecn_coords(shape, levels)
    if ax is None:
        fig, ax = plt.subplots(1, 1)
    else:
        fig = ax.get_figure()
    for coord in coords:
        ax.plot(coord[0], coord[1], fmt)
    ax.set_axis_off()
    ax.axis('square')
    if label_levels > 0:
        for key, c in centers.items():
            lev = np.max([len(k) for k in key])
            if lev <= label_levels:
                ax.text(c[0], c[1], key,
                        horizontalalignment='center',
                        verticalalignment='center')
    return fig, ax


def boundary_mode_subplot(x, mode, ax, symw=True):
    """Plot an illustration of the boundary mode in a subplot axis."""

    # if odd-length, periodization replicates the last sample to make it even
    if mode == 'periodization' and len(x) % 2 == 1:
        x = np.concatenate((x, (x[-1], )))

    npad = 2 * len(x)
    t = np.arange(len(x) + 2 * npad)
    xp = pad(x, (npad, npad), mode=mode)

    ax.plot(t, xp, 'k.')
    ax.set_title(mode)

    # plot the original signal in red
    if mode == 'periodization':
        ax.plot(t[npad:npad + len(x) - 1], x[:-1], 'r.')
    else:
        ax.plot(t[npad:npad + len(x)], x, 'r.')

    # add vertical bars indicating points of symmetry or boundary extension
    o2 = np.ones(2)
    left = npad
    if symw:
        step = len(x) - 1
        rng = range(-2, 4)
    else:
        left -= 0.5
        step = len(x)
        rng = range(-2, 4)
    if mode in ['smooth', 'constant', 'zero']:
        rng = range(0, 2)
    for rep in rng:
        ax.plot((left + rep * step) * o2, [xp.min() - .5, xp.max() + .5], 'k-')