/*=========================================================================

  Program:   PCL Plugin
  Module:    vtkPCLWallFitting.cxx

  Copyright (c) Kitware, Inc.
  All rights reserved.
  See LICENSE or http://www.apache.org/licenses/LICENSE-2.0 for details.

  This software is distributed WITHOUT ANY WARRANTY; without even
  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
  PURPOSE.  See the above copyright notice for more information.

=========================================================================*/

// Local includes
#include "vtkPCLWallFitting.h"
#include "vtkPCLFittingModel.h"

// VTK includes
#include <vtkInformation.h>
#include <vtkInformationVector.h>
#include <vtkMath.h>
#include <vtkPointData.h>
#include <vtkPolyData.h>
#include <vtkSmartPointer.h>
#include <vtkUnsignedIntArray.h>

//-----------------------------------------------------------------------------
vtkStandardNewMacro(vtkPCLWallFitting);

//-----------------------------------------------------------------------------
void vtkPCLWallFitting::PrintSelf(ostream& os, vtkIndent indent)
{
  this->Superclass::PrintSelf(os, indent);
}

//-----------------------------------------------------------------------------
int vtkPCLWallFitting::RequestData(vtkInformation* vtkNotUsed(request),
  vtkInformationVector** inputVector,
  vtkInformationVector* outputVector)
{
  if (this->NormalAxis[0] == 0 && this->NormalAxis[1] == 0 && this->NormalAxis[2] == 0)
  {
    vtkErrorMacro("NormalAxis is not set");
    return 0;
  }
  // Get the input
  vtkPolyData* input = vtkPolyData::GetData(inputVector[0]->GetInformationObject(0));
  vtkPolyData* output = vtkPolyData::GetData(outputVector->GetInformationObject(0));
  output->ShallowCopy(input);
  this->OutlierIds.clear();
  this->InstanceCount = 0;

  vtkSmartPointer<vtkUnsignedIntArray> objectArray = vtkSmartPointer<vtkUnsignedIntArray>::New();
  objectArray->Allocate(output->GetNumberOfPoints());
  objectArray->SetName("wallIndex");
  objectArray->SetNumberOfTuples(output->GetNumberOfPoints());
  objectArray->Fill(0);
  output->GetPointData()->AddArray(objectArray);

  std::function<bool(const Eigen::VectorXf&)> horizontalConstraints =
    std::bind(&vtkPCLWallFitting::HorizontalConstraints, this, std::placeholders::_1);
  std::function<bool(const Eigen::VectorXf&)> verticalConstraints =
    std::bind(&vtkPCLWallFitting::VerticalConstraints, this, std::placeholders::_1);
  std::function<bool(const Eigen::VectorXf&)> wallConstraints =
    std::bind(&vtkPCLWallFitting::WallConstraints, this, std::placeholders::_1);

  // Fit floor and ceiling planes
  this->FittingPlanesWithConstraints(this->NumberOfFloorCeillingPlanes,
    horizontalConstraints,
    output,
    objectArray,
    PlaneType::Floor);
  // Fit wall planes
  if (this->ConstraintsBetweenWalls && this->NumberOfWallPlanes > 1)
  {
    // Fit the first wall
    this->FittingPlanesWithConstraints(
      1, verticalConstraints, output, objectArray, PlaneType::Wall);
    // Fit the other walls with constraints between them
    this->FittingPlanesWithConstraints(
      this->NumberOfWallPlanes - 1, wallConstraints, output, objectArray, PlaneType::Wall);
  }
  else
  {
    // Fit all the walls without constraints between them
    this->FittingPlanesWithConstraints(
      this->NumberOfWallPlanes, verticalConstraints, output, objectArray, PlaneType::Wall);
  }

  output->GetPointData()->SetActiveScalars("wallIndex");

  return 1;
}

//----------------------------------------------------------------------------
void vtkPCLWallFitting::FittingPlanesWithConstraints(unsigned int instanceCount,
  std::function<bool(const Eigen::VectorXf&)> constraintsFunction,
  vtkPolyData* poly,
  vtkUnsignedIntArray* objectArray,
  PlaneType planeType)
{
  for (unsigned int i = 0; i < instanceCount; i++)
  {
    vtkSmartPointer<vtkPCLFittingModel> fittingModel = this->InitFittingModel(poly);
    fittingModel->SetShowMeshes(false);

    fittingModel->SetModelConstraintsFunction(constraintsFunction);
    if (this->InstanceCount > 0)
    {
      // Restrict the fitting to the outliers of the previous fitting
      fittingModel->SetIndices(this->OutlierIds);
    }
    fittingModel->Update();

    vtkPolyData* fittedModelOutput = vtkPolyData::SafeDownCast(fittingModel->GetOutput(0));
    poly->ShallowCopy(fittedModelOutput);
    // Get the first wall normal axis to use it as constraint for the next wall
    if (this->ConstraintsBetweenWalls && planeType == PlaneType::Wall && this->FirstWall)
    {
      this->FirstWall = false;
      Eigen::VectorXf model = fittingModel->GetModelCoefficients().at(0);
      this->NormalWallAxis = Eigen::Vector3d(model[0], model[1], model[2]);
    }

    vtkSmartPointer<vtkUnsignedIntArray> objectArray2 =
      vtkUnsignedIntArray::SafeDownCast(fittedModelOutput->GetPointData()->GetArray("instance"));
    this->MergeInstanceDataArrays(objectArray, objectArray2);
    this->InstanceCount++;
  }
}

