//=========================================================================
//
// Copyright 2018 Kitware, Inc.
// Author: Guilbert Pierre (spguilbert@gmail.com)
// Date: 03-27-2018
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//=========================================================================

// This slam algorithm is inspired by the LOAM algorithm:
// J. Zhang and S. Singh. LOAM: Lidar Odometry and Mapping in Real-time.
// Robotics: Science and Systems Conference (RSS). Berkeley, CA, July 2014.

// The algorithm is composed of three sequential steps:
//
// - Keypoints extraction: this step consists of extracting keypoints over
// the points clouds. To do that, the laser lines / scans are trated indepently.
// The laser lines are projected onto the XY plane and are rescale depending on
// their vertical angle. Then we compute their curvature and create two class of
// keypoints. The edges keypoints which correspond to points with a hight curvature
// and planar points which correspond to points with a low curvature.
//
// - Ego-Motion: this step consists of recovering the motion of the lidar
// sensor between two frames (two sweeps). The motion is modelized by a constant
// velocity and angular velocity between two frames (i.e null acceleration).
// Hence, we can parameterize the motion by a rotation and translation per sweep / frame
// and interpolate the transformation inside a frame using the timestamp of the points.
// Since the points clouds generated by a lidar are sparses we can't design a
// pairwise match between keypoints of two successive frames. Hence, we decided to use
// a closest-point matching between the keypoints of the current frame
// and the geometrics features derived from the keypoints of the previous frame.
// The geometrics features are lines or planes and are computed using the edges keypoints
// and planar keypoints of the previous frame. Once the matching is done, a keypoint
// of the current frame is matched with a plane / line (depending of the
// nature of the keypoint) from the previous frame. Then, we recover R and T by
// minimizing the function f(R, T) = sum(d(point, line)^2) + sum(d(point, plane)^2).
// Which can be writen f(R, T) = sum((R*X+T-P).t*A*(R*X+T-P)) where:
// - X is a keypoint of the current frame
// - P is a point of the corresponding line / plane
// - A = (n*n.t) with n being the normal of the plane
// - A = (I - n*n.t).t * (I - n*n.t) with n being a director vector of the line
// Since the function f(R, T) is a non-linear mean square error function
// we decided to use the Levenberg-Marquardt algorithm to recover its argmin.
//
// - Mapping: This step consists of refining the motion recovered in the Ego-Motion
// step and to add the new frame in the environment map. Thanks to the ego-motion
// recovered at the previous step it is now possible to estimate the new position of
// the sensor in the map. We use this estimation as an initial point (R0, T0) and we
// perform an optimization again using the keypoints of the current frame and the matched
// keypoints of the map (and not only the previous frame this time!). Once the position in the
// map has been refined from the first estimation it is then possible to update the map by
// adding the keypoints of the current frame into the map.
//
// In the following programs : "vtkSlam.h" and "vtkSlam.cxx" the lidar
// coordinate system {L} is a 3D coordinate system with its origin at the
// geometric center of the lidar. The world coordinate system {W} is a 3D
// coordinate system which coinciding with {L] at the initial position. The
// points will be denoted by the ending letter L or W if they belong to
// the corresponding coordinate system

// LOCAL
#include "vtkSlam.h"
#include "vtkSpinningSensorKeypointExtractor.h"

// VTK
#include <vtkCellArray.h>
#include <vtkDataArray.h>
#include <vtkDoubleArray.h>
#include <vtkInformation.h>
#include <vtkInformationVector.h>
#include <vtkLine.h>
#include <vtkMath.h>
#include <vtkNew.h>
#include <vtkObjectFactory.h>
#include <vtkPointData.h>
#include <vtkPoints.h>
#include <vtkPolyData.h>
#include <vtkSmartPointer.h>
#include <vtkTable.h>
#include <vtkTransform.h>
#include <vtkTransformPolyDataFilter.h>

// PCL
#include<pcl/common/transforms.h>

// vtkSlam filter input ports (vtkPolyData and vtkTable)
#define LIDAR_FRAME_INPUT_PORT 0       ///< Current LiDAR frame
#define CALIBRATION_INPUT_PORT 1       ///< LiDAR calibration (vtkTable)
#define INPUT_PORT_COUNT 2

