/*=========================================================================

  Program:   PCL Plugin
  Module:    vtkPCLPlaneAlignment.cxx

  Copyright (c) Kitware, Inc.
  All rights reserved.
  See LICENSE or http://www.apache.org/licenses/LICENSE-2.0 for details.

  This software is distributed WITHOUT ANY WARRANTY; without even
  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
  PURPOSE.  See the above copyright notice for more information.

=========================================================================*/

// Local includes
#include "vtkPCLPlaneAlignment.h"
#include "vtkPCLFittingModel.h"

// STL includes
#include <cmath>

// VTK includes
#include <vtkInformation.h>
#include <vtkInformationVector.h>
#include <vtkNew.h>
#include <vtkPointData.h>
#include <vtkPoints.h>
#include <vtkPolyData.h>
#include <vtkSMPTools.h>
#include <vtkUnsignedIntArray.h>

// Eigen includes
#include <Eigen/Geometry>

//----------------------------------------------------------------------------
vtkStandardNewMacro(vtkPCLPlaneAlignment);

//-----------------------------------------------------------------------------
void vtkPCLPlaneAlignment::PrintSelf(ostream& os, vtkIndent indent)
{
  this->Superclass::PrintSelf(os, indent);
}

//-----------------------------------------------------------------------------
int vtkPCLPlaneAlignment::RequestData(vtkInformation* vtkNotUsed(request),
  vtkInformationVector** inputVector,
  vtkInformationVector* outputVector)
{
  // Get the input
  vtkPolyData* input = vtkPolyData::GetData(inputVector[0]->GetInformationObject(0));
  // Get the output
  vtkPolyData* output = vtkPolyData::GetData(outputVector->GetInformationObject(0));

  int modelType =
    this->UseNormalAxis ? vtkPCLFittingModel::PerpendicularPlane : vtkPCLFittingModel::Plane;
  vtkNew<vtkPCLFittingModel> fittingModel;
  int instanceCount = this->UseApproximateDistance ? this->DefaultPlaneNumber : this->PlaneNumber;
  if (instanceCount == 0)
  {
    vtkErrorMacro("At least one plane must be fitted");
    return 0;
  }

  fittingModel->SetInputData(input);
  fittingModel->SetOutputMode(vtkPCLFittingModel::AllData);
  fittingModel->SetInstanceCount(instanceCount);
  fittingModel->SetModelType(modelType);
  fittingModel->SetDistanceThreshold(this->DistanceThreshold);
  fittingModel->SetProbability(this->Probability);
  fittingModel->SetMaxIterations(this->MaxIterations);
  fittingModel->SetAxis(this->NormalAxis);
  fittingModel->SetThresholdAngle(this->ThresholdAngle);
  fittingModel->SetShowMeshes(false);
  fittingModel->Update();

  output->ShallowCopy(fittingModel->GetOutput());

  // Get the models coefficients
  std::vector<Eigen::VectorXf> coefficients = fittingModel->GetModelCoefficients();

  if (coefficients.empty())
  {
    vtkErrorMacro("Plane not found");
    return 0;
  }

  // Get the main plane index
  unsigned int planeIndex = instanceCount - 1;

  if (this->UseApproximateDistance)
  {
    // Get the plane with the largest distance to the origin
    float minDistance = std::numeric_limits<float>::max();

    Eigen::Vector3f normalAxis(this->NormalAxis[0], this->NormalAxis[1], this->NormalAxis[2]);
    for (unsigned int i = 0; i < coefficients.size(); ++i)
    {
      // Get the plane coefficients
      Eigen::VectorXf& coeffs = coefficients.at(i);
      Eigen::Vector3f pointInPlane = -coeffs.head(3) * coeffs[3];
      Eigen::Vector3f relativeNormal = pointInPlane.normalized();

      int sign = relativeNormal.dot(normalAxis) > 0 ? 1 : -1;
      float distanceToOrigin = pointInPlane.norm();
      float signedDistanceToOrigin = distanceToOrigin * sign;
      float deltaDistance = std::abs(signedDistanceToOrigin - this->ApproximateDistance);

      // Keep the plane with the closest distance to the requested distance
      if (deltaDistance < minDistance)
      {
        minDistance = deltaDistance;
        planeIndex = i;
      }
    }
  }

  Eigen::VectorXf planeCoefficients = coefficients.at(planeIndex);

  if (this->RefineCoefficients)
  {
    // Refine the plane coefficients using PCA
    this->RefineRansac(output, planeCoefficients);
  }

  // Hessians normal form last coefficient ax+by+cz-d=0 (distance to the origin)
  double d = planeCoefficients[3];
  Eigen::Vector3f n = planeCoefficients.head(3);
  n.normalize();

  Eigen::Vector3f nPrev = this->PrevPlaneCoeffs.head(3);
  double dPrev = this->PrevPlaneCoeffs[3];

  // Temporal averaging weight
  double t = 1 - this->PrevEstimationWeight;

  // Temporal averaging
  if (this->TemporalAveraging && std::abs(nPrev.norm() - 1) < 1e-3)
  {
    // if the angle between the new normal estimate and the previous is larger than the threshold,
    // the current normal is discarded
    if (std::asin(n.cross(nPrev).norm()) >
      vtkMath::RadiansFromDegrees(this->MaxTemporalAngleChange))
    {
      n = nPrev;
      d = dPrev;
    }
    else
    {
      // update the plane estimate with the current one
      n = t * n + (1 - t) * nPrev;
      d = t * d + (1 - t) * dPrev;
    }
  }
  n.normalize();

  this->PrevPlaneCoeffs[0] = n[0];
  this->PrevPlaneCoeffs[1] = n[1];
  this->PrevPlaneCoeffs[2] = n[2];
  this->PrevPlaneCoeffs[3] = d;

  planeCoefficients = this->PrevPlaneCoeffs;

  Eigen::Vector3f planeNormal = planeCoefficients.head(3);
  // Defin XY plane normal
  Eigen::Vector3f xyPlaneNormal(0, 0, 1);
  // Normalize the vectors
  planeNormal.normalize();

  // Plane orientation
  float cosAngle = planeNormal.dot(xyPlaneNormal);
  // By default, the plane normal is pointing to the positive z axis
  planeNormal = cosAngle < 0 ? -planeNormal : planeNormal;

  // Calculate the rotation quaternion
  Eigen::Quaternionf rotationQuaternion =
    Eigen::Quaternionf::FromTwoVectors(planeNormal, xyPlaneNormal);

  vtkPoints* points = input->GetPoints();
  vtkNew<vtkPoints> rotatedPoints;
  rotatedPoints->SetNumberOfPoints(points->GetNumberOfPoints());
  rotatedPoints->SetDataTypeToFloat();

  vtkPolyData* fittedModelOutput = vtkPolyData::SafeDownCast(fittingModel->GetOutput(0));
  vtkUnsignedIntArray* objectArray =
    vtkUnsignedIntArray::SafeDownCast(fittedModelOutput->GetPointData()->GetArray("instance"));

  // Rotate the points
  vtkSMPTools::For(0,
    points->GetNumberOfPoints(),
    [&](vtkIdType i, vtkIdType end)
    {
      for (; i < end; ++i)
      {
        double point[3];
        points->GetPoint(i, point);
        Eigen::Vector3f pointVector(point[0], point[1], point[2]);
        Eigen::Vector3f rotatedPointVector = rotationQuaternion * pointVector;
        // Flip the point cloud if requested
        if (this->Flip)
        {
          Eigen::Quaternionf q(Eigen::AngleAxisf(M_PI, Eigen::Vector3f::UnitX()));
          rotatedPointVector = q * rotatedPointVector;
        }
        rotatedPoints->SetPoint(
          i, rotatedPointVector[0], rotatedPointVector[1], rotatedPointVector[2]);
      }
    });

  // Align the points to the origin oh height
  int numberOfInliers = 0;
  if (this->TranslateToOrigin)
  {
    float averageHeight = 0;
    for (vtkIdType i = 0; i < points->GetNumberOfPoints(); i++)
    {
      // Calculate the average point of the inliers
      if (objectArray->GetValue(i) == planeIndex + 1)
      {
        double point[3];
        rotatedPoints->GetPoint(i, point);
        averageHeight += point[2];
        numberOfInliers++;
      }
    }
    averageHeight = numberOfInliers == 0 ? 0 : averageHeight / static_cast<float>(numberOfInliers);

    // Translate the height of the points to the origin
    vtkSMPTools::For(0,
      rotatedPoints->GetNumberOfPoints(),
      [&](vtkIdType i, vtkIdType end)
      {
        for (; i < end; ++i)
        {
          double point[3];
          rotatedPoints->GetPoint(i, point);
          point[2] -= averageHeight;
          rotatedPoints->SetPoint(i, point);
        }
      });
  }

  // Set the rotated points
  output->SetPoints(rotatedPoints);
  output->GetPointData()->SetActiveScalars("intensity");

  return 1;
}

