from __future__ import division, print_function
import os
import numpy as np
import cv2
import time
from PIL import Image as PILImage
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
import glob
from scipy import interpolate
from scipy.optimize import minimize_scalar
import copy
import gtsam
import pickle


# ADAPT imports.
from colmap_processing.calibration import horn
import colmap_processing.camera_models as camera_models
from colmap_processing.colmap_interface import read_images_binary, \
    read_points3D_binary, read_cameras_binary, qvec2rotmat, \
    standard_cameras_from_colmap
from colmap_processing.colmap_interface import Image as ColmapImage
from colmap_processing.platform_pose import PlatformPoseInterp
from colmap_processing.rotations import quaternion_slerp, quaternion_matrix, \
    quaternion_inverse, quaternion_multiply, euler_from_quaternion
import colmap_processing.slam
from colmap_processing.slam import reprojection_error, rescale_sfm_to_ins, \
    map_to_pinhole_problem, read_colmap_results, show_reproj_error_on_images, \
    show_solution_errors

%matplotlib auto


if False:
    recon_dir = '/mnt/data2tb/adapt/2022-04-21/snapshots/2028832250'
    image_dir = '/mnt/data2tb/adapt/2022-04-21/images0'
#    recon_dir = '/mnt/data2tb/adapt/2022-04-21/snapshots/0444925779'    # smaller
#    recon_dir = '/mnt/data2tb/adapt/2022-04-21/snapshots/0441352111'    # much smaller
    odometry_txt = '/mnt/homenas2/noaa_adapt/adapt_flights/2022-04-21/odometry.txt'
    imu_txt = '/mnt/homenas2/noaa_adapt/adapt_flights/2022-04-21/imu.txt'
else:
    recon_dir = '/mnt/data2tb/libraries/adapt/adapt_ros_ws/data/2022_jan_flight/sparse/0'    # much smaller
    image_dir = '/mnt/data2tb/libraries/adapt/adapt_ros_ws/data/2022_jan_flight/images0'
    odometry_txt = '/mnt/homenas2/noaa_adapt/adapt_flights/2022_jan_flight/odometry.txt'
    imu_txt = '/mnt/homenas2/noaa_adapt/adapt_flights/2022-04-21/imu.txt'

camera_model_fname = '/mnt/data2tb/libraries/adapt/adapt_ros_ws/data/2022_jan_flight/camera_model.yaml'

max_images = 2000
max_image_pts = 1000
max_images = None
#max_image_pts = None

# Read in the odometry txt file encoding the outputs from the INS during the
# flight and create a 'PlatformPoseInterp' object, which allows us to query INS
# state at any time from the flight.
ins = PlatformPoseInterp.from_odometry_llh_txt(odometry_txt)
cm_ins = camera_models.load_from_file(camera_model_fname, ins)
#cm_ins.cam_quat = np.array([0.0661142045242736, 0.05316447060215157, 0.5972116203253665, 0.7975843100099612])
#imu_data = np.loadtxt(imu_txt)
imu_data = None

# Load in the structure from motion reconstruction generated by Colmap. This
# will provide a starting point for image-based measurements. We remap the
# image measurements and camera model to a pinhole camera.
ret = read_colmap_results(recon_dir, use_camera_id=1, max_images=max_images,
                          max_image_pts=max_image_pts, min_track_len=4)
cm_sfm, im_pts_at_time, wrld_pts0, image_names = ret
#cm_sfm, im_pts_at_time = map_to_pinhole_problem(cm_sfm, im_pts_at_time)
cm_sfm, wrld_pts = rescale_sfm_to_ins(cm_sfm, ins, wrld_pts0)

print('Reprojection error', reprojection_error(cm_sfm, im_pts_at_time,
                                               wrld_pts))

# We use the INS-reported state as our starting point.
#cm._platform_pose_provider = ins