// vtkSlam filter output ports (vtkPolyData)
#define SLAM_FRAME_OUTPUT_PORT 0       ///< Current transformed SLAM frame enriched with debug arrays
#define SLAM_TRAJECTORY_OUTPUT_PORT 1  ///< Trajectory (with position, orientation, covariance and time)
#define EDGE_MAP_OUTPUT_PORT 2         ///< Edge keypoints map
#define PLANE_MAP_OUTPUT_PORT 3        ///< Plane keypoints map
#define BLOB_MAP_OUTPUT_PORT 4         ///< Blob keypoints map
#define EDGE_KEYPOINTS_OUTPUT_PORT 5   ///< Extracted edge keypoints from current frame
#define PLANE_KEYPOINTS_OUTPUT_PORT 6  ///< Extracted plane keypoints from current frame
#define BLOB_KEYPOINTS_OUTPUT_PORT 7   ///< Extracted blob keypoints from current frame
#define OUTPUT_PORT_COUNT 8

//-----------------------------------------------------------------------------
vtkStandardNewMacro(vtkSlam)

namespace {

//-----------------------------------------------------------------------------
template <typename T>
vtkSmartPointer<T> createArray(const std::string& Name, int NumberOfComponents = 1, int NumberOfTuples = 0)
{
  vtkSmartPointer<T> array = vtkSmartPointer<T>::New();
  array->SetNumberOfComponents(NumberOfComponents);
  array->SetNumberOfTuples(NumberOfTuples);
  array->SetName(Name.c_str());
  return array;
}

//-----------------------------------------------------------------------------
inline double Rad2Deg(double val)
{
  return val / vtkMath::Pi() * 180.;
}

//-----------------------------------------------------------------------------
void AddPoseToPolyData(const Eigen::Isometry3d& pose, vtkPolyData* poly)
{
  // Add position
  Eigen::Vector3d translation = pose.translation();
  poly->GetPoints()->InsertNextPoint(translation.x(), translation.y(), translation.z());

  // Add orientation
  Eigen::Quaterniond orientation(pose.linear());
  double* xyzw = orientation.coeffs().data();
  double wxyz[] = {xyzw[3], xyzw[0], xyzw[1], xyzw[2]};
  poly->GetPointData()->GetArray("Orientation")->InsertNextTuple(wxyz);

  // Add line linking 2 successive points
  vtkIdType nPoints = poly->GetNumberOfPoints();
  if (nPoints >= 2)
  {
    vtkSmartPointer<vtkLine> line = vtkSmartPointer<vtkLine>::New();
    line->GetPointIds()->SetId(0, nPoints - 2);
    line->GetPointIds()->SetId(1, nPoints - 1);
    poly->GetLines()->InsertNextCell(line);
  }
}

//-----------------------------------------------------------------------------
void PolyDataFromPointCloud(pcl::PointCloud<Slam::Point>::Ptr pc, vtkPolyData* poly)
{
  // Set points
  auto pts = vtkSmartPointer<vtkPoints>::New();
  for (const Slam::Point& p: *pc)
  {
    pts->InsertNextPoint(p.x, p.y, p.z);
    // TODO : add other fields (intensity, time, laserId)?
  }
  poly->SetPoints(pts);

  // Set cells
  vtkNew<vtkIdTypeArray> cells;
  cells->SetNumberOfValues(pc->size() * 2);
  vtkIdType* ids = cells->GetPointer(0);
  for (vtkIdType i = 0; i < pc->size(); ++i)
  {
    ids[i * 2] = 1;
    ids[i * 2 + 1] = i;
  }

  auto cellArray = vtkSmartPointer<vtkCellArray>::New();
  cellArray->SetCells(pc->size(), cells.GetPointer());
  poly->SetVerts(cellArray);
}

//-----------------------------------------------------------------------------
void PointCloudFromPolyData(vtkPolyData* poly, pcl::PointCloud<Slam::Point>::Ptr pc)
{
  auto arrayTime = poly->GetPointData()->GetArray("adjustedtime");
  auto arrayLaserId = poly->GetPointData()->GetArray("laser_id");
  auto arrayIntensity = poly->GetPointData()->GetArray("intensity");
  for (vtkIdType i = 0; i < poly->GetNumberOfPoints(); i++)
  {
    double pos[3];
    poly->GetPoint(i, pos);
    Slam::Point p;
    p.x = pos[0];
    p.y = pos[1];
    p.z = pos[2];
    p.time = arrayTime->GetTuple1(i) * 1e-6; // time in seconds
    p.laserId = arrayLaserId->GetTuple1(i);
    p.intensity = arrayIntensity->GetTuple1(i);
    pc->push_back(p);
  }
  pc->header.stamp = arrayTime->GetTuple1(0);
}

//-----------------------------------------------------------------------------
template <typename T>
std::vector<size_t> sortIdx(const std::vector<T> &v)
{
  // initialize original index locations
  std::vector<size_t> idx(v.size());
  std::iota(idx.begin(), idx.end(), 0);

  // sort indexes based on comparing values in v
  std::sort(idx.begin(), idx.end(),
       [&v](size_t i1, size_t i2) {return v[i1] > v[i2];});

  return idx;
}
} // end of anonymous namespace

