#include "vtkCellTreeLocator.h"
#include "vtkDataArray.h"
#include "vtkGenericCell.h"
#include "vtkNew.h"
#include "vtkPointData.h"
#include "vtkSphereSource.h"

#include <array>

// Note that:
// vtkCellTreeLocator moved from vtkFiltersGeneral to vtkCommonDataModel in
// VTK commit 4a29e6f7dd9acb460644fe487d2e80aac65f7be9

int main(int, char*[])
{
  vtkNew<vtkSphereSource> sphere0;
  sphere0->SetCenter(0.0, 0.0, 0.0);
  sphere0->SetRadius(1.0);
  sphere0->Update();

  // Create the cell locator tree.
  vtkNew<vtkCellTreeLocator> cellTree;
  cellTree->SetDataSet(sphere0->GetOutput());
  cellTree->BuildLocator();

  //  These two points should not be on the sphere.
  double testInside[3] = {0.5, 0.0, 0.0};
  double testOutside[3] = {10.0, 0.0, 0.0};
  double tol = 0.0;
  double pcoords[3];
  double weights[3];

  vtkIdType cellId;

  vtkNew<vtkGenericCell> cell;

  int returnValue = EXIT_SUCCESS;

  //  A point on the sphere.
  std::array<double, 3> sourcePt{0.0, 0.0, 0.0};
  sphere0->GetOutput()->GetPoint(0, sourcePt.data());
  cellId = cellTree->FindCell(sourcePt.data(), tol, cell, pcoords, weights);
  if (cellId >= 0)
  {
    std::cout << "Point 0 on the sphere is in cell " << cellId << "."
              << std::endl;
    // Find the midpoint in the cell and check if it is in the same cell.
    if (cellId >= 0)
    {
      auto bounds = cell->GetBounds();
      std::array<double, 3> midPt{0.0, 0.0, 0.0};
      auto i = 0;
      for (auto j = 0, inc = 2; j < 6; j += inc)
      {
        midPt[i] = bounds[j] + (bounds[j + 1] - bounds[j]) / 2.0;
        ++i;
      }
      auto cellIdMidPt =
          cellTree->FindCell(midPt.data(), tol, cell, pcoords, weights);
      if (cellIdMidPt != cellId)
      {
        std::cout << "ERROR: The cell midpoint should be in the same cell."
                  << std::endl;
        returnValue = EXIT_FAILURE;
      }
    }
  }
  else
  {
    std::cout << "ERROR: The cell corresponding to point 0 on the sphere"
              << " was not found but should have been." << std::endl;
    returnValue = EXIT_FAILURE;
  }

  // Should be inside the sphere.
  cellId = cellTree->FindCell(testInside, tol, cell, pcoords, weights);
  if (cellId >= 0)
  {
    std::cout << "testInside point is in cell " << cellId
              << " of the sphere but it should not be in the cell."
              << std::endl;
    returnValue = EXIT_FAILURE;
  }
  else
  {
    std::cout << "testInside point is inside the sphere." << std::endl;
  }

  // Should be outside.
  cellId = cellTree->FindCell(testOutside, 0, cell, pcoords, weights);
  if (cellId >= 0)
  {
    std::cout << "testOutside point is in cell " << cellId
              << " of the sphere but it should not be in the cell."
              << std::endl;
    returnValue = EXIT_FAILURE;
  }
  else
  {
    std::cout << "testOutside point is outside the sphere." << std::endl;
  }

  auto numberOfPoints = sphere0->GetOutput()->GetNumberOfPoints();
  auto countOfPoints = 0;
  for (auto i = 0; i < numberOfPoints; ++i)
  {
    sphere0->GetOutput()->GetPoint(0, sourcePt.data());
    cellId = cellTree->FindCell(sourcePt.data(), tol, cell, pcoords, weights);
    if (cellId >= 0)
    {
      ++countOfPoints;
    }
  }

  if (countOfPoints != numberOfPoints)
  {
    auto numMissed = numberOfPoints - countOfPoints;
    std::cout << "ERROR: " << numMissed
              << " points should have been on the sphere!" << std::endl;
    returnValue = EXIT_FAILURE;
  }
  else
  {
    std::cout << "Passed: A total of " << countOfPoints
              << " points on the sphere were detected." << std::endl;
  }

  // This is based on
  // [CellTreeLocator](https://gitlab.kitware.com/vtk/vtk/-/blob/master/Common/DataModel/Testing/Cxx/CellTreeLocator.cxx)
  // Kuhnan's sample code is used to test
  // vtkCellLocator::IntersectWithLine(...9 params...)

  // sphere1: the outer sphere
  vtkNew<vtkSphereSource> sphere1;
  sphere1->SetThetaResolution(100);
  sphere1->SetPhiResolution(100);
  sphere1->SetRadius(1);
  sphere1->Update();

  // sphere2: the inner sphere
  vtkNew<vtkSphereSource> sphere2;
  sphere2->SetThetaResolution(100);
  sphere2->SetPhiResolution(100);
  sphere2->SetRadius(0.8);
  sphere2->Update();

  // The normals obtained from the outer sphere.
  vtkDataArray* sphereNormals =
      sphere1->GetOutput()->GetPointData()->GetNormals();

  // Create the  cell locator.
  vtkNew<vtkCellTreeLocator> locator;
  locator->SetDataSet(sphere2->GetOutput());
  locator->AutomaticOn();
  locator->BuildLocator();

  // Initialise the counter and ray length.
  int numIntersected = 0;
  tol = 0.0000001;
  double rayLen = 1.0 - 0.8 + tol; // = 1 - 0.8 + error tolerance
  int sub_id;
  vtkIdType cell_id;
  double param_t, intersect[3], paraCoord[3];
  double sourcePnt[3], destinPnt[3], normalVec[3];

  // This loop traverses each point on the outer sphere (sphere1)
  //  and looks for an intersection on the inner sphere (sphere2).
  numberOfPoints = sphere1->GetOutput()->GetNumberOfPoints();
  for (int i = 0; i < numberOfPoints; i++)
  {
    sphere1->GetOutput()->GetPoint(i, sourcePnt);
    sphereNormals->GetTuple(i, normalVec);

    // Cast a ray in the negative direction toward sphere1.
    destinPnt[0] = sourcePnt[0] - rayLen * normalVec[0];
    destinPnt[1] = sourcePnt[1] - rayLen * normalVec[1];
    destinPnt[2] = sourcePnt[2] - rayLen * normalVec[2];

    if (locator->IntersectWithLine(sourcePnt, destinPnt, 0.0010, param_t,
                                   intersect, paraCoord, sub_id, cell_id, cell))
    {
      numIntersected++;
    }
  }

  if (numIntersected != numberOfPoints)
  {
    int numMissed = numberOfPoints - numIntersected;
    std::cout << "ERROR: " << numMissed << " ray-sphere intersections missed!"
              << std::endl;
    returnValue = EXIT_FAILURE;
  }
  else
  {
    std::cout << "Passed: A total of " << numberOfPoints
              << " ray-sphere intersections detected." << std::endl;
  }

  if (returnValue != EXIT_FAILURE)
  {
    std::cout << "All checks passed." << std::endl;
  }
  else
  {
    std::cout << "Some checks failed." << std::endl;
  }

  return returnValue;
}