//----------------------------------------------------------------------------
void vtkPCLPlaneAlignment::RefineRansac(vtkPolyData* poly, Eigen::VectorXf& plane)
{
  vtkUnsignedIntArray* instanceArray =
    vtkUnsignedIntArray::SafeDownCast(poly->GetPointData()->GetArray("instance"));

  Eigen::Vector3d normalPlane(plane[0], plane[1], plane[2]);
  Eigen::Vector3d pointPlane = -plane[3] * normalPlane;
  normalPlane.normalize();

  // Compute inliers
  std::vector<Eigen::Vector3d> inliersPoints;
  for (unsigned int k = 0; k < poly->GetPoints()->GetNumberOfPoints(); ++k)
  {
    Eigen::Vector3d point(poly->GetPoint(k));

    if (std::abs((point - pointPlane).dot(normalPlane)) < this->DistanceThreshold)
    {
      inliersPoints.push_back(point);
      instanceArray->SetValue(k, 1);
    }
    else
    {
      instanceArray->SetValue(k, 0);
    }
  }

  if (inliersPoints.size() < 3)
  {
    vtkWarningMacro("Not enough inliers to refine the plane");
    return;
  }

  // Compute the best plane using PCA on the inliers
  Eigen::MatrixXd centeredSamples(3, inliersPoints.size());
  Eigen::Vector3d center = Eigen::Vector3d::Zero();
  for (unsigned int k = 0; k < inliersPoints.size(); ++k)
  {
    centeredSamples.col(k) = inliersPoints[k];
    center += inliersPoints[k];
  }
  center /= static_cast<double>(inliersPoints.size());

  vtkSMPTools::For(0,
    inliersPoints.size(),
    [&](vtkIdType i, vtkIdType end)
    {
      for (; i < end; ++i)
      {
        centeredSamples.col(i) -= center;
      }
    });

  Eigen::Matrix3d varianceCovariance = centeredSamples * centeredSamples.transpose();
  varianceCovariance /= static_cast<double>(inliersPoints.size());

  // Since the variance covariance matrix is a real symmetric matrix it can be diagonalized in a
  // orthonormal basis
  Eigen::SelfAdjointEigenSolver<Eigen::Matrix3d> eigenSolver(varianceCovariance);

  // PlaneParameters
  normalPlane = eigenSolver.eigenvectors().col(0);
  pointPlane = center;

  plane[0] = normalPlane(0);
  plane[1] = normalPlane(1);
  plane[2] = normalPlane(2);
  plane[3] = -normalPlane.dot(pointPlane);
}

//----------------------------------------------------------------------------
void vtkPCLPlaneAlignment::SetTemporalAveraging(bool value)
{
  this->PrevPlaneCoeffs = Eigen::VectorXf::Zero(4);
  this->TemporalAveraging = value;
  this->Modified();
}
