/*=========================================================================

   Library: iMSTK

   Copyright (c) Kitware, Inc. & Center for Modeling, Simulation,
   & Imaging in Medicine, Rensselaer Polytechnic Institute.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0.txt

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

=========================================================================*/

#include "imstkLevelSetCH.h"
#include "imstkCollisionData.h"
#include "imstkImageData.h"
#include "imstkLevelSetDeformableObject.h"
#include "imstkLevelSetModel.h"
#include "imstkRbdConstraint.h"
#include "imstkRigidObject2.h"

namespace imstk
{
namespace expiremental
{
LevelSetCH::LevelSetCH(const Side&                               side,
                       const std::shared_ptr<CollisionData>      colData,
                       std::shared_ptr<LevelSetDeformableObject> lvlSetObj,
                       std::shared_ptr<RigidObject2>             rigidObj) :
    CollisionHandling(Type::LevelSet, side, colData),
    m_lvlSetObj(lvlSetObj),
    m_rigidObj(rigidObj)
{
    setKernel(m_kernelSize, m_kernelSigma);
    maskAllPoints();

    std::shared_ptr<LevelSetModel> lvlSetModel = m_lvlSetObj->getLevelSetModel();
    std::shared_ptr<ImageData>     grid = std::dynamic_pointer_cast<SignedDistanceField>(lvlSetModel->getModelGeometry())->getImage();
    auto imgScalarsPtr = std::dynamic_pointer_cast<DataArray<double>>(grid->getScalars());
    DataArray<double>& imgScalars = *imgScalarsPtr;
    
    // find range of the bone    
    for (int i = 0; i < imgScalarsPtr->size(); ++i)
    {
        if (imgScalars[i] > m_scalarMax) { m_scalarMax = imgScalars[i]; }
        if (imgScalars[i] < m_scalarMin) { m_scalarMin = imgScalars[i]; }
    }

}

LevelSetCH::~LevelSetCH()
{
    if (m_kernelWeights != nullptr)
    {
        delete[] m_kernelWeights;
    }
}

void
LevelSetCH::setKernel(const int size, const double sigma)
{
    m_kernelSize  = size;
    m_kernelSigma = sigma;
    if (size % 2 == 0)
    {
        LOG(WARNING) << "LevelSetCH kernel size must be odd, increasing by 1";
        m_kernelSize++;
    }
    if (m_kernelWeights != nullptr)
    {
        delete[] m_kernelWeights;
    }
    m_kernelWeights = new double[size * size * size];

    const double invDiv   = 1.0 / (2.0 * sigma * sigma);
    const int    halfSize = static_cast<int>(size * 0.5);
    int          i = 0;
    for (int z = -halfSize; z < halfSize + 1; z++)
    {
        for (int y = -halfSize; y < halfSize + 1; y++)
        {
            for (int x = -halfSize; x < halfSize + 1; x++)
            {
                const double dist = Vec3i(x, y, z).cast<double>().norm();
                m_kernelWeights[i++] = std::exp(-dist * invDiv);
            }
        }
    }
}


///
/// \brief Accepts structured coordinates (ie: pre int cast, [0, dim)) so it can do interpolation
/// origin should be image origin + spacing/2
///
template<typename T>
static T
trilinearSample(const Vec3d& structuredPt, T* imgPtr, const Vec3i& dim, const int numComps, const int comp)
{
    // minima of voxel, clamped to bounds
    const Vec3i s1 = structuredPt.cast<int>().cwiseMax(0).cwiseMin(dim - Vec3i(1, 1, 1));

    // maxima of voxel, clamped to bounds
    const Vec3i s2 = (structuredPt.cast<int>() + Vec3i(1, 1, 1)).cwiseMax(0).cwiseMin(dim - Vec3i(1, 1, 1));

    const size_t index000 = ImageData::getScalarIndex(s1.x(), s1.y(), s1.z(), dim, numComps) + comp;
    const size_t index100 = ImageData::getScalarIndex(s2.x(), s1.y(), s1.z(), dim, numComps) + comp;
    const size_t index110 = ImageData::getScalarIndex(s2.x(), s2.y(), s1.z(), dim, numComps) + comp;
    const size_t index010 = ImageData::getScalarIndex(s1.x(), s2.y(), s1.z(), dim, numComps) + comp;
    const size_t index001 = ImageData::getScalarIndex(s1.x(), s1.y(), s2.z(), dim, numComps) + comp;
    const size_t index101 = ImageData::getScalarIndex(s2.x(), s1.y(), s2.z(), dim, numComps) + comp;
    const size_t index111 = ImageData::getScalarIndex(s2.x(), s2.y(), s2.z(), dim, numComps) + comp;
    const size_t index011 = ImageData::getScalarIndex(s1.x(), s2.y(), s2.z(), dim, numComps) + comp;

    const double val000 = static_cast<double>(imgPtr[index000]);
    const double val100 = static_cast<double>(imgPtr[index100]);
    const double val110 = static_cast<double>(imgPtr[index110]);
    const double val010 = static_cast<double>(imgPtr[index010]);

    const double val001 = static_cast<double>(imgPtr[index001]);
    const double val101 = static_cast<double>(imgPtr[index101]);
    const double val111 = static_cast<double>(imgPtr[index111]);
    const double val011 = static_cast<double>(imgPtr[index011]);

    // Interpolants
    //const Vec3d t = s2.cast<double>() - structuredPt;
    const Vec3d t = structuredPt - s2.cast<double>();

    // Interpolate along x
    const double ax = val000 + (val100 - val000) * t[0];
    const double bx = val010 + (val110 - val010) * t[0];

    const double dx = val001 + (val101 - val001) * t[0];
    const double ex = val011 + (val111 - val011) * t[0];

    // Interpolate along y
    const double cy = ax + (bx - ax) * t[1];
    const double fy = dx + (ex - dx) * t[1];

    // Interpolate along z
    const double gz = cy + (fy - cy) * t[2];

    return static_cast<T>(gz);
}

void
LevelSetCH::processCollisionData()
{
    std::shared_ptr<LevelSetModel> lvlSetModel = m_lvlSetObj->getLevelSetModel();
    std::shared_ptr<ImageData>     grid = std::dynamic_pointer_cast<SignedDistanceField>(lvlSetModel->getModelGeometry())->getImage();

    if (grid == nullptr)
    {
        LOG(FATAL) << "Error: level set model geometry is not ImageData";
        return;
    }

    //const Vec3i& dim = grid->getDimensions();
    const Vec3d& invSpacing = grid->getInvSpacing();
    const Vec3d& origin     = grid->getOrigin();

    if (m_useProportionalForce)
    {
        // Apply impulses at points of contacts
        PositionDirectionCollisionData& pdColData = m_colData->PDColData;
        for (size_t i = 0; i < pdColData.getSize(); i++)
        {
            // If the point is in the mask, let it apply impulses
            if (m_ptIdMask.count(pdColData[i].nodeIdx) != 0)
            {
                const Vec3d& pos    = pdColData[i].posB;
                const Vec3d& normal = pdColData[i].dirAtoB;
                const Vec3i  coord  = (pos - origin).cwiseProduct(invSpacing).cast<int>();

                //interpolate                                
                auto imgScalarsPtr = std::dynamic_pointer_cast<DataArray<double>>(grid->getScalars());                
                double interpScalar = trilinearSample<double>(pos, imgScalarsPtr->getPointer(), grid->getDimensions(), 1, 0);
                interpScalar = 0.0 + (interpScalar - m_scalarMin) / ((m_scalarMax - m_scalarMin)) * (m_velocityScaling - 0.0);

                // Scale the applied impulse by the normal force
                const double fN = normal.normalized().dot(m_rigidObj->getRigidBody()->getForce()) / m_rigidObj->getRigidBody()->getForce().norm();
                const double S  = std::max(fN, 0.0) * interpScalar * m_toolVelScaling;

                const int halfSize = static_cast<int>(m_kernelSize * 0.5);
                int       j = 0;
                for (int z = -halfSize; z < halfSize + 1; z++)
                {
                    for (int y = -halfSize; y < halfSize + 1; y++)
                    {
                        for (int x = -halfSize; x < halfSize + 1; x++)
                        {
                            const Vec3i fCoord = coord + Vec3i(x, y, z);
                            lvlSetModel->addImpulse(fCoord, S * m_kernelWeights[j++]);
                        }
                    }
                }
            }
        }
    }
    else
    {
        // Apply impulses at points of contacts
        PositionDirectionCollisionData& pdColData = m_colData->PDColData;
        for (size_t i = 0; i < pdColData.getSize(); i++)
        {
            // If the point is in the mask, let it apply impulses
            if (m_ptIdMask.count(pdColData[i].nodeIdx) != 0)
            {
                const Vec3d& pos = pdColData[i].posB;
                //const Vec3d& normal = pdColData[i].dirAtoB;
                const Vec3i  coord = (pos - origin).cwiseProduct(invSpacing).cast<int>();
                const double S     = m_velocityScaling;

                const int halfSize = static_cast<int>(m_kernelSize * 0.5);
                int       j = 0;
                for (int z = -halfSize; z < halfSize + 1; z++)
                {
                    for (int y = -halfSize; y < halfSize + 1; y++)
                    {
                        for (int x = -halfSize; x < halfSize + 1; x++)
                        {
                            const Vec3i fCoord = coord + Vec3i(x, y, z);
                            lvlSetModel->addImpulse(fCoord, S * m_kernelWeights[j++]);
                        }
                    }
                }
            }
        }
    }
}

void
LevelSetCH::maskAllPoints()
{
    std::shared_ptr<PointSet> pointSet = std::dynamic_pointer_cast<PointSet>(m_rigidObj->getCollidingGeometry());
    for (int i = 0; i < static_cast<int>(pointSet->getNumVertices()); i++)
    {
        m_ptIdMask.insert(i);
    }
}
}
}