//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//============================================================================

#include "IcoSphere.h"

#include <vtkm/Transform3D.h>
#include <vtkm/VectorAnalysis.h>
#include <vtkm/cont/DataSetBuilderExplicit.h>
#include <vtkm/filter/clean_grid/CleanGrid.h>

namespace beams
{
namespace source
{
namespace detail
{
class IcoSphereCreater
{
public:
  IcoSphereCreater(vtkm::IdComponent refinementLevel)
    : RefinementLevel(refinementLevel)
  {
  }

  void SetData(const std::vector<vtkm::Vec3f>& centers,
               const std::vector<vtkm::FloatDefault>& radii,
               const std::vector<vtkm::FloatDefault>& fieldValues,
               const std::string& fieldName)
  {
    this->Centers = centers;
    this->Radii = radii;
    this->FieldValues = fieldValues;
    this->FieldName = fieldName;
  }

  vtkm::cont::DataSet Create()
  {
    std::vector<vtkm::Id3> finalIndices;
    for (std::size_t sphereIdx = 0; sphereIdx < this->Centers.size(); ++sphereIdx)
    {
      const vtkm::Float32 fieldValue = this->FieldValues[sphereIdx];

      std::vector<vtkm::Id3> triangleIndices;
      const vtkm::FloatDefault phi = (1.0f + vtkm::Sqrt(5.0f)) * 0.5f;

      // Create 12 vertices of the icosahedron
      vtkm::Id startIndex = this->AddVertex(vtkm::Vec3f(-1.0f, phi, 0), fieldValue);
      this->AddVertex(vtkm::Vec3f(1.0f, phi, 0), fieldValue);
      this->AddVertex(vtkm::Vec3f(-1.0f, -phi, 0), fieldValue);
      this->AddVertex(vtkm::Vec3f(1.0f, -phi, 0), fieldValue);
      this->AddVertex(vtkm::Vec3f(0.0f, -1.0f, phi), fieldValue);
      this->AddVertex(vtkm::Vec3f(0.0f, 1.0f, phi), fieldValue);
      this->AddVertex(vtkm::Vec3f(0.0f, -1.0f, -phi), fieldValue);
      this->AddVertex(vtkm::Vec3f(0.0f, 1.0f, -phi), fieldValue);
      this->AddVertex(vtkm::Vec3f(phi, 0.0f, -1.0f), fieldValue);
      this->AddVertex(vtkm::Vec3f(phi, 0.0f, 1.0f), fieldValue);
      this->AddVertex(vtkm::Vec3f(-phi, 0.0f, -1.0f), fieldValue);
      this->AddVertex(vtkm::Vec3f(-phi, 0.0f, 1.0f), fieldValue);

      // Create 20 faces of the icosahedron
      vtkm::Id3 indexOffset{ startIndex };
      // 5 faces around point 0
      triangleIndices.push_back(indexOffset + vtkm::Id3(0, 11, 5));
      triangleIndices.push_back(indexOffset + vtkm::Id3(0, 5, 1));
      triangleIndices.push_back(indexOffset + vtkm::Id3(0, 1, 7));
      triangleIndices.push_back(indexOffset + vtkm::Id3(0, 7, 10));
      triangleIndices.push_back(indexOffset + vtkm::Id3(0, 10, 11));

      // 5 adjacent faces
      triangleIndices.push_back(indexOffset + vtkm::Id3(1, 5, 9));
      triangleIndices.push_back(indexOffset + vtkm::Id3(5, 11, 4));
      triangleIndices.push_back(indexOffset + vtkm::Id3(11, 10, 2));
      triangleIndices.push_back(indexOffset + vtkm::Id3(10, 7, 6));
      triangleIndices.push_back(indexOffset + vtkm::Id3(7, 1, 8));

      // 5 faces around point 3
      triangleIndices.push_back(indexOffset + vtkm::Id3(3, 9, 4));
      triangleIndices.push_back(indexOffset + vtkm::Id3(3, 4, 2));
      triangleIndices.push_back(indexOffset + vtkm::Id3(3, 2, 6));
      triangleIndices.push_back(indexOffset + vtkm::Id3(3, 6, 8));
      triangleIndices.push_back(indexOffset + vtkm::Id3(3, 8, 9));

      // 5 adjacent faces
      triangleIndices.push_back(indexOffset + vtkm::Id3(4, 9, 5));
      triangleIndices.push_back(indexOffset + vtkm::Id3(2, 4, 11));
      triangleIndices.push_back(indexOffset + vtkm::Id3(6, 2, 10));
      triangleIndices.push_back(indexOffset + vtkm::Id3(8, 6, 7));
      triangleIndices.push_back(indexOffset + vtkm::Id3(9, 8, 1));

      // refine triangles
      for (vtkm::IdComponent i = 0; i < this->RefinementLevel; ++i)
      {
        std::vector<vtkm::Id3> refinedTriangleIndices;
        for (const vtkm::Id3& triIndices : triangleIndices)
        {
          vtkm::Id v1 = this->GetMidPointId(triIndices[0], triIndices[1], fieldValue);
          vtkm::Id v2 = this->GetMidPointId(triIndices[1], triIndices[2], fieldValue);
          vtkm::Id v3 = this->GetMidPointId(triIndices[2], triIndices[0], fieldValue);

          refinedTriangleIndices.push_back(vtkm::Id3(triIndices[0], v1, v3));
          refinedTriangleIndices.push_back(vtkm::Id3(triIndices[1], v2, v1));
          refinedTriangleIndices.push_back(vtkm::Id3(triIndices[2], v3, v2));
          refinedTriangleIndices.push_back(vtkm::Id3(v1, v2, v3));
        }
        triangleIndices = refinedTriangleIndices;
      }
      vtkm::Id endIndex = this->Vertices.size();

      vtkm::Vec3f center = this->Centers[sphereIdx];
      vtkm::Float32 radius = this->Radii[sphereIdx];
      vtkm::Matrix<vtkm::FloatDefault, 4, 4> transform;
      vtkm::MatrixIdentity(transform);
      transform = vtkm::MatrixMultiply(transform, vtkm::Transform3DTranslate(center));
      transform = vtkm::MatrixMultiply(transform, vtkm::Transform3DScale(radius, radius, radius));
      for (vtkm::Id ti = startIndex; ti < endIndex; ++ti)
      {
        this->Vertices[ti] = vtkm::Transform3DPoint(transform, this->Vertices[ti]);
      }

      finalIndices.insert(finalIndices.end(), triangleIndices.begin(), triangleIndices.end());
    }

    vtkm::cont::DataSetBuilderExplicitIterative builder;
    for (const vtkm::Vec3f& vertex : this->Vertices)
    {
      builder.AddPoint(vertex);
    }
    for (const vtkm::Id3& triIndices : finalIndices)
    {
      builder.AddCell(vtkm::CELL_SHAPE_TRIANGLE);
      builder.AddCellPoint(triIndices[0]);
      builder.AddCellPoint(triIndices[1]);
      builder.AddCellPoint(triIndices[2]);
    }

    vtkm::cont::DataSet result = builder.Create();
    result.AddPointField(this->FieldName,
                         vtkm::cont::make_ArrayHandle(this->VertexValues, vtkm::CopyFlag::On));
    return result;
  }

private:
  vtkm::Id GetMidPointId(vtkm::Id id1, vtkm::Id id2, vtkm::Float32 fieldValue)
  {
    vtkm::Int64 key = this->CreateCacheKey(id1, id2);
    if (this->MidPointIndexCache.find(key) != this->MidPointIndexCache.end())
    {
      return this->MidPointIndexCache[key];
    }

    vtkm::Vec3f& p1 = this->Vertices[id1];
    vtkm::Vec3f& p2 = this->Vertices[id2];
    vtkm::Vec3f midPoint = (p1 + p2) * 0.5f;
    vtkm::Id midPointId = this->AddVertex(midPoint, fieldValue);
    this->MidPointIndexCache.insert({ key, midPointId });
    return midPointId;
  }

