#include "vtkContinuousPointRegistration.h"

#include "vtkCellData.h"
#include "vtkDoubleArray.h"
#include "vtkFieldData.h"
#include "vtkFloatArray.h"
#include "vtkInformation.h"
#include "vtkInformationVector.h"
#include "vtkLandmarkTransform.h"
#include "vtkLinearTransform.h"
#include "vtkMatrix4x4.h"
#include "vtkObjectFactory.h"
#include "vtkPointData.h"
#include "vtkPolyData.h"
#include "vtkRobustIterativeClosestPointTransform.h"
#include "vtkSMPThreadLocalObject.h"
#include "vtkSMPTools.h"
#include "vtkSmartPointer.h"
#include "vtkStreamingDemandDrivenPipeline.h"
#include "vtkTransform.h"
#include "vtkTransformPolyDataFilter.h"

#include <sstream>
#include <string.h>

vtkStandardNewMacro(vtkContinuousPointRegistration);

vtkCxxSetObjectMacro(vtkContinuousPointRegistration, RegistrationAlgorithm, vtkLinearTransform);

//------------------------------------------------------------------------------
void vtkContinuousPointRegistration::CreateDefaultRegistration()
{
  if (this->RegistrationAlgorithm)
  {
    this->RegistrationAlgorithm->Delete();
    this->RegistrationAlgorithm = nullptr;
  }

  // Default registration algorithm
  auto registrationAlgorithm = vtkRobustIterativeClosestPointTransform::New();
  registrationAlgorithm->GetLandmarkTransform()->SetModeToRigidBody();
  registrationAlgorithm->SetMaximumNumberOfIterations(30);
  registrationAlgorithm->StartByMatchingCentroidsOff();
  registrationAlgorithm->SetThresholdParameter(0.5);
  registrationAlgorithm->SetMaximumMeanDistance(0.001);
  registrationAlgorithm->SetCheckMeanDistance(true);
  registrationAlgorithm->SetMeanDistanceModeToRMS();
  registrationAlgorithm->SetMaximumNumberOfLandmarks(500);
  this->RegistrationAlgorithm = registrationAlgorithm;
}

//------------------------------------------------------------------------------
vtkContinuousPointRegistration::vtkContinuousPointRegistration()
{
  this->SetNumberOfInputPorts(1);
  this->SetNumberOfOutputPorts(1);

  this->LastTarget = nullptr;
  this->RegistrationAlgorithm = nullptr;
  this->ExplicitTransform = false;
  this->TemporalRegistration = true;
  this->DeepCopyInput = false;

  this->UserTransform = vtkTransform::New();
  this->UserTransform->Identity();
  this->TimeSteps = nullptr;
  this->NTimeSteps = 0;
  this->TimeStepIndex = 0;
  this->TimeStep = -1.0; // TODO: Can we use this for repeated registrations
  this->CreateDefaultRegistration();
}

//------------------------------------------------------------------------------
vtkContinuousPointRegistration::~vtkContinuousPointRegistration()
{
  this->SetRegistrationAlgorithm(nullptr);

  ReleaseLastTarget();
  if (this->TimeSteps)
  {
    delete this->TimeSteps;
    this->TimeSteps = nullptr;
  }
  this->UserTransform->Delete();
}

//------------------------------------------------------------------------------
int vtkContinuousPointRegistration::RequestUpdateTime(vtkInformation* vtkNotUsed(request),
  vtkInformationVector** inputVector, vtkInformationVector* outputVector)
{
  vtkDebugMacro(<< __PRETTY_FUNCTION__);

  // We acquire upstream time steps here.
  vtkInformation* inInfo = inputVector[0]->GetInformationObject(0);
  vtkInformation* outInfo = outputVector->GetInformationObject(0);

  int UpStreamNTimeSteps = -1;

  if (inInfo->Has(vtkStreamingDemandDrivenPipeline::TIME_STEPS()))
  {
    UpStreamNTimeSteps = inInfo->Length(vtkStreamingDemandDrivenPipeline::TIME_STEPS());
    vtkDebugMacro(<< "Upstream time steps: " << UpStreamNTimeSteps);

    double* UpStreamTimeSteps = inInfo->Get(vtkStreamingDemandDrivenPipeline::TIME_STEPS());
    double* UpStreamTimeRange = inInfo->Get(vtkStreamingDemandDrivenPipeline::TIME_RANGE());

    this->NTimeSteps = UpStreamNTimeSteps;
    // Replace our internal buffer with time steps. We could keep data in a circular buffer
    if (this->TimeSteps)
    {
      delete this->TimeSteps;
      this->TimeSteps = nullptr;
    }
    this->TimeSteps = new double[this->NTimeSteps];
    memcpy(this->TimeSteps, UpStreamTimeSteps, this->NTimeSteps * sizeof(double));

    // Set downstream time-steps
    assert(this->TimeSteps[0] == UpStreamTimeRange[0]);
    assert(this->TimeSteps[this->NTimeSteps - 1] == UpStreamTimeRange[1]);
    outInfo->Set(vtkStreamingDemandDrivenPipeline::TIME_RANGE(), UpStreamTimeRange, 2);
    outInfo->Set(
      vtkStreamingDemandDrivenPipeline::TIME_STEPS(), UpStreamTimeSteps, UpStreamNTimeSteps);
  }
  return 1;
}