if False:
    cm_sfm, im_pts_at_time2 = map_to_pinhole_problem(cm_sfm, im_pts_at_time)
    image_times = sorted(list(im_pts_at_time2.keys()))
    i, j = im_pts_at_time2[image_times[0]]
    one = dict(zip(j, i.T))
    i, j = im_pts_at_time2[image_times[10]]
    two = dict(zip(j, i.T))
    keys = set(one.keys()).intersection(set(two.keys()))
    im_pts1 = np.array([one[t] for t in keys]).T
    im_pts2 = np.array([two[t] for t in keys]).T
    K = cm_sfm.K
    pose, pose_cov = stereo_pair_pose_constr(K, im_pts1, im_pts2,
                                             pixel_sigma=3, max_sep=100,
                                             max_viz_dist=1e4)
    print(pose_cov[:3, :3])



#ins_drift_rate = np.array([0.000005, 0.0000025])
imu_drift_rate = [1000, 100]
imu_drift_rate = None
ins_drift_rate = None
min_pos_std = None
min_orientation_std = None
min_pos_std = np.array([5, 5, 10])
#min_pos_std = 2
#min_pos_std = 3
#min_pos_std *= 100000
min_orientation_std = np.array([3, 3, 20])/180*np.pi
min_orientation_std = 1000000000
#min_orientation_std *= 10000
#ins_drift_rate = None
#min_pos_std[:] = 1e8
#min_orientation_std = 1e8
pixel_sigma = 10
slam = colmap_processing.slam.OfflineSLAM(cm_ins,
                                          min_pos_std=min_pos_std,
                                          min_orientation_std=min_orientation_std,
                                          pixel_sigma=pixel_sigma,
                                          imu_drift_rate=imu_drift_rate,
                                          ins_drift_rate=ins_drift_rate,
                                          robust_pixels_k=None,
                                          robust_ins_k=None,
                                          balance_measurements=True,
                                          estimate_camera=True)
slam.define_problem(im_pts_at_time, wrld_pts, imu_data=None, time_uncertainty=2)
print('Running solver')
print('Final cost:', slam.solve())
#slam.result = slam.initial_estimate
cm2, wrld_pts2 = slam.convert_solution()

#print(slam)
#print('Original reprojection error', reprojection_error(cm_sfm, im_pts_at_time,
#                                                        wrld_pts, plot_results=False))
print('Final reprojection error', reprojection_error(cm2, im_pts_at_time,
                                                     wrld_pts2, plot_results=True))

#show_solution_errors(cm2, cm_ins, im_pts_at_time, wrld_pts2)

if False:
    save_fname = '%s/max_images=%i_max_image_pts=%i_pixel_sigma=%i' % (os.path.split(image_dir)[0],
                                                            max_images,
                                                            max_image_pts,
                                                            pixel_sigma)
    with open(save_fname, 'wb') as handle:
        pickle.dump(slam, handle, protocol=pickle.HIGHEST_PROTOCOL)

if False:
    for robust_pixels_k in np.logspace(-2, np.log10(2), 5):
        print('robust_pixels_k', robust_pixels_k)
        slam.update_reproj_sigma(pixel_sigma=pixel_sigma,
                                 robust_pixels_k=robust_pixels_k)
        print('Final cost:', slam.solve())

    #slam.result = slam.initial_estimate
    cm2, wrld_pts2 = slam.convert_solution()


if False:
    # Show reprojection error on images.
    out_dir = '/mnt/homenas2/noaa_adapt/adapt_flights/2022-04-21/debug_images2'
    os.makedirs(out_dir, exist_ok=True)
    show_reproj_error_on_images(cm2, im_pts_at_time, wrld_pts2, image_names,
                                image_dir, out_dir)