//-----------------------------------------------------------------------------
void vtkSlam::SetKeyPointsExtractor (vtkSpinningSensorKeypointExtractor* _arg)
{
  vtkSetObjectBodyMacro(KeyPointsExtractor,vtkSpinningSensorKeypointExtractor,_arg);
  this->SlamAlgo->SetKeyPointsExtractor(this->KeyPointsExtractor->GetExtractor());
}

//-----------------------------------------------------------------------------
std::vector<size_t> vtkSlam::GetLaserIdMapping(vtkTable *calib)
{
  auto array = vtkDataArray::SafeDownCast(calib->GetColumnByName("verticalCorrection"));
  std::vector<size_t> laserIdMapping;
  if (array)
  {
    std::vector<double> verticalCorrection(array->GetNumberOfTuples());
    for (int i = 0; i < array->GetNumberOfTuples(); ++i)
    {
      verticalCorrection[i] = array->GetTuple1(i);
    }
    laserIdMapping = sortIdx(verticalCorrection);
  }
  else
  {
    vtkErrorMacro(<< "The calibration data has no column named 'verticalCorrection'");
  }
  return laserIdMapping;
}

//-----------------------------------------------------------------------------
int vtkSlam::RequestData(vtkInformation *vtkNotUsed(request),
vtkInformationVector **inputVector, vtkInformationVector *outputVector)
{
  // Get the input
  vtkPolyData* input = vtkPolyData::GetData(inputVector[LIDAR_FRAME_INPUT_PORT], 0);
  vtkTable* calib = vtkTable::GetData(inputVector[CALIBRATION_INPUT_PORT], 0);
  std::vector<size_t> laserMapping = GetLaserIdMapping(calib);

  // Conversion vtkPolyData -> PCL pointcloud
  pcl::PointCloud<Slam::Point>::Ptr pc(new pcl::PointCloud<Slam::Point>);
  PointCloudFromPolyData(input, pc);

  // Run SLAM
  this->SlamAlgo->AddFrame(pc, laserMapping);

  // Get SLAM transform
  Transform Tworld = this->SlamAlgo->GetWorldTransform();
  vtkSmartPointer<vtkTransform> transform = vtkSmartPointer<vtkTransform>::New();
  transform->PostMultiply();
  Eigen::Vector3d ypr = Tworld.GetRotation().toRotationMatrix().eulerAngles(2, 1, 0);
  transform->RotateX(Rad2Deg(ypr[2]));
  transform->RotateY(Rad2Deg(ypr[1]));
  transform->RotateZ(Rad2Deg(ypr[0]));
  transform->Translate(Tworld.GetPosition().data());

  // Transform the current frame to world coordinates
  vtkSmartPointer<vtkTransformPolyDataFilter> transformFilter = vtkSmartPointer<vtkTransformPolyDataFilter>::New();
  transformFilter->SetInputData(input);
  transformFilter->SetTransform(transform);
  transformFilter->Update();

  // Update Trajectory
  AddPoseToPolyData(Tworld.GetIsometry(), this->Trajectory);
  this->Trajectory->GetPointData()->GetArray("Time")->InsertNextTuple(&Tworld.time);
  this->Trajectory->GetPointData()->GetArray("Covariance")->InsertNextTuple(this->SlamAlgo->GetTransformCovariance().data());

  // Fill SLAM filter outputs
  Slam::PointCloud::Ptr tmpPcl(new Slam::PointCloud);
  auto keypointExtractor = this->SlamAlgo->GetKeyPointsExtractor();

  // ===== SLAM frame and pose =====
  // Output : Current LiDAR frame in world coordinates
  auto* slamFrame = vtkPolyData::GetData(outputVector, SLAM_FRAME_OUTPUT_PORT);
  slamFrame->ShallowCopy(transformFilter->GetOutput());
  // Output : SLAM Trajectory
  auto* slamTrajectory = vtkPolyData::GetData(outputVector, SLAM_TRAJECTORY_OUTPUT_PORT);
  slamTrajectory->ShallowCopy(this->Trajectory);

  // ===== Aggregated Keypoints maps =====
  // Output : Edges points map
  auto* edgeMap = vtkPolyData::GetData(outputVector, EDGE_MAP_OUTPUT_PORT);
  PolyDataFromPointCloud(this->SlamAlgo->GetEdgesMap(), edgeMap);
  // Output : Planar points map
  auto* planarMap = vtkPolyData::GetData(outputVector, PLANE_MAP_OUTPUT_PORT);
  PolyDataFromPointCloud(this->SlamAlgo->GetPlanarsMap(), planarMap);
  // Output : Blob points map
  auto* blobMap = vtkPolyData::GetData(outputVector, BLOB_MAP_OUTPUT_PORT);
  PolyDataFromPointCloud(this->SlamAlgo->GetBlobsMap(), blobMap);

  // ===== Extracted keypoints from current frame =====
  // Output : Current edge keypoints
  auto* edgePoints = vtkPolyData::GetData(outputVector, EDGE_KEYPOINTS_OUTPUT_PORT);
  pcl::transformPointCloud(*keypointExtractor->GetEdgePoints(), *tmpPcl, Tworld.GetMatrix());
  PolyDataFromPointCloud(tmpPcl, edgePoints);
  // Output : Current planar keypoints
  auto* planarPoints = vtkPolyData::GetData(outputVector, PLANE_KEYPOINTS_OUTPUT_PORT);
  pcl::transformPointCloud(*keypointExtractor->GetPlanarPoints(), *tmpPcl, Tworld.GetMatrix());
  PolyDataFromPointCloud(tmpPcl, planarPoints);
  // Output : Current blob keypoints
  auto* blobPoints = vtkPolyData::GetData(outputVector, BLOB_KEYPOINTS_OUTPUT_PORT);
  pcl::transformPointCloud(*keypointExtractor->GetBlobPoints(), *tmpPcl, Tworld.GetMatrix());
  PolyDataFromPointCloud(tmpPcl, blobPoints);

  // add debug information if displayMode is enabled
  if (this->DisplayMode)
  {
    // Keypoints extraction debug array (curvatures, depth gap, intensity gap...)
    // Info added as PointData array of output0
    auto keypointsExtractionDebugArray = keypointExtractor->GetDebugArray();
    for (const auto& it : keypointsExtractionDebugArray)
    {
      auto array = createArray<vtkDoubleArray>(it.first.c_str(), 1, it.second.size());
      // memcpy is a better alternative than looping on all tuples
      std::memcpy(array->GetVoidPointer(0), it.second.data(), sizeof(double) * it.second.size());
      slamFrame->GetPointData()->AddArray(array);
    }

    // General SLAM info (number of keypoints used in ICP and optimization, max variance, ...)
    // Info added as PointData array of output1
    auto debugInfo = this->SlamAlgo->GetDebugInformation();
    for (const auto& it : debugInfo)
    {
      slamTrajectory->GetPointData()->GetArray(it.first.c_str())->InsertNextTuple1(it.second);
    }

    // ICP keypoints matching results for ego-motion or mapping steps
    // Info added as PointData array of output5-7
    std::unordered_map<std::string, vtkPolyData*> outputMap;
    outputMap["EgoMotion: edges matches"] = edgePoints;
    outputMap["Mapping: edges matches"] = edgePoints;
    outputMap["EgoMotion: planes matches"] = planarPoints;
    outputMap["Mapping: planes matches"] = planarPoints;
    auto debugArray = this->SlamAlgo->GetDebugArray();
    for (const auto& it : outputMap)
    {
      auto array = createArray<vtkDoubleArray>(it.first.c_str(), 1, debugArray[it.first].size());
      // memcpy is a better alternative than looping on all tuples
      std::memcpy(array->GetVoidPointer(0), debugArray[it.first].data(), sizeof(double) * debugArray[it.first].size());
      it.second->GetPointData()->AddArray(array);
    }
  }

  return 1;
}