//------------------------------------------------------------------------------
int vtkContinuousPointRegistration::RequestUpdateTimeDependentInformation(
  vtkInformation* vtkNotUsed(request), vtkInformationVector** inputVector,
  vtkInformationVector* outputVector)
{
  vtkDebugMacro(<< __PRETTY_FUNCTION__);
  // Called after ::RequestUpdateTime
  return 1;
}

//------------------------------------------------------------------------------
int vtkContinuousPointRegistration::RequestInformation(
  vtkInformation* request, vtkInformationVector** inputVector, vtkInformationVector* outputVector)
{
  vtkDebugMacro(<< __PRETTY_FUNCTION__);
  return this->Superclass::RequestInformation(request, inputVector, outputVector);
}

//------------------------------------------------------------------------------
int vtkContinuousPointRegistration::RequestUpdateExtent(
  vtkInformation* request, vtkInformationVector** inputVector, vtkInformationVector* outputVector)
{
  vtkDebugMacro(<< __PRETTY_FUNCTION__);

  vtkInformation* inInfo = inputVector[0]->GetInformationObject(0);
  vtkInformation* outInfo = outputVector->GetInformationObject(0);

  if (outInfo->Has(vtkStreamingDemandDrivenPipeline::UPDATE_TIME_STEP()))
  {

    double nextTimeStep = outInfo->Get(vtkStreamingDemandDrivenPipeline::UPDATE_TIME_STEP());

    vtkDebugMacro(<< "Time is requested: " << nextTimeStep);

    int iPrevious = 0;
    while (iPrevious < (this->NTimeSteps - 1) && this->TimeSteps[iPrevious] < nextTimeStep)
    {
      iPrevious++;
    }
    iPrevious--;

    if (!(iPrevious < 0))
    {
      // TODO: Verify last target exists
      double previousTimeStep = this->TimeSteps[iPrevious];
      inInfo->Set(vtkStreamingDemandDrivenPipeline::UPDATE_TIME_STEP(), previousTimeStep);
      vtkPolyData* testMe = vtkPolyData::SafeDownCast(inInfo->Get(vtkDataObject::DATA_OBJECT()));
      vtkPolyData* lastTarget = testMe->NewInstance();
      lastTarget->ShallowCopy(testMe);
      this->SetLastTarget(lastTarget);
      vtkDebugMacro(<< "Previous time: " << previousTimeStep);
    }
    inInfo->Set(vtkStreamingDemandDrivenPipeline::UPDATE_TIME_STEP(), nextTimeStep);
  }
  else
  {
    vtkDebugMacro(<< "Fallback if no time is requested");
    vtkDataObject* dataObject = inInfo->Get(vtkDataObject::DATA_OBJECT());
    if (dataObject)
    {
      // Data objects have no information
      vtkInformation* inDataInformation = dataObject->GetInformation();
      if (inDataInformation->Has(vtkDataObject::DATA_TIME_STEP()))
      {
        double timeStep = inDataInformation->Get(vtkDataObject::DATA_TIME_STEP());
        inInfo->Set(vtkStreamingDemandDrivenPipeline::UPDATE_TIME_STEP(), timeStep);
      }
    }
  }

  return this->Superclass::RequestUpdateExtent(request, inputVector, outputVector);
}