if False:
    # Compare INS to SfM.
    out1 = cm_ins.platform_pose_provider.estimate_imu_output(rate=100, s=1/500,
                                                             gravity=0,
                                                             plot_results=False)
    out2 = cm_sfm.platform_pose_provider.estimate_imu_output(rate=100, s=1/500,
                                                             gravity=0,
                                                             plot_results=False)
    plt.figure(num=None, figsize=(15.3, 10.7), dpi=80)
    plt.rc('font', **{'size': 20})
    plt.rc('axes', linewidth=4)
    txt = ['X', 'Y', 'Z', 'W']
    for i in range(1, 4):
        plt.subplot(3, 1, i)
        plt.plot(out1[:, 0], out1[:, i], '.', label='INS')
        plt.plot(out2[:, 0], out2[:, i], '.', label='SfM')
        plt.xlabel('Time (s)')
        plt.ylabel('%s-Accel (m/s^2)' % txt[i-1])
        plt.legend(fontsize=14)
        plt.tight_layout()


if False:
    # The linearized version of the negative log-likelihood is 1/2|A*x-b|^2.
    a = slam.graph.linearize(slam.result)
    b = a.sparseJacobian_()
    print(b.shape[1], slam.graph.size(), slam.result.size())

    c = a.jacobian()

    plt.plot(np.max(a.jacobian()[0], axis=0))

    plt.imshow(np.dot(a.jacobian()[0], a.jacobian()[0].T))

    160, 45
    num_poses = len(slam.pose_times)
    num_landmarks = len(slam.wrld_pts_orig_ind)

    N = 0
    for i in range(slam.graph.size()):
        N += slam.graph.at(i).dim()

    M = num_poses*6 + num_landmarks*3

    6*num_poses + num_landmarks*2 + 6*num_poses


    marginals = gtsam.Marginals(slam.graph, slam.result)

    covs = []
    for i in slam.wrld_pts_orig_ind:
        pose_cov = marginals.marginalCovariance(L(i))
        covs.append(max(np.linalg.eig(pose_cov)[0]))


image_times = np.array(sorted(list(im_pts_at_time.keys())))

plt.close('all')

if False:
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    plt.plot(wrld_pts2[0], wrld_pts2[1], wrld_pts2[2], 'ro')
    plt.plot(wrld_pts[0], wrld_pts[1], wrld_pts[2], 'bo')


pos1 = np.array([cm_sfm.platform_pose_provider.pos(t) for t in image_times]).T
pos2 = np.array([cm2.platform_pose_provider.pos(t) for t in image_times]).T
pos3 = np.array([ins.pos(t) for t in image_times]).T

if False:
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    plt.plot(pos1[0], pos1[1], pos1[2], 'ro', label='Original SfM')
    plt.plot(pos2[0], pos2[1], pos2[2], 'bo', markersize=8, label='GTSAM')
    plt.plot(pos3[0], pos3[1], pos3[2], 'go', label='INS')
    plt.xlabel('Easting (m)', fontsize=40)
    plt.ylabel('Northing (m)', fontsize=40)
    plt.legend(fontsize=20)

plt.figure(num=None, figsize=(15.3, 10.7), dpi=80)
plt.rc('font', **{'size': 20})
plt.rc('axes', linewidth=4)

plt.plot(pos2[0], pos2[1], 'bo', markersize=8, label='GTSAM')
plt.plot(pos1[0], pos1[1], 'ro', label='Original SfM')
plt.plot(pos3[0], pos3[1], 'go', label='INS')
plt.axis('image')
plt.xlabel('Easting (m)', fontsize=40)
plt.ylabel('Northing (m)', fontsize=40)
plt.legend(fontsize=20)

if True:
    for i in range(pos1.shape[1]):
        #plt.plot([pos2[0, i], pos1[0, i]], [pos2[1, i], pos1[1, i]], 'r-')
        plt.plot([pos2[0, i], pos3[0, i]], [pos2[1, i], pos3[1, i]], 'k-')

plt.figure(num=None, figsize=(15.3, 10.7), dpi=80)
plt.rc('font', **{'size': 20})
plt.rc('axes', linewidth=4)
plt.plot(pos2[0], pos2[2], 'bo', markersize=8, label='GTSAM')
plt.plot(pos1[0], pos1[2], 'ro', label='Original SfM')
plt.plot(pos3[0], pos3[2], 'go', label='INS')
plt.axis('image')
plt.xlabel('Easting (m)', fontsize=40)
plt.ylabel('Up (m)', fontsize=40)
plt.legend(fontsize=20)

