/*=========================================================================

  Program:   Visualization Toolkit
  Module:    vtkAnisotropicLandmarkTransform.cxx

  Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
  All rights reserved.
  See Copyright.txt or http://www.kitware.com/Copyright.htm for details.

  This software is distributed WITHOUT ANY WARRANTY; without even
  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
  PURPOSE.  See the above copyright notice for more information.

  =========================================================================*/
#include "vtkAnisotropicLandmarkTransform.h"

#include "vtkMath.h"
#include "vtkMatrix4x4.h"
#include "vtkObjectFactory.h"
#include "vtkPoints.h"

#include "vtk_eigen.h"
#include VTK_EIGEN(Dense)

vtkStandardNewMacro(vtkAnisotropicLandmarkTransform);

//----------------------------------------------------------------------------
vtkAnisotropicLandmarkTransform::vtkAnisotropicLandmarkTransform()
{
  this->SourceLandmarks=nullptr;
  this->TargetLandmarks=nullptr;
  this->Matrix->Identity();
  this->RotationMatrix = vtkMatrix3x3::New();
  this->ScalingMatrix = vtkMatrix3x3::New();
  this->RotationMatrix->Identity();
  this->ScalingMatrix->Identity();
  this->Threshold = 1e-9;
}

//----------------------------------------------------------------------------
vtkAnisotropicLandmarkTransform::~vtkAnisotropicLandmarkTransform()
{
  if(this->RotationMatrix) {
    this->RotationMatrix->Delete();
  }
  if(this->ScalingMatrix) {
    this->ScalingMatrix->Delete();
  }
}

//----------------------------------------------------------------------------
void vtkAnisotropicLandmarkTransform::PrintSelf(ostream& os, vtkIndent indent)
{
  this->Superclass::PrintSelf(os, indent);
  os << "SourceLandmarks: " << this->SourceLandmarks << "\n";
  if(this->SourceLandmarks)
  {
    this->SourceLandmarks->PrintSelf(os,indent.GetNextIndent());
  }
  os << "TargetLandmarks: " << this->TargetLandmarks << "\n";
  if(this->TargetLandmarks)
  {
    this->TargetLandmarks->PrintSelf(os,indent.GetNextIndent());
  }
}

//----------------------------------------------------------------------------
// Update the 4x4 matrix. Updates are only done as necessary.

