"""
Class for monitoring performance on validation data.

TODO:
    - [ ] Implement algorithm from dlib
    http://blog.dlib.net/2018/02/automatic-learning-rate-scheduling-that.html
"""
from netharn import util
import itertools as it
import numpy as np
import ubelt as ub

__all__ = ['Monitor']


def demodata_monitor():
    rng = np.random.RandomState(0)
    n = 300
    losses = (sorted(rng.randint(10, n, size=n)) + rng.randint(0, 20, size=n) - 10)[::-1]
    mious = (sorted(rng.randint(10, n, size=n)) + rng.randint(0, 20, size=n) - 10)
    monitor = Monitor(minimize=['loss'], maximize=['miou'], smoothing=.6)
    for epoch, (loss, miou) in enumerate(zip(losses, mious)):
        monitor.update(epoch, {'loss': loss, 'miou': miou})
    return monitor


class Monitor(object):
    """
    Monitors an instance of FitHarn as it trains. Makes sure that measurements
    of quality (e.g. loss, accuracy, AUC, mAP, etc...) on the validation
    dataset continues to go do (or at least isn't increasing), and stops
    training early if certain conditions are met.

    Attributes:
        minimize (List[str]): measures where a lower is better
        maximize (List[str]): measures where a higher is better
        smoothing (float): smoothness factor for the moving averages
        max_epoch (int, default=1000): number of epochs to stop after
        patience (int, default=None): if specified, the number of epochs
            to wait before quiting if the quality metrics are not improving.

    Example:
        >>> # simulate loss going down and then overfitting
        >>> from netharn.monitor import *
        >>> rng = np.random.RandomState(0)
        >>> n = 300
        >>> losses = (sorted(rng.randint(10, n, size=n)) + rng.randint(0, 20, size=n) - 10)[::-1]
        >>> mious = (sorted(rng.randint(10, n, size=n)) + rng.randint(0, 20, size=n) - 10)
        >>> monitor = Monitor(minimize=['loss'], maximize=['miou'], smoothing=.6)
        >>> for epoch, (loss, miou) in enumerate(zip(losses, mious)):
        >>>     monitor.update(epoch, {'loss': loss, 'miou': miou})
        >>> # xdoctest: +REQUIRES(--show)
        >>> monitor.show()
    """

    def __init__(monitor, minimize=['loss'], maximize=[], smoothing=.6,
                 patience=None, max_epoch=1000):

        # Internal attributes
        monitor._ewma = util.ExpMovingAve(alpha=1 - smoothing)
        monitor._raw_metrics = []
        monitor._smooth_metrics = []
        monitor._epochs = []
        monitor._is_good = []

        # Bookkeeping
        monitor._best_raw_metrics = None
        monitor._best_smooth_metrics = None
        monitor._best_epoch = None
        monitor._n_bad_epochs = 0

        # Keep track of which metrics we want to maximize / minimize
        monitor.minimize = minimize
        monitor.maximize = maximize

        # early stopping
        monitor.patience = patience
        monitor.max_epoch = max_epoch

    @classmethod
    def coerce(cls, config, **kw):
        """
        Accepts keywords 'max_epoch' and 'patience'
        """
        from netharn.api import _update_defaults
        config = _update_defaults(config, kw)
        max_epoch = config.get('max_epoch', 100)
        return (cls, {
            'minimize': ['loss'],
            'patience': config.get('patience', max_epoch),
            'max_epoch': max_epoch,
        })

    def show(monitor):
        """
        Draws the monitored metrics using matplotlib
        """
        import matplotlib.pyplot as plt
        import kwplot
        import pandas as pd
        smooth_ydatas = pd.DataFrame.from_dict(monitor._smooth_metrics).to_dict('list')
        raw_ydatas = pd.DataFrame.from_dict(monitor._raw_metrics).to_dict('list')
        keys = monitor.minimize + monitor.maximize
        pnum_ = kwplot.PlotNums(nSubplots=len(keys))
        for i, key in enumerate(keys):
            kwplot.multi_plot(
                monitor._epochs, {'raw ' + key: raw_ydatas[key],
                                  'smooth ' + key: smooth_ydatas[key]},
                xlabel='epoch', ylabel=key, pnum=pnum_[i], fnum=1,
                # markers={'raw ' + key: '-', 'smooth ' + key: '--'},
                # colors={'raw ' + key: 'b', 'smooth ' + key: 'b'},
            )

            # star all the good epochs
            flags = np.array(monitor._is_good)
            if np.any(flags):
                plt.plot(list(ub.compress(monitor._epochs, flags)),
                         list(ub.compress(smooth_ydatas[key], flags)), 'b*')

    def __getstate__(monitor):
        state = monitor.__dict__.copy()
        _ewma = state.pop('_ewma')
        state['ewma_state'] = _ewma.__dict__
        return state

    def __setstate__(monitor, state):
        ewma_state = state.pop('ewma_state', None)
        if ewma_state is not None:
            monitor._ewma = util.ExpMovingAve()
            monitor._ewma.__dict__.update(ewma_state)
        monitor.__dict__.update(**state)

    def state_dict(monitor):
        """
        pytorch-like API. Alias for __getstate__
        """
        return monitor.__getstate__()

    def load_state_dict(monitor, state):
        """
        pytorch-like API. Alias for __setstate__

        Args:
            state (Dict):
        """
        return monitor.__setstate__(state)

    def update(monitor, epoch, _raw_metrics):
        """
        Informs the monitor about quality measurements for a particular epoch.

        Args:
            epoch (int):
                Current epoch number

            _raw_metrics (Dict[str, float]):
                Scalar values for each quality metric that was measured on this
                epoch.

        Returns:
            bool: improved:
                True if the model has quality of the validation metrics have
                improved.

        """
        monitor._epochs.append(epoch)
        monitor._raw_metrics.append(_raw_metrics)
        monitor._ewma.update(_raw_metrics)
        # monitor.other_data.append(other)

        _smooth_metrics = monitor._ewma.average()
        monitor._smooth_metrics.append(_smooth_metrics.copy())

        improved_keys = monitor._improved(_smooth_metrics, monitor._best_smooth_metrics)
        if improved_keys:
            if monitor._best_smooth_metrics is None:
                monitor._best_smooth_metrics = _smooth_metrics.copy()
                monitor._best_raw_metrics = _raw_metrics.copy()
            else:
                for key in improved_keys:
                    monitor._best_smooth_metrics[key] = _smooth_metrics[key]
                    monitor._best_raw_metrics[key] = _raw_metrics[key]
            monitor._best_epoch = epoch
            monitor._n_bad_epochs = 0
        else:
            monitor._n_bad_epochs += 1

        improved = len(improved_keys) > 0
        monitor._is_good.append(improved)
        return improved

    def _improved(monitor, metrics, best_metrics):
        """
        If any of the metrics we care about is improving then we are happy

        Returns:
            List[str]: list of the quality metrics that have improved

        Example:
            >>> from netharn.monitor import *
            >>> monitor = Monitor(['loss'], ['acc'])
            >>> metrics = {'loss': 5, 'acc': .99}
            >>> best_metrics = {'loss': 4, 'acc': .98}
            >>> monitor._improved(metrics, best_metrics)
            ['acc']
        """
        keys = monitor.maximize + monitor.minimize

        def _as_minimization(metrics):
            # convert to a minimization problem
            sign = np.array(([-1] * len(monitor.maximize)) +
                            ([1] * len(monitor.minimize)))
            chosen = np.array(list(ub.take(metrics, keys)))
            return chosen, sign

        current, sign1 = _as_minimization(metrics)

        if not best_metrics:
            return keys

        best, sign2 = _as_minimization(best_metrics)

        # TODO: also need to see if anything got significantly worse

        # only use threshold rel mode
        monitor.rel_threshold = 1e-6
        rel_epsilon = 1.0 - monitor.rel_threshold
        improved_flags = (sign1 * current) < (rel_epsilon * sign2 * best)
        # * rel_epsilon

        improved_keys = list(ub.compress(keys, improved_flags))
        return improved_keys

    def is_done(monitor):
        """
        Returns True if the termination criterion is satisfied

        Returns:
            bool: if training should be stopped

        Example:
            >>> from netharn.monitor import *
            >>> Monitor().is_done()
            False
            >>> Monitor(patience=0).is_done()
            True
        """
        if monitor.patience is None:
            return False
        return monitor._n_bad_epochs >= monitor.patience

    def message(monitor, ansi=True):
        """
        A status message with optional ANSI coloration

        Args:
            ansi (bool, default=True): if False disables ANSI coloration

        Returns:
            str: message for logging

        Example:
            >>> from netharn.monitor import *
            >>> monitor = Monitor()
            >>> print(monitor.message(ansi=False))
            vloss is unevaluated
            >>> monitor.update(0, {'loss': 1.0})
            >>> print(monitor.message(ansi=False))
            vloss: 1.0000 (n_bad=00, best=1.0000)
            >>> monitor.update(0, {'loss': 2.0})
            >>> print(monitor.message(ansi=False))
            vloss: 1.4000 (n_bad=01, best=1.0000)
            >>> monitor.update(0, {'loss': 0.1})
            >>> print(monitor.message(ansi=False))
            vloss: 0.8800 (n_bad=00, best=0.8800)
        """
        if not monitor._epochs:
            message = 'vloss is unevaluated'
            if ansi:
                message = ub.color_text(message, 'blue')
        else:
            prev_loss = monitor._smooth_metrics[-1]['loss']
            best_loss = monitor._best_smooth_metrics['loss']

            message = 'vloss: {:.4f} (n_bad={:02d}, best={:.4f})'.format(
                prev_loss, monitor._n_bad_epochs, best_loss,
            )
            if monitor.patience is None:
                patience = monitor.max_epoch
            else:
                patience = monitor.patience
            if ansi:
                if monitor._n_bad_epochs <= int(patience * .25):
                    message = ub.color_text(message, 'green')
                elif monitor._n_bad_epochs >= int(patience * .75):
                    message = ub.color_text(message, 'red')
                else:
                    message = ub.color_text(message, 'yellow')
        return message

    def best_epochs(monitor, num=None, smooth=True):
        """
        Returns the best `num` epochs for every metric.

        Args:
            num (int, default=None):
                Number of top epochs to return. If not specified then all are
                returned.

            smooth (bool, default=True):
                Uses smoothed metrics if True otherwise uses the raw metrics.

        Returns:
            Dict[str, ndarray]: epoch numbers for all of the best epochs

        Example:
            >>> monitor = demodata_monitor()
            >>> metric_ranks = monitor.best_epochs(5)
            >>> print(ub.repr2(metric_ranks, with_dtype=False, nl=1))
            {
                'loss': np.array([297, 299, 298, 296, 295]),
                'miou': np.array([299, 298, 297, 296, 295]),
            }
        """
        metric_ranks = {}
        for key in it.chain(monitor.minimize, monitor.maximize):
            metric_ranks[key] = monitor._rank(key, smooth=smooth)[:num]
        return metric_ranks

    def _rank(monitor, key, smooth=True):
        """
        Ranks the best epochs from best to worst for each metric

        Example:
            >>> monitor = demodata_monitor()
            >>> ranked_epochs = monitor._rank('loss', smooth=False)
            >>> ranked_epochs = monitor._rank('miou', smooth=True)
        """
        if smooth:
            metrics = monitor._smooth_metrics
        else:
            metrics = monitor._raw_metrics

        values = [m[key] for m in metrics]
        sortx = np.argsort(values)
        if key in monitor.maximize:
            sortx = np.argsort(values)[::-1]
        elif key in monitor.minimize:
            sortx = np.argsort(values)
        else:
            raise KeyError(type)
        ranked_epochs = np.array(monitor._epochs)[sortx]
        return ranked_epochs

    def _BROKEN_rank_epochs(monitor):
        """
        FIXME:
            broken - implement better rank aggregation with custom weights

        Example:
            >>> monitor = demodata_monitor()
            >>> monitor._BROKEN_rank_epochs()
        """
        rankings = {}
        for key, value in monitor.best_epochs(smooth=False).items():
            rankings[key + '_raw'] = value

        for key, value in monitor.best_epochs(smooth=True).items():
            rankings[key + '_smooth'] = value

        # borda-like weighted rank aggregation.
        # probably could do something better.
        epoch_to_weight = ub.ddict(lambda: 0)
        for key, ranking in rankings.items():
            # weights = np.linspace(0, 1, num=len(ranking))[::-1]
            weights = np.logspace(0, 2, num=len(ranking))[::-1] / 100
            for epoch, w in zip(ranking, weights):
                epoch_to_weight[epoch] += w

        agg_ranking = ub.argsort(epoch_to_weight)[::-1]
        return agg_ranking

if __name__ == '__main__':
    """
    CommandLine:
        xdoctest -m netharn.monitor all
    """
    import xdoctest
    xdoctest.doctest_module(__file__)
