#!/usr/bin/env python
"""
ckwg +31
Copyright 2019 by Kitware, Inc.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

 * Redistributions of source code must retain the above copyright notice,
   this list of conditions and the following disclaimer.

 * Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

 * Neither name of Kitware, Inc. nor the names of any contributors may be used
   to endorse or promote products derived from this software without specific
   prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS''
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

==============================================================================

Library handling projection operations of a standard camera model.

"""
from __future__ import division, print_function
import os
import numpy as np
import cv2
import StringIO
from PIL import Image as PILImage
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
import glob
from scipy import interpolate

# ROS imports
from rosbag import Bag
from cv_bridge import CvBridge
from tf.transformations import quaternion_multiply, quaternion_matrix, \
    quaternion_from_euler, quaternion_inverse, euler_matrix, quaternion_slerp

# ADAPT imports.
from sensor_models.nav_conversions import enu_to_llh, llh_to_enu
from sensor_models.nav_state import NavStateINS, trajectory_to_kmz
from sensor_models.utilities import rotation_between_quats, horn
from colmap_processing.colmap_interface import read_images_binary, \
    read_points3d_binary, read_cameras_binary, qvec2rotmat, \
    standard_cameras_from_colmap


bridge = CvBridge()

# Bag information
bag_fname = '/home/user/adapt_ws/data/2021-07-01-16-36-40.bag'
imagery_topic = '/image_raw'
imu_topic = '/an_device/Imu'
gps_topic = '/an_device/NavSatFix'
filter_status_topic = '/an_device/FilterStatus'
ins_system_status_topic = '/an_device/SystemStatus'
out_dir = os.path.splitext(bag_fname)[0]

# ----------------------------------------------------------------------------

# You should have a colmap directory where all of the Colmap-generated files
# reside.

colmap_dir = '/home/user/adapt_ws/data/2021-06-23-16-51-24_rs'

# Directory with all of the raw images.
colmap_images_subdir = 'images0'

# Sub-directory containing the images.bin and cameras.bin. Set to '' if in the
# top-level Colmap directory.
#sparse_recon_subdir = 'sparse/1'
sparse_recon_subdir = 'snapshots/1256684511'
aligned_sparse_recon_subdir = 'aligned'
# ----------------------------------------------------------------------------


image_dir = '%s/images' % out_dir
try:
    os.makedirs(image_dir)
except (IOError, OSError):
    pass


msgs = {}
images =  {}
system_status = {}
filter_status = {}
topics = set()
with Bag(bag_fname, 'r') as ib:
    for topic, msg, t in ib:
        topics.add(topic)

        if topic == ins_system_status_topic:
            system_status[t.to_sec()] = msg
            continue

        if topic == filter_status_topic:
            filter_status[t.to_sec()] = msg
            continue

        try:
            t = msg.header.stamp.to_sec()
        except AttributeError:
            continue

        if topic == imagery_topic:
            if hasattr(msg, 'format'):
                sio = StringIO.StringIO(msg.data)
                im = PILImage.open( sio )
                image = np.array( im )
            else:
                image = bridge.imgmsg_to_cv2(msg, "rgb8")

            frame_id = msg.header.frame_id

            #images[t] = image

            fname = '%s/%i.png' % (image_dir, int(np.round(t*1000000)))
            #cv2.imwrite(fname, image[:, :, ::-1])
        else:
            if topic not in msgs:
                msgs[topic] = {}

            msgs[topic][t] = msg


# --------------------- Read Existing Colmap Reconstruction ------------------
# Read in the Colmap details of all images.
images_bin_fname = '%s/%s/images.bin' % (colmap_dir, sparse_recon_subdir)
colmap_images = read_images_binary(images_bin_fname)

sfm_poses = []

for image_num in colmap_images:
        image = colmap_images[image_num]

        t = float(os.path.splitext(image.name)[0])/1000000

        # Query Colmaps pose for the camera.
        R = qvec2rotmat(image.qvec)
        pos = -np.dot(R.T, image.tvec)

        # The qvec used by Colmap is a (w, x, y, z) quaternion
        # representing the rotation of a vector defined in the world
        # coordinate system into the camera coordinate system. However,
        # the 'camera_models' module assumes (x, y, z, w) quaternions
        # representing a coordinate system rotation. Also, the quaternion
        # used by 'camera_models' represents a coordinate system rotation
        # versus the coordinate system transform of Colmap's convention,
        # so we need an inverse.

        #quat = transformations.quaternion_inverse(image.qvec)
        quat = image.qvec / np.linalg.norm(image.qvec)
        quat[0] = -quat[0]

        quat = [quat[1], quat[2], quat[3], quat[0]]

        sfm_pose = np.hstack([t, pos, quat])

        sfm_poses.append(sfm_pose)

