# --------------------------------------------------------
# R-C3D
# Copyright (c) 2017 Boston University
# Licensed under The MIT License [see LICENSE for details]
# Written by Huijuan Xu
# --------------------------------------------------------

"""Transform a roidb into a trainable roidb by adding a bunch of metadata."""

from __future__ import print_function
import numpy as np
from tdcnn.config import cfg
from tdcnn.twin_transform import twin_transform
from utils.cython_twin import twin_overlaps
import PIL

def prepare_roidb(imdb):
    """Enrich the video database's roidb by adding some derived quantities that
    are useful for training. This function precomputes the maximum overlap,
    taken over ground-truth windows, between each ROI and each ground-truth
    window. The class with maximum overlap is also recorded.
    """
    roidb = imdb.roidb
    for i in xrange(len(imdb.video_index)):
        roidb[i]['video'] = imdb.video_path_at(i)
        roidb[i]['start_frame'] = imdb.start_frame(i)
        # need gt_overlaps as a dense array for argmax
        gt_overlaps = roidb[i]['gt_overlaps'].toarray()
        # max overlap with gt over classes (columns)
        max_overlaps = gt_overlaps.max(axis=1)
        # gt class that had the max overlap
        max_classes = gt_overlaps.argmax(axis=1)
        roidb[i]['max_classes'] = max_classes
        roidb[i]['max_overlaps'] = max_overlaps
        # sanity checks
        # max overlap of 0 => class should be zero (background)
        zero_inds = np.where(max_overlaps == 0)[0]
        assert all(max_classes[zero_inds] == 0)
        # max overlap > 0 => class should not be zero (must be a fg class)
        nonzero_inds = np.where(max_overlaps > 0)[0]
        assert all(max_classes[nonzero_inds] != 0)

def add_twin_regression_targets(roidb):
    """Add information needed to train time window regressors."""
    assert len(roidb) > 0
    assert 'max_classes' in roidb[0], 'Did you call prepare_roidb first?'

    # Infer number of classes from the number of columns in gt_overlaps
    num_classes = cfg.NUM_CLASSES
    for item in roidb:
        rois = item['wins']
        max_overlaps = item['max_overlaps']
        max_classes = item['max_classes']
        item['twin_targets'] = \
                _compute_targets(rois, max_overlaps, max_classes)

    if cfg.TRAIN.TWIN_NORMALIZE_TARGETS_PRECOMPUTED:
        # Use fixed / precomputed "means" and "stds" instead of empirical values
        means = np.tile(
                np.array(cfg.TRAIN.TWIN_NORMALIZE_MEANS), (num_classes, 1))
        stds = np.tile(
                np.array(cfg.TRAIN.TWIN_NORMALIZE_STDS), (num_classes, 1))
    else:
        # Compute values needed for means and stds
        # var(x) = E(x^2) - E(x)^2
        class_counts = np.zeros((num_classes, 1)) + cfg.EPS
        sums = np.zeros((num_classes, 2))
        squared_sums = np.zeros((num_classes, 2))
        for item in roidb:
            targets = item['twin_targets']
            for cls in xrange(1, num_classes):
                cls_inds = np.where(targets[:, 0] == cls)[0]
                if cls_inds.size > 0:
                    class_counts[cls] += cls_inds.size
                    sums[cls, :] += targets[cls_inds, 1:].sum(axis=0)
                    squared_sums[cls, :] += \
                            (targets[cls_inds, 1:] ** 2).sum(axis=0)

        means = sums / class_counts
        stds = np.sqrt(squared_sums / class_counts - means ** 2)

    print('twin target means:')
    print(means)
    print(means[1:, :].mean(axis=0)) # ignore bg class
    print('twin target stdevs:')
    print(stds)
    print(stds[1:, :].mean(axis=0)) # ignore bg class

    # Normalize targets
    if cfg.TRAIN.TWIN_NORMALIZE_TARGETS:
        print("Normalizing targets")
        for item in roidb:
            targets = item['twin_targets']
            for cls in range(1, num_classes):
                cls_inds = np.where(targets[:, 0] == cls)[0]
                item['twin_targets'][cls_inds, 1:] -= means[cls, :]
                item['twin_targets'][cls_inds, 1:] /= stds[cls, :]
    else:
        print("NOT normalizing targets")

    # These values will be needed for making predictions
    # (the predicts will need to be unnormalized and uncentered)
    return means.ravel(), stds.ravel()

def _compute_targets(rois, overlaps, labels):
    """Compute bounding-box regression targets for an image."""
    # Indices of ground-truth ROIs
    gt_inds = np.where(overlaps == 1)[0]
    if len(gt_inds) == 0:
        # Bail if the image has no ground-truth ROIs
        return np.zeros((rois.shape[0], 3), dtype=np.float32)
    # Indices of examples for which we try to make predictions
    ex_inds = np.where(overlaps >= cfg.TRAIN.TWIN_THRESH)[0]

    # Get IoU overlap between each ex ROI and gt ROI
    ex_gt_overlaps = twin_overlaps(
        np.ascontiguousarray(rois[ex_inds, :], dtype=np.float),
        np.ascontiguousarray(rois[gt_inds, :], dtype=np.float))

    # Find which gt ROI each ex ROI has max overlap with:
    # this will be the ex ROI's gt target
    gt_assignment = ex_gt_overlaps.argmax(axis=1)
    gt_rois = rois[gt_inds[gt_assignment], :]
    ex_rois = rois[ex_inds, :]

    targets = np.zeros((rois.shape[0], 3), dtype=np.float32)
    targets[ex_inds, 0] = labels[ex_inds]
    targets[ex_inds, 1:] = twin_transform(ex_rois, gt_rois)
    return targets