//----------------------------------------------------------------------------
vtkSmartPointer<vtkPCLFittingModel> vtkPCLWallFitting::InitFittingModel(vtkPolyData* poly)
{
  vtkSmartPointer<vtkPCLFittingModel> fittingModel = vtkSmartPointer<vtkPCLFittingModel>::New();
  fittingModel->SetMaxIterations(this->MaxIterations);
  fittingModel->SetProbability(this->Probability);
  fittingModel->SetModelType(vtkPCLFittingModel::NormalPlane);
  fittingModel->SetSampleConsensusType(this->SampleConsensusType);
  fittingModel->SetDistanceThreshold(this->DistanceThreshold);
  fittingModel->SetInstanceCount(1);
  fittingModel->SetProjectInliers(this->ProjectInliers);
  fittingModel->SetInputData(poly);
  return fittingModel;
}

//-----------------------------------------------------------------------------
double vtkPCLWallFitting::ModelVectorAngle(const Eigen::VectorXf& model, Eigen::Vector3d vector)
{
  Eigen::Vector3d normal = vector;
  vector.normalize();
  // Plane normal
  Eigen::Vector3d planeNormal(model[0], model[1], model[2]);
  planeNormal.normalize();

  // Angle between the normal and the plane normal
  return acos(normal.dot(planeNormal));
}

//-----------------------------------------------------------------------------
bool vtkPCLWallFitting::AngularConstraints(const Eigen::VectorXf& model,
  Eigen::Vector3d vector,
  double constraintAngle)
{
  double thresholdAngle = vtkMath::RadiansFromDegrees(this->ThresholdAngle);
  // Angle between the normal and the plane normal
  double angle = this->ModelVectorAngle(model, vector);
  return (abs(angle - constraintAngle) < thresholdAngle);
}

//-----------------------------------------------------------------------------
bool vtkPCLWallFitting::HorizontalConstraints(const Eigen::VectorXf& model)
{
  Eigen::Vector3d normal(this->NormalAxis[0], this->NormalAxis[1], this->NormalAxis[2]);
  return this->AngularConstraints(model, normal, 0);
}

//-----------------------------------------------------------------------------
bool vtkPCLWallFitting::VerticalConstraints(const Eigen::VectorXf& model)
{
  Eigen::Vector3d normal(this->NormalAxis[0], this->NormalAxis[1], this->NormalAxis[2]);
  return this->AngularConstraints(model, normal, vtkMath::Pi() / 2.0);
}

//----------------------------------------------------------------------------
bool vtkPCLWallFitting::WallConstraints(const Eigen::VectorXf& model)
{
  double thresholdAngle = vtkMath::RadiansFromDegrees(this->ThresholdAngle);
  Eigen::Vector3d planeNormal(model[0], model[1], model[2]);
  // Angle between the normal and the plane normal
  double angle = this->ModelVectorAngle(model, planeNormal);

  // Test if the wall is orthogonal to the floor or ceiling
  bool validConstraints = this->VerticalConstraints(model);
  // Test if the wall is perpendicular or orthogonal to the first wall
  validConstraints = validConstraints &&
    (abs(std::fmod(angle, vtkMath::Pi() / 2.0)) < thresholdAngle ||
      (abs(std::fmod(angle, vtkMath::Pi())) < thresholdAngle));

  return validConstraints;
}

//----------------------------------------------------------------------------
void vtkPCLWallFitting::MergeInstanceDataArrays(vtkUnsignedIntArray* instanceArray,
  vtkUnsignedIntArray* newInstanceArray)
{
  this->OutlierIds.clear();
  for (int k = 0; k < instanceArray->GetNumberOfTuples(); ++k)
  {
    if (newInstanceArray->GetValue(k) != 0)
    {
      instanceArray->SetValue(k, 1 + this->InstanceCount);
    }
    if (instanceArray->GetValue(k) == 0)
    {
      this->OutlierIds.push_back(k);
    }
  }
}