sfm_poses = np.array(sfm_poses)

fig = plt.figure()
ax = plt.axes(projection='3d')
plt.plot(sfm_poses[:, 1], sfm_poses[:, 2], sfm_poses[:, 3])

ind = np.argsort(sfm_poses[:, 0])
sfm_poses = sfm_poses[ind]

fig = plt.figure()
ax = plt.axes(projection='3d')
plt.plot(sfm_poses[:, 1], sfm_poses[:, 2], sfm_poses[:, 3], 'ro')

fig = plt.figure()
plt.plot(sfm_poses[:, 0], sfm_poses[:, 1], 'k-')


# Visualize coordinate system.
plt.close('all')
quats = sfm_poses[::20, 4:]
rotation_matrices = [quaternion_matrix(quaternion_inverse(quat))[:3,:3]
                     for quat in quats]
draw_moving_coordinate_system(sfm_poses[::20, 1:4], rotation_matrices,
                              arrow_scale=5)


rodrigues_vec = []

# ----------------------------------------------------------------------------


# ------------------------------ Display Position ----------------------------
lats = []
lons = []
alts = []

pos_covs = []
pos_err_rad = []
pos_times = sorted(list(msgs[gps_topic].keys()))
for t in pos_times:

    msg = msgs[gps_topic][t]
    lats.append(msg.latitude)
    lons.append(msg.longitude)
    alts.append(msg.altitude)

    pos_covs.append(np.array(msg.position_covariance).reshape(3, 3))

pos_times = np.array(pos_times)
lats = np.array(lats)
lons = np.array(lons)
alts = np.array(alts)
pos_covs = np.array(pos_covs)

if False:
    # Write navigation trajectory to kmz file to view in Google Earth.
    filename = bag_fname.replace('.bag', '.kmz')
    trajectory_to_kmz(filename, pos_times, lats, lons, alts,
                      pos_covs=pos_covs, min_dt=0.25, name='flight_path')

lat0 = np.median(lats)
lon0 = np.median(lons)
alt0 = np.min(alts)

# We take the INS-reported position (converted from latitude, longitude, and
# altitude into easting/northing/up coordinates) and assign it to each image.
print('Latiude of ENU coordinate system:', lat0, 'degrees')
print('Longitude of ENU coordinate system:', lon0, 'degrees')
print('Height above the WGS84 ellipsoid of the ENU coordinate system:', alt0,
      'meters')

enu = np.array([llh_to_enu(lats[i], lons[i], alts[i], lat0, lon0, alt0)
                for i in range(len(lats))]).T

fig = plt.figure()
ax = plt.axes(projection='3d')
plt.plot(enu[0], enu[1], enu[2])

fig = plt.figure()
plt.plot(pos_times - pos_times.min(), enu[1])

enu_v_time = interpolate.interp1d(pos_times, enu)
# ----------------------------------------------------------------------------


# ----------------------------------------------------------------------------

imu_times = sorted(list(msgs[imu_topic].keys()))
linear_accel = []
angular_vel = []
orientation = []
orientation_cov = []
for t in imu_times:
    msg = msgs[imu_topic][t]
    linear_accel.append([msg.linear_acceleration.x,
                         msg.linear_acceleration.y,
                         msg.linear_acceleration.z])
    angular_vel.append([msg.angular_velocity.x,
                        msg.angular_velocity.y,
                        msg.angular_velocity.z])
    orientation.append([msg.orientation.x, msg.orientation.y,
                        msg.orientation.z, msg.orientation.w])
    orientation_cov.append(np.array(msg.orientation_covariance).reshape(3, 3))

imu_times = np.array(imu_times)
linear_accel = np.array(linear_accel)
angular_vel = np.array(angular_vel)
orientation = np.array(orientation)
orientation_cov = np.array(orientation_cov)

plt.figure()
plt.subplot(311)
plt.plot(linear_accel[:, 0])
plt.ylabel('X-Acceleration (m/s^2)')
plt.subplot(312)
plt.plot(linear_accel[:, 1])
plt.ylabel('Y--Acceleration (m/s^2)')
plt.subplot(313)
plt.plot(linear_accel[:, 2])
plt.ylabel('Z-Acceleration (m/s^2)')


