/*=========================================================================

   Library: iMSTK

   Copyright (c) Kitware, Inc. & Center for Modeling, Simulation,
   & Imaging in Medicine, Rensselaer Polytechnic Institute.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0.txt

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

=========================================================================*/

#include "imstkSPHModel.h"
#include "imstkParallelUtils.h"
#include <g3log/g3log.hpp>

#include <set>

namespace imstk
{
void
SPHModelConfig::initialize()
{
    LOG_IF(FATAL, (std::abs(m_particleRadius) < Real(1e-8))) << "Particle radius must be set by user";
    LOG_IF(FATAL, (m_restDensity <= 0)) << "Invalid fluid rest density";
    LOG_IF(FATAL, (m_kernelOverParticleRadiusRatio <= 1.0)) << "Invalid kernel radius/particle radius ratio";

    m_particleRadiusSqr = m_particleRadius * m_particleRadius;
    m_particleMass      = Real(std::pow(Real(2.0) * m_particleRadius, 3)) * m_restDensity * m_particleMassScale;
    m_restDensitySqr    = m_restDensity * m_restDensity;
    m_restDensityInv    = Real(1) / m_restDensity;

    m_kernelRadius    = m_particleRadius * m_kernelOverParticleRadiusRatio;
    m_kernelRadiusSqr = m_kernelRadius * m_kernelRadius;
}

bool
SPHModel::initialize()
{
    LOG_IF(FATAL, (!m_geometry)) << "Model geometry has not been set! Cannot initialize without model geometry.";

    // Initialize model group
    SPHModel::initializeModelGroup(shared_from_this());

    // Initialize  positions and velocity of the particles
    this->m_initialState = std::make_shared<SPHKinematicState>();
    this->m_currentState = std::make_shared<SPHKinematicState>();

    // Set particle positions and zero default velocities
    /// \todo set particle data with given (non-zero) velocities
    this->m_initialState->setParticleData(m_geometry->getVertexPositions());
    this->m_currentState->setState(this->m_initialState);

    // Attach current state to simulation state
    m_simulationState.setKinematicState(this->m_currentState);

    // Initialize (allocate memory for) simulation data such as density, acceleration etc.
    m_simulationState.initializeData(SPHModel::getModelGroup(this).size());

    // Initialize simulation dependent parameters
    m_modelParameters->initialize();

    // Initialize SPH kernels
    m_kernels.initialize(m_modelParameters->m_kernelRadius);

    // Initialize neighbor searcher
    m_neighborSearcher = std::make_shared<NeighborSearch>(m_modelParameters->m_neighborSearchMethod,
                                                          m_modelParameters->m_kernelRadius);
    return true;
}

void
SPHModel::advanceTimeStep()
{
    updateNeighborSearchData();
    findParticleNeighbors();
    computeNeighborRelativePositions();
    computeDensity();
    normalizeDensity();
    collectNeighborDensity();
    computePressureAcceleration();
    computeSurfaceNormal();
    computeSurfaceTensionAcceleration();
    computeTimeStepSize();
    updateVelocity(getTimeStep());
    computeViscosity();
    moveParticles(getTimeStep());
}

void
SPHModel::computeTimeStepSize()
{
    m_dt = (this->m_timeStepSizeType == TimeSteppingType::fixed) ? m_defaultDt : computeCFLTimeStepSize();
}

Real
SPHModel::computeCFLTimeStepSize()
{
    const auto maxVel = ParallelUtils::findMaxL2Norm(getState().getVelocities());

    // dt = CFL * 2r / max{|| v ||}
    Real timestep = maxVel > Real(1e-6) ?
                    m_modelParameters->m_CFLFactor * (Real(2.0) * m_modelParameters->m_particleRadius / maxVel) :
                    m_modelParameters->m_maxTimestep;

    // clamp the time step size to be within a given range
    if (timestep > m_modelParameters->m_maxTimestep)
    {
        timestep = m_modelParameters->m_maxTimestep;
    }
    else if (timestep < m_modelParameters->m_minTimestep)
    {
        timestep = m_modelParameters->m_minTimestep;
    }

    return timestep;
}

void
SPHModel::updateNeighborSearchData()
{
    // Collect indices of particles in this model
    m_neighborSearcher->collectNeighborSearchData(getState().getPositions());
}

void
SPHModel::findParticleNeighbors()
{
    const auto& modelGroup = SPHModel::getModelGroup(this);
    for (size_t modelIdx = 0; modelIdx < modelGroup.size(); ++modelIdx)
    {
        const auto& model = modelGroup[modelIdx];

        // For each particle in this model, find its neighbors in modelGroup[modelIdx]
        model->m_neighborSearcher->getCollectedNeighbors(
                    getState().getFluidNeighborLists(modelIdx),
                    getState().getPositions(),
                    model->getState().getPositions());

        // If considering boundary particles for computing fluid density
        if (m_modelParameters->m_densityWithBoundary)
        {
            model->m_neighborSearcher->getCollectedNeighbors(
                        getState().getBoundaryNeighborLists(modelIdx),
                        getState().getPositions(),
                        model->getState().getBoundaryParticlePositions());
        }
    }
}

void
SPHModel::computeNeighborRelativePositions()
{
    auto computeRelvPositions = [&](const Vec3r& ppos, const std::vector<size_t>& neighborList,
                                    const StdVectorOfVec3r& allPositions,
                                    std::vector<NeighborInfo>& neighborInfo) {
                                    for (const size_t q : neighborList)
                                    {
                                        const Vec3r& qpos = allPositions[q];
                                        const Vec3r  r    = ppos - qpos;
                                        neighborInfo.push_back({ r, m_modelParameters->m_restDensity });
                                    }
                                };

    const auto& modelGroup = SPHModel::getModelGroup(this);
    for (size_t modelIdx = 0; modelIdx < modelGroup.size(); ++modelIdx)
    {
        const auto& model = modelGroup[modelIdx];

        ParallelUtils::parallelFor(getState().getNumParticles(),
            [&](const size_t p) {
                const auto& ppos   = getState().getPositions()[p];
                auto& neighborInfo = getState().getNeighborInfo(modelIdx)[p];
                neighborInfo.resize(0);

                computeRelvPositions(ppos, getState().getFluidNeighborLists(modelIdx)[p],
                                     model->getState().getPositions(),
                                     neighborInfo);

                // if considering boundary particles then also cache relative positions with them
                if (m_modelParameters->m_densityWithBoundary)
                {
                    computeRelvPositions(ppos, getState().getBoundaryNeighborLists(modelIdx)[p],
                                         model->getState().getBoundaryParticlePositions(),
                                         neighborInfo);
                }
            });
    }
}

void
SPHModel::collectNeighborDensity()
{
    // After computing particle densities, cache them into neighborInfo variable, next to relative positions
    // This is usefull because relative positions and densities are accessed together multiple times
    // Caching relative positions and densities therefore can reduce computation time significantly (tested)
    const auto& modelGroup = SPHModel::getModelGroup(this);
    for (size_t modelIdx = 0; modelIdx < modelGroup.size(); ++modelIdx)
    {
        const auto& model          = modelGroup[modelIdx];
        const auto& modelDensities = model->getState().getDensities();

        ParallelUtils::parallelFor(getState().getNumParticles(),
            [&](const size_t p) {
                auto& neighborInfo = getState().getNeighborInfo(modelIdx)[p];
                const auto& fluidNeighborList = getState().getFluidNeighborLists(modelIdx)[p];
                for (size_t i = 0; i < fluidNeighborList.size(); ++i)
                {
                    const auto q = fluidNeighborList[i];
                    neighborInfo[i].density = modelDensities[q];
#if defined(DEBUG) || defined(_DEBUG) || !defined(NDEBUG)
                    LOG_IF(FATAL, (neighborInfo[i].density < 1e-5)) << "Invalid density";
#endif
                }
            });
    }
}

void
SPHModel::computeDensity()
{
    const auto& modelGroup = SPHModel::getModelGroup(this);
    ParallelUtils::parallelFor(getState().getNumParticles(),
        [&](const size_t p) {
            Real pdensity = 0;

            // Must loop for all modelIdx here
            // Each particle p loop for all neighbors in all models
            for (size_t modelIdx = 0; modelIdx < modelGroup.size(); ++modelIdx)
            {
                const auto& neighborInfo = getState().getNeighborInfo(modelIdx)[p];
                for (const auto& qInfo : neighborInfo)
                {
                    pdensity += m_kernels.W(qInfo.xpq);
                }
            }
            getState().getDensities()[p] = pdensity;
        });
}

void
SPHModel::normalizeDensity()
{
    if (!m_modelParameters->m_normalizeDensity)
    {
        return;
    }

    getState().getNormalizedDensities().resize(getState().getNumParticles());
    const auto& modelGroup = SPHModel::getModelGroup(this);

    ParallelUtils::parallelFor(getState().getNumParticles(),
        [&](const size_t p) {
            Real tmp = 0;
            for (size_t modelIdx = 0; modelIdx < modelGroup.size(); ++modelIdx)
            {
                auto& neighborInfo = getState().getNeighborInfo(modelIdx)[p];
                const auto& fluidNeighborList = getState().getFluidNeighborLists(modelIdx)[p];
                const auto& model = modelGroup[modelIdx];

                for (size_t i = 0; i < fluidNeighborList.size(); ++i)
                {
                    const auto& qInfo = neighborInfo[i];

                    // because we're not done with density computation, qInfo does not contain desity of particle q yet
                    const auto q        = fluidNeighborList[i];
                    const auto qdensity = model->getState().getDensities()[q];
                    tmp += m_kernels.W(qInfo.xpq) / qdensity;
                }

                if (m_modelParameters->m_densityWithBoundary)
                {
                    const auto& BDNeighborList = getState().getBoundaryNeighborLists(modelIdx)[p];
#if defined(DEBUG) || defined(_DEBUG) || !defined(NDEBUG)
                    LOG_IF(FATAL, (fluidNeighborList.size() + BDNeighborList.size() != neighborInfo.size()))
                    << "Invalid neighborInfo computation";
#endif
                    for (size_t i = fluidNeighborList.size(); i < neighborInfo.size(); ++i)
                    {
                        const auto& qInfo = neighborInfo[i];
                        tmp += m_kernels.W(qInfo.xpq);  // density of boundary particle is set to rest density
                    }
                }
            }
            if (tmp > 0)
            {
                getState().getNormalizedDensities()[p] = getState().getDensities()[p] / (tmp /* * m_modelParameters->m_particleMass*/);
            }
        });

    // put normalized densities to densities
    std::swap(getState().getDensities(), getState().getNormalizedDensities());
}

void
SPHModel::computePressureAcceleration()
{
    auto compPressure = [&](const Real density, const Real particleMass, const Real restDensity) {
                            const Real error = std::pow(particleMass * density / restDensity, 7) - Real(1);
                            // Clamp pressure error to zero to maintain stability due to negative (attractive) pressure
                            return error > Real(0) ? error : Real(0);
                        };
    const auto& modelGroup = SPHModel::getModelGroup(this);
    ParallelUtils::parallelFor(getState().getNumParticles(),
        [&](const size_t p) {
            const auto pdensity  = getState().getDensities()[p];
            const auto ppressure = compPressure(pdensity, m_modelParameters->m_particleMass, m_modelParameters->m_restDensity);
            const auto pPressureDensitySqrInv  = ppressure / (pdensity * pdensity);
            Vec3r accel(0, 0, 0);

            for (size_t modelIdx = 0; modelIdx < modelGroup.size(); ++modelIdx)
            {
                const auto& model = modelGroup[modelIdx];
                const auto& neighborInfo = getState().getNeighborInfo(modelIdx)[p];
                for (size_t idx = 0; idx < neighborInfo.size(); ++idx)
                {
                    const auto& qInfo    = neighborInfo[idx];
                    const auto r         = qInfo.xpq;
                    const auto qdensity  = qInfo.density;
                    const auto qpressure = compPressure(qdensity, model->getParameters()->m_particleMass,
                                                        model->getParameters()->m_restDensity);
                    accel += -(pPressureDensitySqrInv + qpressure / (qdensity * qdensity)) * m_kernels.gradW(r);
                }
            }
            accel *= m_modelParameters->m_pressureStiffnessConstant / m_modelParameters->m_particleMass;

            getState().getAccelerations()[p] = accel;
        });
}

// Compute surface tension using Akinci et at. 2013 model (Versatile Surface Tension and Adhesion for SPH Fluids)
void
SPHModel::computeSurfaceNormal()
{
    const auto& modelGroup = SPHModel::getModelGroup(this);

    // Firstly compute surface normal for all particles
    ParallelUtils::parallelFor(getState().getNumParticles(),
        [&](const size_t p) {
            Vec3r n(0, 0, 0);
            for (size_t modelIdx = 0; modelIdx < modelGroup.size(); ++modelIdx)
            {
                const auto& neighborInfo = getState().getNeighborInfo(modelIdx)[p];
                for (size_t i = 0; i < neighborInfo.size(); ++i)
                {
                    const auto& qInfo   = neighborInfo[i];
                    const auto r        = qInfo.xpq;
                    const auto qdensity = qInfo.density;
                    n += (Real(1.0) / qdensity) * m_kernels.gradW(r);
                }
            }
            n *= m_modelParameters->m_kernelRadius;
            getState().getNormals()[p] = n;
        });
}

// Compute surface tension using Akinci et at. 2013 model (Versatile Surface Tension and Adhesion for SPH Fluids)
void
SPHModel::computeSurfaceTensionAcceleration()
{
    const auto& modelGroup = SPHModel::getModelGroup(this);

    // Compute surface tension acceleration
    ParallelUtils::parallelFor(getState().getNumParticles(),
        [&](const size_t p) {
            const auto ni            = getState().getNormals()[p];
            const auto pdensity      = getState().getDensities()[p];
            Vec3r accel(0, 0, 0);

            for (size_t modelIdx = 0; modelIdx < modelGroup.size(); ++modelIdx)
            {
                const auto& fluidNeighborList = getState().getFluidNeighborLists(modelIdx)[p];
                const auto& neighborInfo = getState().getNeighborInfo(modelIdx)[p];
                const auto& model = modelGroup[modelIdx];

                for (size_t i = 0; i < fluidNeighborList.size(); ++i)
                {
                    const auto q  = fluidNeighborList[i];
                    const auto& qInfo = neighborInfo[i];
                    const auto qdensity = qInfo.density;

                    // Correction factor
                    const auto K_ij = (m_modelParameters->m_restDensity + model->getParameters()->m_restDensity) /
                                      (pdensity * m_modelParameters->m_particleMass +
                                       qdensity * model->getParameters()->m_particleMass);

                    // Cohesion acc
                    const auto r = qInfo.xpq;
                    const auto d2 = r.squaredNorm();
                    if (d2 > Real(1e-20))
                    {
                        accel -= K_ij * model->getParameters()->m_particleMass *
                                 (r / std::sqrt(d2)) * m_kernels.cohesionW(r);
                    }

                    // Curvature acc
                    const auto nj = model->getState().getNormals()[q];
                    accel -= K_ij * (ni - nj);
                }
            }
            accel *= m_modelParameters->m_surfaceTensionConstant;
            getState().getAccelerations()[p] += accel;
    });
}

void
SPHModel::updateVelocity(Real timestep)
{
    ParallelUtils::parallelFor(getState().getNumParticles(),
        [&](const size_t p) {
            getState().getVelocities()[p] += (m_modelParameters->m_gravity + getState().getAccelerations()[p]) * timestep;
        });
}

void
SPHModel::computeViscosity()
{
    const auto& modelGroup = SPHModel::getModelGroup(this);

    ParallelUtils::parallelFor(getState().getNumParticles(),
        [&](const size_t p) {
            const auto& pvel = getState().getVelocities()[p];
            Vec3r diffuseFluid(0, 0, 0);
            Vec3r diffuseBoundary(0, 0, 0);

            for (size_t modelIdx = 0; modelIdx < modelGroup.size(); ++modelIdx)
            {
                const auto& neighborInfo = getState().getNeighborInfo(modelIdx)[p];
                const auto& fluidNeighborList = getState().getFluidNeighborLists(modelIdx)[p];
                const auto& model = modelGroup[modelIdx];

                for (size_t i = 0; i < fluidNeighborList.size(); ++i)
                {
                    const auto q        = fluidNeighborList[i];
                    const auto& qvel    = model->getState().getVelocities()[q];
                    const auto& qInfo   = neighborInfo[i];
                    const auto r        = qInfo.xpq;
                    const auto qdensity = qInfo.density;
                    diffuseFluid       += (Real(1.0) / qdensity) * m_kernels.W(r) * (qvel - pvel);
                }

                if (m_modelParameters->m_densityWithBoundary)
                {
                    for (size_t i = fluidNeighborList.size(); i < neighborInfo.size(); ++i)
                    {
                        const auto& qInfo   = neighborInfo[i];
                        const auto r        = qInfo.xpq;
                        diffuseBoundary    -= m_modelParameters->m_restDensityInv * m_kernels.W(r) * pvel;
                    }
                }
            }

            diffuseFluid *= m_modelParameters->m_fluidViscosityConstant;
            diffuseBoundary *= m_modelParameters->m_boundaryViscosityConstant;
            getState().getDiffuseVelocities()[p] = (diffuseFluid + diffuseBoundary);
        });

    // Add diffused velocity back to velocity, causing viscosity
    ParallelUtils::parallelFor(getState().getNumParticles(),
        [&](const size_t p)
        {
            getState().getVelocities()[p] += getState().getDiffuseVelocities()[p];
        });
}

void
SPHModel::moveParticles(Real timestep)
{
    ParallelUtils::parallelFor(getState().getNumParticles(),
        [&](const size_t p)
        {
            getState().getPositions()[p] += getState().getVelocities()[p] * timestep;
        });
}

// Static members
std::unordered_map<SPHModel*, SPHModel::SPHModelGroup> SPHModel::s_mSPHModelGroups;

void
SPHModel::initializeModelGroup(const std::shared_ptr<SPHModel>& sphModel)
{
    if (s_mSPHModelGroups.find(sphModel.get()) != s_mSPHModelGroups.end())
    {
        LOG(WARNING) << "SPH Model has previously been initialized";
        return;
    }
    s_mSPHModelGroups[sphModel.get()] = { sphModel };
}

const SPHModel::SPHModelGroup&
SPHModel::getModelGroup(SPHModel* const sphModel)
{
    const auto it = s_mSPHModelGroups.find(sphModel);
#if defined(DEBUG) || defined(_DEBUG) || !defined(NDEBUG)
    LOG_IF(FATAL, (it == s_mSPHModelGroups.end())) << "Model group for the given SPH model has not been initialized";
#endif
    return it->second;
}

void
SPHModel::setupModelGroup(const std::vector<std::shared_ptr<SPHModel>>& models)
{
    // If there is just 1 SPH model, don't need to do anything
    if (models.size() < 2)
    {
        return;
    }

    // Build a set of all models from the group
    std::set<std::shared_ptr<SPHModel>> modelSet;
    for (const auto& model: models)
    {
        modelSet.insert(model);
    }

    // Union all model groups of all models
    for (const auto& model: models)
    {
        const auto currentGroup = s_mSPHModelGroups[model.get()];
        modelSet.insert(currentGroup.begin(), currentGroup.end());
    }

    // Convert set of models to vector of models (model group)
    SPHModelGroup modelGroup;
    modelGroup.insert(modelGroup.end(), modelSet.begin(), modelSet.end());

    // Set the same model group to all models in the group
    for (const auto& model: models)
    {
        s_mSPHModelGroups[model.get()] = modelGroup;
    }
}
} // end namespace imstk