//-----------------------------------------------------------------------------
void vtkSlam::PrintSelf(ostream& os, vtkIndent indent)
{
  this->Superclass::PrintSelf(os, indent);
  os << indent << "Slam parameters: " << std::endl;
  vtkIndent paramIndent = indent.GetNextIndent();
  #define PrintParameter(param) os << paramIndent << #param << "\t" << this->SlamAlgo->Get##param() << std::endl;

  PrintParameter(FastSlam)
  PrintParameter(Undistortion)

  PrintParameter(MaxDistanceForICPMatching)

  PrintParameter(EgoMotionLMMaxIter)
  PrintParameter(EgoMotionICPMaxIter)
  PrintParameter(EgoMotionLineDistanceNbrNeighbors)
  PrintParameter(EgoMotionMinimumLineNeighborRejection)
  PrintParameter(EgoMotionMaxLineDistance)
  PrintParameter(EgoMotionLineDistancefactor)
  PrintParameter(EgoMotionPlaneDistanceNbrNeighbors)
  PrintParameter(EgoMotionMaxPlaneDistance)
  PrintParameter(EgoMotionPlaneDistancefactor1)
  PrintParameter(EgoMotionPlaneDistancefactor2)
  PrintParameter(EgoMotionInitLossScale)
  PrintParameter(EgoMotionFinalLossScale)

  PrintParameter(MappingLMMaxIter)
  PrintParameter(MappingICPMaxIter)
  PrintParameter(MappingLineDistanceNbrNeighbors)
  PrintParameter(MappingMinimumLineNeighborRejection)
  PrintParameter(MappingMaxLineDistance)
  PrintParameter(MappingLineMaxDistInlier)
  PrintParameter(MappingLineDistancefactor)
  PrintParameter(MappingPlaneDistanceNbrNeighbors)
  PrintParameter(MappingPlaneDistancefactor1)
  PrintParameter(MappingPlaneDistancefactor2)
  PrintParameter(MappingMaxPlaneDistance)
  PrintParameter(MappingInitLossScale)
  PrintParameter(MappingFinalLossScale)

  this->GetKeyPointsExtractor()->PrintSelf(os, indent);
}