# Visualize coordinate system.
plt.close('all')
downsampling = 200
quats = orientation[::downsampling]
rotation_matrices = [quaternion_matrix(quaternion_inverse(quat))[:3,:3]
                     for quat in quats]
draw_moving_coordinate_system(enu.T[::downsampling], rotation_matrices,
                              arrow_scale=5)


assert np.all(pos_times == imu_times)
nav_states = np.hstack([np.atleast_2d(pos_times).T, np.atleast_2d(lats).T,
                        np.atleast_2d(lons).T, np.atleast_2d(alts).T,
                        orientation])

nav_state_provider = NavStateINS(nav_states)
# ----------------------------------------------------------------------------


# ----------------------------------------------------------------------------
# Compare angulare rates.

sfm_rot_vecs = []
ins_rot_vecs = []
for i in range(len(sfm_poses) - 1):
    # SFM
    t1, t2 = sfm_poses[i:i + 2, 0]
    quat1, quat2 = sfm_poses[i : i + 2, 4:]
    sfm_rot_vecs.append(np.product(rotation_between_quats(quat1, quat2)))

    # INS
    quat1 = nav_state_provider.quat(t1)
    quat2 = nav_state_provider.quat(t2)
    ins_rot_vecs.append(np.product(rotation_between_quats(quat1, quat2)))

sfm_rot_vecs = np.array(sfm_rot_vecs)
ins_rot_vecs = np.array(ins_rot_vecs)
plt.close('all')
plt.plot(np.sqrt(np.sum(sfm_rot_vecs**2, axis=1))*180/np.pi, '.',
         label='SFM Rotation Degrees')
plt.plot(np.sqrt(np.sum(ins_rot_vecs**2, axis=1))*180/np.pi,  '.',
         label='INS Rotation Degrees')
plt.legend(fontsize=16)

width = 50
err = []
for i in range(len(ins_rot_vecs) - width):
    R = horn(ins_rot_vecs[i:i+width].T, sfm_rot_vecs[i:i+width].T,
             fit_translation=False)
    err.append(np.mean((np.dot(R, ins_rot_vecs[i:i+width].T) - sfm_rot_vecs[i:i+width].T)**2))

plt.close('all')
plt.plot(err)
R = horn(ins_rot_vecs[800:].T, sfm_rot_vecs[800:].T, fit_translation=False)


sfm_rot_vecs = []
ins_rot_vecs = []
for i in range(len(sfm_poses) - 1):
    # SFM
    t1, t2 = sfm_poses[i:i + 2, 0]
    quat1, quat2 = sfm_poses[i : i + 2, 4:]
    sfm_rot_vecs.append(np.product(rotation_between_quats(quat1, quat2)))

    # INS
    quat1 = nav_state_provider.quat(t1)
    quat2 = nav_state_provider.quat(t2)
    ins_rot_vecs.append(np.product(rotation_between_quats(quat1, quat2)))

# ----------------------------------------------------------------------------


# ----------------------------------------------------------------------------
img_fnames = glob.glob('%s/%s/*.png' % (colmap_dir, colmap_images_subdir))
img_fnames = img_fnames + glob.glob('%s/%s/*.jpg' % (colmap_dir, colmap_images_subdir))
img_fnames = sorted(img_fnames)

min_time = 1624481549022702/1000000
max_time = 1624481588594812/1000000


# Colmap then uses this pairing to solve for a similarity transform to best-
# match the SfM poses it recovered into these positions. All Colmap coordinates
# in this aligned version of its reconstruction will then be in easting/
# northing/up meters coordinates
align_fname = '%s/image_locations.txt' % colmap_dir
with open(align_fname, 'w') as fo:
    for i in range(len(img_fnames)):
        name = os.path.split(img_fnames[i])[1]
        t = float(os.path.splitext(name)[0])/1000000


        if t < min_time or t > max_time:
            continue

        try:
            pos = enu_v_time(t)
            fo.write('%s %0.8f %0.8f %0.8f\n' % (name, pos[0], pos[1], pos[2]))
        except ValueError:
            pass

try:
    os.makedirs('%s/%s' % (colmap_dir, aligned_sparse_recon_subdir))
except (OSError, IOError):
    pass

print('Now run\nnoaa_kamera/src/kitware-ros-pkg/postflight_scripts/scripts/'
      'colmap/model_aligner.sh %s %s %s %s' % (colmap_dir.replace('/host_filesystem', ''),
                                               sparse_recon_subdir,
                                               'image_locations.txt',
                                               aligned_sparse_recon_subdir))
# ----------------------------------------------------------------------------