//------------------------------------------------------------------------------
int vtkContinuousPointRegistration::RequestData(vtkInformation* vtkNotUsed(request),
  vtkInformationVector** inputVector, vtkInformationVector* outputVector)
{
  vtkDebugMacro(<< __PRETTY_FUNCTION__);

  vtkInformation* inInfo = inputVector[0]->GetInformationObject(0);
  vtkInformation* outInfo = outputVector->GetInformationObject(0);

  vtkDataSet* input = vtkDataSet::SafeDownCast(inInfo->Get(vtkDataObject::DATA_OBJECT()));
  vtkDataSet* output = vtkDataSet::SafeDownCast(outInfo->Get(vtkDataObject::DATA_OBJECT()));

  if (!input || !output)
  {
    return 0;
  }

  // Point and cell data must be propagated - always
  vtkPointData *pd = input->GetPointData(), *outPD = output->GetPointData();
  vtkCellData *cd = input->GetCellData(), *outCD = output->GetCellData();

  // Field data with user transform
  vtkNew<vtkDoubleArray> affineArray;
  affineArray->SetName("RegTransform");
  affineArray->SetNumberOfValues(16);
  affineArray->SetNumberOfComponents(4);

  vtkDataSet* tmpOutput = nullptr;

  vtkNew<vtkTransformPolyDataFilter> correctTransform;

  // TODO: Check time of dataobject
  //       Only marks last target valid if registered
  if (!this->LastTarget)
  {
    // Output = input
    tmpOutput = input;

    // TODO: Could we simply write this->UserTransform???
    vtkNew<vtkMatrix4x4> userTransform;
    userTransform->Identity();
    memcpy(affineArray->GetPointer(0), userTransform->GetData(), 16 * sizeof(double));
  }
  else
  {
    vtkDebugMacro(<< "Executing registration");

    // Perform ICP

    // HACK since VTK does not have vtkLinearDataSetTransform
    vtkRobustIterativeClosestPointTransform* algorithm =
      vtkRobustIterativeClosestPointTransform::SafeDownCast(this->RegistrationAlgorithm);
    algorithm->SetTarget(this->LastTarget);
    algorithm->SetSource(input);
    algorithm->Modified();
    algorithm->Update();

    // The absolute registration since first frame
    this->UserTransform->PostMultiply();
    this->UserTransform->Concatenate(algorithm);
    this->UserTransform->PreMultiply();

    vtkDebugMacro(<< "Done registration");

    this->UserTransform->SetMatrix(this->UserTransform->GetMatrix());

    if (this->ExplicitTransform)
    {
      correctTransform->SetInputData(input);
      correctTransform->SetTransform(this->UserTransform);
      correctTransform->Update();
      tmpOutput = correctTransform->GetOutput();
      vtkNew<vtkTransform> identity;
      identity->Identity();
      memcpy(affineArray->GetPointer(0), identity->GetMatrix()->GetData(), 16 * sizeof(double));
    }
    else
    {
      tmpOutput = input;
      memcpy(affineArray->GetPointer(0), this->UserTransform->GetMatrix()->GetData(),
        16 * sizeof(double));
    }
    vtkDebugMacro(<< "Copied data for LastTarget");
  }

  // Update data for output

  vtkPolyData* polyDataOutput = vtkPolyData::SafeDownCast(output);
  vtkPolyData* polyTmpOutput = vtkPolyData::SafeDownCast(tmpOutput);
  if (polyDataOutput && polyTmpOutput)
  {
    polyDataOutput->SetPoints(polyTmpOutput->GetPoints());
    polyDataOutput->SetVerts(polyTmpOutput->GetVerts());
    polyDataOutput->SetLines(polyTmpOutput->GetLines());
    polyDataOutput->SetPolys(polyTmpOutput->GetPolys());
    polyDataOutput->SetStrips(polyTmpOutput->GetStrips());
  }
  // Pass point and cell data
  outPD->PassData(pd);
  outCD->PassData(cd);

  this->TimeStepIndex++;

  // Update field data
  vtkFieldData* fieldData = output->GetFieldData();
  affineArray->Modified();
  fieldData->AddArray(affineArray);

  return 1;
}

//------------------------------------------------------------------------------
void vtkContinuousPointRegistration::PrintSelf(ostream& os, vtkIndent indent)
{
  this->Superclass::PrintSelf(os, indent);
  os << indent << "DeepCopyInput: " << (this->DeepCopyInput ? "on" : "off") << endl;
}

//------------------------------------------------------------------------------
void vtkContinuousPointRegistration::SetLastTarget(vtkDataSet* lastTarget)
{
  // TODO: Try using deep copy
  //       Next step is to fetch this from upstream time information
  if (this->LastTarget == lastTarget)
    return;

  if (this->LastTarget)
  {
    this->ReleaseLastTarget();
  }

  if (lastTarget)
  {
    lastTarget->Register(this);
    this->LastTarget = lastTarget;
  }
  else
  {
    this->UserTransform->Identity();
  }
}

//------------------------------------------------------------------------------
void vtkContinuousPointRegistration::ReleaseLastTarget()
{
  vtkDebugMacro(<< __PRETTY_FUNCTION__);
  if (this->LastTarget)
  {
    this->LastTarget->UnRegister(this);
    this->LastTarget = nullptr;
  }
}