//-----------------------------------------------------------------------------
vtkSlam::vtkSlam()
: SlamAlgo(new Slam)
{
  this->SetNumberOfInputPorts(INPUT_PORT_COUNT);
  this->SetNumberOfOutputPorts(OUTPUT_PORT_COUNT);
  this->Reset();
}

//-----------------------------------------------------------------------------
void vtkSlam::Reset()
{
  this->SlamAlgo->Reset();

  // init the output SLAM trajectory
  this->Trajectory = vtkSmartPointer<vtkPolyData>::New();
  auto pts = vtkSmartPointer<vtkPoints>::New();
  this->Trajectory->SetPoints(pts);
  auto cellArray = vtkSmartPointer<vtkCellArray>::New();
  this->Trajectory->SetLines(cellArray);
  this->Trajectory->GetPointData()->AddArray(createArray<vtkDoubleArray>("Time", 1));
  this->Trajectory->GetPointData()->AddArray(createArray<vtkDoubleArray>("Orientation", 4));
  this->Trajectory->GetPointData()->AddArray(createArray<vtkDoubleArray>("Covariance", 36));

  // add the required array in the trajectory
  if (this->DisplayMode)
  {
    auto debugInfo = this->SlamAlgo->GetDebugInformation();
    for (const auto& it : debugInfo)
    {
      this->Trajectory->GetPointData()->AddArray(createArray<vtkDoubleArray>(it.first));
    }
  }
}

//-----------------------------------------------------------------------------
int vtkSlam::FillInputPortInformation(int port, vtkInformation *info)
{
  // Pointcloud data
  if (port == LIDAR_FRAME_INPUT_PORT)
  {
    info->Set(vtkDataObject::DATA_TYPE_NAME(), "vtkPolyData" );
    return 1;
  }
  // LiDAR calibration
  if (port == CALIBRATION_INPUT_PORT)
  {
    info->Set(vtkDataObject::DATA_TYPE_NAME(), "vtkTable" );
    return 1;
  }
  return 0;
}

//-----------------------------------------------------------------------------
void vtkSlam::SetVoxelGridLeafSizeEdges(double size)
{
  this->SlamAlgo->SetVoxelGridLeafSizeEdges(size);
  this->ParametersModificationTime.Modified();
}

//-----------------------------------------------------------------------------
void vtkSlam::SetVoxelGridLeafSizePlanes(double size)
{
  this->SlamAlgo->SetVoxelGridLeafSizePlanes(size);
  this->ParametersModificationTime.Modified();
}

//-----------------------------------------------------------------------------
void vtkSlam::SetVoxelGridLeafSizeBlobs(double size)
{
  this->SlamAlgo->SetVoxelGridLeafSizeBlobs(size);
  this->ParametersModificationTime.Modified();
}

//-----------------------------------------------------------------------------
void vtkSlam::SetVoxelGridSize(int size)
{
  this->SlamAlgo->SetVoxelGridSize(size);
  this->ParametersModificationTime.Modified();
}

//-----------------------------------------------------------------------------
void vtkSlam::SetVoxelGridResolution(double resolution)
{
  this->SlamAlgo->SetVoxelGridResolution(resolution);
  this->ParametersModificationTime.Modified();
}