void vtkAnisotropicLandmarkTransform::InternalUpdate()
{
  vtkIdType i;
  int j;

  if (this->SourceLandmarks == nullptr || this->TargetLandmarks == nullptr)
  {
    this->Matrix->Identity();
    return;
  }

  // --- compute the necessary transform to match the two sets of landmarks ---

  const vtkIdType N_PTS = this->SourceLandmarks->GetNumberOfPoints();
  if(N_PTS != this->TargetLandmarks->GetNumberOfPoints())
  {
    vtkErrorMacro("Update: Source and Target Landmarks contain a different number of points");
    return;
  }

  // -- if no points, stop here

  if (N_PTS == 0)
  {
    this->Matrix->Identity();
    return;
  }

  /*
    The solution is based on:
    Mohammed Bennani Dosse and Jos Ten Berge (2010),
    "Anisotropic Orthogonal Procrustes Analysis"
    Journal of Classification 27:111-128
  */

  //if there is only one point just find the translation
  if(N_PTS == 1) {
    double p1[3];
    double p2[3];
    this->SourceLandmarks->GetPoint(0, p1);
    this->TargetLandmarks->GetPoint(0, p2);
    this->Matrix->Element[0][3] = p2[0] - p1[0];
    this->Matrix->Element[1][3] = p2[1] - p1[1];
    this->Matrix->Element[2][3] = p2[2] - p1[2];
  }

  //put the points into eigen matrices
  Eigen::MatrixXd X(3, N_PTS);
  Eigen::MatrixXd Y(3, N_PTS);
  double p1[3];
  double p2[3];
  for(int i=0; i < N_PTS; i++)
  {
    this->SourceLandmarks->GetPoint(i, p1);
    this->TargetLandmarks->GetPoint(i, p2);
    for(int j=0; j < 3; j++)
    {
      X(j, i) = p1[j];
      Y(j, i) = p2[j];
    }
  }

  //find centroids
  Eigen::Vector3d X_centroid = X.rowwise().mean();
  Eigen::Vector3d Y_centroid = Y.rowwise().mean();

  //translate input by centroids
  Eigen::MatrixXd X_trans = X.colwise() - X_centroid;
  Eigen::MatrixXd Y_trans = Y.colwise() - Y_centroid;

  //normalise translated source input
  Eigen::MatrixXd X_norm = X_trans.rowwise().normalized();

  //Find the cross covariance matrix
  Eigen::Matrix3d B = Y_trans * X_norm.transpose();

  //use SVD to decompose B such that B=USV^T
  Eigen::JacobiSVD<Eigen::Matrix3d> svd(B, Eigen::ComputeFullU | Eigen::ComputeFullV);
  Eigen::Matrix3d U = svd.matrixU();
  Eigen::Matrix3d V = svd.matrixV().transpose();

  //create D where the last element is the determinant of UV
  Eigen::Matrix3d D(3, 3);
  D = Eigen::Matrix3d::Identity();
  D(2, 2) = (U*V).determinant();

  //Find the rotation
  Eigen::Matrix3d Q(3, 3);
  Q = U * D * V;

  //Calculate the FRE for the given rotation
  Eigen::MatrixXd FRE_vect(3, N_PTS);
  FRE_vect = Y_trans - Q.transpose() * X_norm;
  double FRE = sqrt(FRE_vect.squaredNorm()/N_PTS);

  //starting FRE value to compare to
  double FRE_orig = 2.0 * (FRE + this->Threshold);

  //majorisation to solve for rotation
  while(fabs(FRE_orig - FRE) > this->Threshold) {
    //recompute SVD values
    Eigen::Matrix3d QB = Q.transpose() * B;
    Eigen::Matrix3d I = Eigen::Matrix3d::Zero();
    I(0,0) = QB(0,0);
    I(1,1) = QB(1,1);
    I(2,2) = QB(2,2);

    svd.compute(B * I, Eigen::ComputeFullU | Eigen::ComputeFullV);
    U = svd.matrixU();
    V = svd.matrixV().transpose();
    //recompute rotation
    D(2, 2) = (U*V).determinant();
    Q = U * D * V;

    //recompute FRE value
    FRE_vect = Y_trans - Q.transpose() * X_norm;
    FRE_orig = FRE;
    FRE = 0.0f;
    FRE = sqrt(FRE_vect.squaredNorm()/N_PTS);
  }

  //calculate final scaling
  B = Y_trans * X_trans.transpose();
  U = B.transpose() * Q;
  V = X_trans * X_trans.transpose();
  Eigen::Matrix3d A(3, 3);
  A = Eigen::Matrix3d::Zero();
  A(0, 0) = U(0, 0)/V(0, 0);
  A(1, 1) = U(1, 1)/V(1, 1);
  A(2, 2) = U(2, 2)/V(2, 2);

  //calculate final translation
  Eigen::Vector3d t(3);
  t =  Y_centroid - (Q * (A * X_centroid));

  //calculate final FRE values
  FRE_vect = Y - ((Q * (A * X)).colwise() + t);
  FRE = sqrt(FRE_vect.squaredNorm()/N_PTS);

  //set scaling and rotation matrices sperate to regular
  //transformation matrix output
  this->RotationMatrix->Identity();
  this->ScalingMatrix->Identity();
  for(int i=0; i < 3; i ++)
  {
    this->RotationMatrix->SetElement(i, 0, Q(i, 0));
    this->RotationMatrix->SetElement(i, 1, Q(i, 1));
    this->RotationMatrix->SetElement(i, 2 ,Q(i, 2));
  }
  this->ScalingMatrix->SetElement(0, 0, A(0, 0));
  this->ScalingMatrix->SetElement(1, 1, A(1, 1));
  this->ScalingMatrix->SetElement(2, 2, A(2, 2));

  Q = Q * A;

  //fill rotation into the matrix
  for(int i=0; i < 3; i ++)
  {
    this->Matrix->Element[i][0] = Q(i, 0);
    this->Matrix->Element[i][1] = Q(i, 1);
    this->Matrix->Element[i][2] = Q(i, 2);
  }

  //fill translation into the matrix
  this->Matrix->Element[0][3] = t(0);
  this->Matrix->Element[1][3] = t(1);
  this->Matrix->Element[2][3] = t(2);

  // fill the bottom row of the 4x4 matrix
  this->Matrix->Element[3][0] = 0.0;
  this->Matrix->Element[3][1] = 0.0;
  this->Matrix->Element[3][2] = 0.0;
  this->Matrix->Element[3][3] = 1.0;

  this->Matrix->Modified();
}

//------------------------------------------------------------------------
vtkMTimeType vtkAnisotropicLandmarkTransform::GetMTime()
{
  vtkMTimeType result = this->vtkLinearTransform::GetMTime();
  vtkMTimeType mtime;

  if (this->SourceLandmarks)
  {
    mtime = this->SourceLandmarks->GetMTime();
    if (mtime > result)
    {
      result = mtime;
    }
  }
  if (this->TargetLandmarks)
  {
    mtime = this->TargetLandmarks->GetMTime();
    if (mtime > result)
    {
      result = mtime;
    }
  }
  return result;
}

//----------------------------------------------------------------------------
void vtkAnisotropicLandmarkTransform::SetThreshold(double threshold)
{
  this->Threshold = threshold;
}

//----------------------------------------------------------------------------
void vtkAnisotropicLandmarkTransform::Inverse()
{
  vtkPoints *tmp1 = this->SourceLandmarks;
  vtkPoints *tmp2 = this->TargetLandmarks;
  this->TargetLandmarks = tmp1;
  this->SourceLandmarks = tmp2;
  this->Modified();
}

//----------------------------------------------------------------------------
vtkAbstractTransform *vtkAnisotropicLandmarkTransform::MakeTransform()
{
  return vtkAnisotropicLandmarkTransform::New();
}

//----------------------------------------------------------------------------
void vtkAnisotropicLandmarkTransform::InternalDeepCopy(vtkAbstractTransform *transform)
{
  vtkAnisotropicLandmarkTransform *t = (vtkAnisotropicLandmarkTransform *)transform;

  this->SetMode(t->Mode);
  this->SetSourceLandmarks(t->SourceLandmarks);
  this->SetTargetLandmarks(t->TargetLandmarks);

  this->Modified();
}