  vtkm::Int64 CreateCacheKey(vtkm::Id id1, vtkm::Id id2)
  {
    vtkm::Id smallerId = vtkm::Min(id1, id2);
    vtkm::Id largerId = vtkm::Max(id1, id2);
    vtkm::Int64 key = (smallerId << 32) + largerId;
    return key;
  }

  vtkm::Id AddVertex(const vtkm::Vec3f& vertex, const vtkm::Float32& fieldValue)
  {
    vtkm::Vec3f normalizedVertex = vtkm::Normal(vertex);
    this->Vertices.push_back(normalizedVertex);
    this->VertexValues.push_back(fieldValue);
    return (this->Vertices.size() - 1);
  }

  std::vector<vtkm::Vec3f> Centers;
  std::vector<vtkm::FloatDefault> Radii;
  std::vector<vtkm::FloatDefault> FieldValues;
  std::string FieldName;
  vtkm::IdComponent RefinementLevel;

  std::vector<vtkm::Vec3f> Vertices;
  std::vector<vtkm::Float32> VertexValues;
  std::unordered_map<vtkm::Int64, vtkm::Id> MidPointIndexCache;
};
} // namespace detail

IcoSphere::IcoSphere(vtkm::IdComponent refinementLevel)
  : RefinementLevel(refinementLevel)
  , FieldName("ico_spheres")
{
}

const std::string& IcoSphere::GetFieldName() const
{
  return this->FieldName;
}

void IcoSphere::SetFieldName(const std::string& fieldName)
{
  this->FieldName = fieldName;
}

vtkm::IdComponent IcoSphere::GetRefinementLevel() const
{
  return this->RefinementLevel;
}

void IcoSphere::SetRefinementLevel(vtkm::IdComponent refinementLevel)
{
  this->RefinementLevel = refinementLevel;
}

void IcoSphere::AddSphere(vtkm::Vec3f center,
                          vtkm::FloatDefault radius,
                          vtkm::FloatDefault fieldValue)
{
  this->Centers.push_back(center);
  this->Radii.push_back(radius);
  this->FieldValues.push_back(fieldValue);
}

vtkm::cont::DataSet IcoSphere::DoExecute() const
{
  detail::IcoSphereCreater creater(this->RefinementLevel);
  creater.SetData(this->Centers, this->Radii, this->FieldValues, this->FieldName);
  return creater.Create();
}
}
}