quat1 = np.array([cm_sfm.platform_pose_provider.quat(t) for t in image_times]).T
quat2 = np.array([cm2.platform_pose_provider.quat(t) for t in image_times]).T
quat3 = np.array([ins.quat(t) for t in image_times]).T
quat1 = quat1*np.atleast_2d(np.sign(quat1[:, -1])).T
quat2 = quat2*np.atleast_2d(np.sign(quat2[:, -1])).T
quat3 = quat3*np.atleast_2d(np.sign(quat3[:, -1])).T

if False:
    plt.figure(num=None, figsize=(15.3, 10.7), dpi=80)
    plt.rc('font', **{'size': 20})
    plt.rc('axes', linewidth=4)
    txt = ['Quat-X', 'Quat-Y', 'Quat-Z', 'Quat-W']
    for i in range(4):
        plt.subplot(4, 1, i+1)
        plt.plot(image_times, quat1[i], 'ro', label='Original SfM')
        plt.plot(image_times, quat3[i], 'go', label='INS')
        plt.plot(image_times, quat2[i], 'bo', label='GTSAM')
        plt.xlabel(txt[i], fontsize=20)
        plt.legend(fontsize=15)

    plt.tight_layout()


plt.figure(num=None, figsize=(15.3, 10.7), dpi=80)
plt.rc('font', **{'size': 12})
plt.rc('axes', linewidth=4)
plt.title('Difference Between INS and GTSAM', fontsize=30)
quat2_ = quat2.T
quat3_ = quat3.T
dxyz = []
for j in range(len(quat2_)):
    dq = quaternion_multiply(quaternion_inverse(quat2_[j]), quat3_[j])
    dq = dq/np.linalg.norm(dq)
    dq *= np.sign(dq[-1])
    theta = 2*np.arccos(dq[3])*180/np.pi
    dxyz.append(dq[:3]/np.linalg.norm(dq[:3])*theta)

dxyz = np.array(dxyz).T
labels = ['X', 'Y', 'Z']
for i in range(3):
    plt.subplot(6, 1, i+1)
    plt.plot(image_times, dxyz[i])
    plt.ylabel('Rot-%s (deg)' % labels[i], fontsize=15)

dxyz = pos2 - pos3
for i in range(3, 6):
    plt.subplot(6, 1, i+1)
    plt.plot(image_times, dxyz[i-3])
    plt.ylabel('Disp-%s (m)' % labels[i-3], fontsize=15)

plt.tight_layout()


#cm2.platform_pose_provider.estimate_imu_output(plot_results=True)

# -----------------------------------------------------------------------------
# Analyze local time offsets.
dts = []
dts_ = np.linspace(-10, 10, 1001)
arg0 = len(dts_)//2
err0 = []
err = []
for i in range(pos3.shape[1]):
    pos3 = np.array([ins.pos(t) for t in image_times[i]+dts_]).T
    err__ = np.atleast_2d(pos2[:, i]).T - pos3
    err_ = np.sum((err__)**2, axis=0)
    ind = np.argmin(err_)
    dts.append(dts_[ind])
    err0.append(err__[:, arg0])
    err.append(err__[:, ind])

plt.figure()
plt.plot(dts)
plt.xlabel('Image Index', fontsize=24)
plt.ylabel('Optimal Time Offset (s)', fontsize=24)
plt.tight_layout()
plt.savefig('%s/time_offset.png' % os.path.split(image_dir)[0])

plt.figure()
err0 = np.array(err0).T
err = np.array(err).T
labels = ['X', 'Y', 'Z']
for i in range(3):
    plt.subplot(3, 1, i + 1)
    plt.plot(err0[i], 'r-')
    plt.plot(err[i], 'g-')
    plt.ylabel('%s-Error (m)' % labels[i], fontsize=18)
