/*=========================================================================

   Library: iMSTK

   Copyright (c) Kitware, Inc. & Center for Modeling, Simulation,
   & Imaging in Medicine, Rensselaer Polytechnic Institute.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0.txt

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

=========================================================================*/

#include "imstkTimer.h"
#include "imstkSimulationManager.h"
#include "imstkForceModelConfig.h"
#include "imstkDeformableObject.h"
#include "imstkBackwardEuler.h"
#include "imstkNonlinearSystem.h"
#include "imstkNewtonSolver.h"
#include "imstkGaussSeidel.h"
#include "imstkConjugateGradient.h"
#include "imstkTetrahedralMesh.h"
#include "imstkMeshIO.h"
#include "imstkLineMesh.h"
#include "imstkOneToOneMap.h"
#include "imstkHDAPIDeviceClient.h"
#include "imstkHDAPIDeviceServer.h"
#include "imstkDeviceTracker.h"
#include "imstkSceneObjectController.h"
#include "imstkVirtualCouplingCH.h"
#include "LineToPointSetCD.h"
#include "NeedleTissueCH.h"
#include "imstkAPIUtilities.h"
#include "imstkMath.h"
#include "LineSegmentToPlaneCD.h"
#include "imstkLogger.h"
#include "NeedlePlaneCH.h"

using namespace imstk;

// Global variables
std::shared_ptr<Scene> g_scene;
std::shared_ptr<SimulationManager> g_sdk;
std::shared_ptr<ConjugateGradient> g_linSolver;
std::shared_ptr<NewtonSolver> g_nlSolver;
std::vector<LinearProjectionConstraint> projList;

#ifdef iMSTK_USE_OPENHAPTICS
    std::shared_ptr<imstk::HDAPIDeviceClient> g_client;
    std::shared_ptr<imstk::HDAPIDeviceServer> g_server;
#endif
std::shared_ptr<imstk::DeviceTracker> g_deviceTracker;
std::shared_ptr<CollidingObject> g_needleObject;
std::shared_ptr<DeformableObject> g_deformableObj;
std::shared_ptr<SurfaceMesh> g_surfMesh;
std::shared_ptr<TetrahedralMesh> g_volTetMesh;
std::shared_ptr <OneToOneMap> g_oneToOneNodalMap;
std::shared_ptr<CollidingObject> g_planeObj;

std::shared_ptr<SceneObjectController> g_objController;

CollisionData colData;
std::vector<LinearProjectionConstraint> needleProjList;
vector<bool> surfaceStatus;

CollisionData colData2;

std::shared_ptr<LineSegmentToPlaneCD> g_CD;
std::shared_ptr<NeedleTissueInteraction>  CHA_NTI;
std::shared_ptr<LineToPointSetCD> CD_NTI;


// Needle tissue interaction configuration options
namespace NTISimulationConfig
{
    const double needleLength = 80.;
    const Vec3d needleStartPoint(0., 0., 0.9*needleLength);
    const Vec3d needleEndPoint(0., 0., -0.1*needleLength);
    const Color needleColor = Color::LightGray;

    const std::string phantomOmniName("Phantom1");

    /*const std::string kidneyMeshFilename(iMSTK_DATA_ROOT "/asianDragon/asianDragon.veg");
    const std::string kidneyConfigFilename(iMSTK_DATA_ROOT "/asianDragon/asianDragon.config");
    const std::string kidneyBCFilename(iMSTK_DATA_ROOT "/asianDragon/asianDragon.bou");*/

    const std::string kidneyMeshFilename(iMSTK_DATA_ROOT "/tissue/tissue.veg");
    const std::string kidneyConfigFilename(iMSTK_DATA_ROOT "/tissue/tissue.config");
    const std::string kidneyBCFilename(iMSTK_DATA_ROOT "/tissue/tissue.bou");

    const Vec3d centeringTransform(-30., -50., 0.);
    const double geoScalingFactor = 20.;
    const double solverTolerance = 1.0e-6;
    const double forceScalingFactor = 1.0e-1*1.2;

    const double timeStep = 0.04;

    const Vec3d bgColor1(0.3285, 0.3285, 0.6525);
    const Vec3d bgColor2(0.1152, 0.1152, 0.2289);

    const bool dispayFPS = false;
    const bool renderDebugInfo = false;

    const bool logPosVel = false;
}

// Dara logger class.
class AnesthesiaSimLogger : public imstk::Logger
{
public:
    AnesthesiaSimLogger(std::string name) : Logger(name), m_loggingEnabled(false) {}
    virtual ~AnesthesiaSimLogger() {}

    //Handle logging boolean
    bool isLoggingEnabled() const
    {
        return m_loggingEnabled;
    }

    void enableLogging()
    {
        m_loggingEnabled = true;
        m_totalLoggingTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock().now().time_since_epoch()).count();
    }

    void disableLogging()
    {
        if (!m_loggingEnabled)
        {
            return;
        }
        auto tPresent = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock().now().time_since_epoch()).count();
        auto prev = m_totalLoggingTime;
        m_totalLoggingTime = tPresent - m_totalLoggingTime;

        m_loggingEnabled = false;
    }

    //Log position and velocities
    inline void logPosAndVel(const imstk::Vec3d position, const imstk::Vec3d velocity)
    {
        std::stringstream ss;
        ss << "P " << position[0] << " " << position[1] << " " << position[2] << " V " << velocity[0] << " " << velocity[1] << " " << velocity[2] << std::endl;
        std::string message = ss.str();

        if (m_loggingEnabled && this->readyForLoggingWithFrequency())
        {
            this->log(message, false);
            this->updateLogTime();
        }
    }

    //Get complete logging time
    long long getTotalLoggingTime() { return m_totalLoggingTime; }

private:
    long long m_totalLoggingTime;
    bool m_loggingEnabled;
};

std::shared_ptr<AnesthesiaSimLogger> g_logger;

// Create a needle object that is controlled by an external device
bool
createNeedle()
{
    std::vector<LineMesh::LineArray> lines;
    StdVectorOfVec3d points;
    std::vector<Color> colors;

    // Create needle mesh data
    points.push_back(NTISimulationConfig::needleStartPoint);
    points.push_back(NTISimulationConfig::needleEndPoint);    
    lines.push_back(LineMesh::LineArray({ 0, 1 }));
    
    colors.push_back(NTISimulationConfig::needleColor);
    colors.push_back(NTISimulationConfig::needleColor);

    // Construct line mesh
    auto lineMesh = std::make_shared<LineMesh>();
    lineMesh->initialize(points, lines);
    lineMesh->setVertexColors(colors);


    std::vector<LineMesh::LineArray> lines2;
    StdVectorOfVec3d points2;
    std::vector<Color> colors2;

    // Create needle mesh data
    points2.push_back(NTISimulationConfig::needleStartPoint);
    points2.push_back(NTISimulationConfig::needleEndPoint);
    lines2.push_back(LineMesh::LineArray({ 0, 1 }));
    auto lineMeshCol = std::make_shared<LineMesh>();
    lineMeshCol->initialize(points2, lines2);

    auto lineMeshMaterial = std::make_shared<RenderMaterial>();
    lineMeshMaterial->setLineWidth(6);
    lineMeshMaterial->setColor(imstk::Color::Blue);
    
    auto lineMeshDisplayModel = std::make_shared<VisualModel>(lineMesh);
    lineMeshDisplayModel->setRenderMaterial(lineMeshMaterial);

    g_needleObject = std::make_shared<CollidingObject>("needleMesh");
    g_needleObject->addVisualModel(lineMeshDisplayModel);
    //g_needleObject->setVisualGeometry(lineMesh);
    g_needleObject->setCollidingGeometry(lineMeshCol);

    g_scene->addSceneObject(g_needleObject);

    // add user control to the needle
#ifdef iMSTK_USE_OPENHAPTICS
    // Device clients
    g_client = std::make_shared<imstk::HDAPIDeviceClient>(NTISimulationConfig::phantomOmniName);

    // Device Server
    g_server = std::make_shared<imstk::HDAPIDeviceServer>();
    g_server->addDeviceClient(g_client);
    g_sdk->addModule(g_server);

    // Device tracker
    g_deviceTracker = std::make_shared<imstk::DeviceTracker>(g_client);

    auto needleController = std::make_shared<imstk::SceneObjectController>(g_needleObject, g_deviceTracker);
    g_scene->addObjectController(needleController);

#endif

    return 1;
}

bool
loadBoundaryConditions(std::vector<int>& bcNodeList)
{
    std::string fileName(NTISimulationConfig::kidneyBCFilename);

    FILE *fp = fopen(fileName.data(), "rb");
    if (!fp)
    {
        return false;
    }

    int nodeNum;
    while (fscanf(fp, "%d ", &nodeNum) != EOF)
    {
        bcNodeList.emplace_back(nodeNum);
    }
    fclose(fp);
    std::sort(bcNodeList.begin(), bcNodeList.end());
    return true;
}

void
createPlane()
{
    // Create a plane in the scene
    auto planeGeom = std::make_shared<Plane>();
    planeGeom->setWidth(200);
    planeGeom->setPosition(0.0, -50, 0.0);
    g_planeObj = std::make_shared<CollidingObject>("Plane");
    g_planeObj->setVisualGeometry(planeGeom);
    g_planeObj->setCollidingGeometry(planeGeom);
    g_scene->addSceneObject(g_planeObj);
}

// create virtual coupling with sphere and a plane
void 
createVirtualCouplingSphere()
{
    // Create a virtual coupling object
    auto visualGeom = std::make_shared<Sphere>();
    visualGeom->setRadius(2);
    auto collidingGeom = std::make_shared<Sphere>();
    collidingGeom->setRadius(2);
    auto SphereObj = std::make_shared<CollidingObject>("VirtualCouplingObject");
    SphereObj->setCollidingGeometry(collidingGeom);

    auto material = std::make_shared<RenderMaterial>();
    auto visualModel = std::make_shared<VisualModel>(visualGeom);
    visualModel->setRenderMaterial(material);
    SphereObj->addVisualModel(visualModel);

    // Add virtual coupling object (with visual, colliding, and physics geometry) in the scene.
    g_scene->addSceneObject(SphereObj);

    // Create and add virtual coupling object controller in the scene
    g_objController = std::make_shared<SceneObjectController>(SphereObj, g_deviceTracker);
    g_scene->addObjectController(g_objController);

    // Create a collision graph
    auto graph = g_scene->getCollisionGraph();
    auto pair = graph->addInteractionPair(g_planeObj, SphereObj,
        CollisionDetection::Type::UnidirectionalPlaneToSphere,
        CollisionHandling::Type::None,
        CollisionHandling::Type::VirtualCoupling);

    // Customize collision handling algorithm
    auto colHandlingAlgo = std::dynamic_pointer_cast<VirtualCouplingCH>(pair->getCollisionHandlingB());
    colHandlingAlgo->setStiffness(5e-01);
    colHandlingAlgo->setDamping(0.005);
}

// create virtual coupling with sphere and a plane
bool
createVirtualCouplingWithNeedle()
{        
    colData2.NeedleColData.resize(0);
    colData2.NeedleColData.reserve(10);

    if (g_deviceTracker)
    {
        // create collision detection
        g_CD = std::make_shared<LineSegmentToPlaneCD>(std::dynamic_pointer_cast<Plane>(g_planeObj->getCollidingGeometry()),
            g_deviceTracker,
            NTISimulationConfig::needleStartPoint,
            NTISimulationConfig::needleEndPoint,
            colData2);

        // collision handling
        auto CHA = std::make_shared<NeedlePlaneCH>(CollisionHandling::Side::A, colData2, g_needleObject);
        CHA->setStiffness(0.4e-0);
        CHA->setDamping(0.1);

        // Add the interaction pair to the scene
        g_scene->getCollisionGraph()->addInteractionPair(g_planeObj, g_needleObject, g_CD, nullptr, CHA);
        
        return 1;
    }
   
    return 0;
   
}

bool 
createTissue()
{
    // Load a kidney mesh
    auto tetMesh = MeshIO::read(NTISimulationConfig::kidneyMeshFilename);
    if (!tetMesh)
    {
        LOG(WARNING) << "Could not read mesh from file: " << NTISimulationConfig::kidneyConfigFilename;
        return 0;
    }

    // Extract the surface mesh from the tetrahedral mesh
    g_surfMesh = std::make_shared<SurfaceMesh>();
    g_volTetMesh = std::dynamic_pointer_cast<TetrahedralMesh>(tetMesh);

    g_volTetMesh->scale(NTISimulationConfig::geoScalingFactor, Geometry::TransformType::ApplyToData);
    g_volTetMesh->translate(NTISimulationConfig::centeringTransform, Geometry::TransformType::ApplyToData);
    //volTetMesh->rotate(Vec3d(0, 0, 1.), PI, Geometry::TransformType::ApplyToData);
    if (!g_volTetMesh)
    {
        LOG(WARNING) << "Dynamic pointer cast from PointSet to TetrahedralMesh failed!";
        return 0;
    }
    g_volTetMesh->extractSurfaceMesh(g_surfMesh, true);
    g_surfMesh->flipNormals();

    // Construct one to one nodal map based on the above meshes
    g_oneToOneNodalMap = std::make_shared<OneToOneMap>();
    g_oneToOneNodalMap->setMaster(tetMesh);
    g_oneToOneNodalMap->setSlave(g_surfMesh);

    // Compute the map
    g_oneToOneNodalMap->compute();

    // Configure dynamic model
    auto dynaModel = std::make_shared<FEMDeformableBodyModel>();
    dynaModel->configure(NTISimulationConfig::kidneyConfigFilename);
    dynaModel->setTimeStepSizeType(TimeSteppingType::realTime);
    dynaModel->setModelGeometry(g_volTetMesh);
    auto timeIntegrator = std::make_shared<BackwardEuler>(NTISimulationConfig::timeStep);
    dynaModel->setTimeIntegrator(timeIntegrator);

    auto material = std::make_shared<RenderMaterial>();
    material->setDisplayMode(RenderMaterial::DisplayMode::WIREFRAME_SURFACE);
    material->setColor(Color::DarkGray);    
    material->setLineWidth(2.0);
 
    auto surfMeshVisualModel = std::make_shared<VisualModel>(g_surfMesh);
    surfMeshVisualModel->setRenderMaterial(material);

    /*auto volTetMeshVisualModel = std::make_shared<VisualModel>(g_volTetMesh);
    volTetMeshVisualModel->setRenderMaterial(material);*/

    // Scene Object
    g_deformableObj = std::make_shared<DeformableObject>("Tissue");    
    g_deformableObj->addVisualModel(surfMeshVisualModel);
    g_deformableObj->setPhysicsGeometry(g_volTetMesh);
    g_deformableObj->setPhysicsToVisualMap(g_oneToOneNodalMap); //assign the computed map
    g_deformableObj->setDynamicalModel(dynaModel);
    g_scene->addSceneObject(g_deformableObj);

    // create a nonlinear system
    auto nlSystem = std::make_shared<NonLinearSystem>(
        dynaModel->getFunction(),
        dynaModel->getFunctionGradient());

    // Add boundary conditions
    std::vector<int> nodeBCList;
    nodeBCList.reserve(400);
    loadBoundaryConditions(nodeBCList);
    for (auto& i : nodeBCList)
    {
        projList.push_back(LinearProjectionConstraint(i, true));
    }

    nlSystem->setUnknownVector(dynaModel->getUnknownVec());
    nlSystem->setUpdateFunction(dynaModel->getUpdateFunction());
    nlSystem->setUpdatePreviousStatesFunction(dynaModel->getUpdatePrevStateFunction());

    // create a linear solver
    g_linSolver = std::make_shared<ConjugateGradient>();
    g_linSolver->setTolerance(NTISimulationConfig::solverTolerance);
    g_linSolver->setLinearProjectors(&projList);

    // create a non-linear solver and add to the scene
    g_nlSolver = std::make_shared<NewtonSolver>();
    g_nlSolver->setLinearSolver(g_linSolver);
    g_nlSolver->setSystem(nlSystem);
    g_scene->addNonlinearSolver(g_nlSolver);

    return 1;
}

bool 
createNeedleTissueInteraction()
{
    // Add collision detection and handling
    g_linSolver->setDynamicLinearProjectors(&needleProjList);
    
    colData.NeedleColData.resize(0);
    colData.NeedleColData.reserve(400);

    if (g_deviceTracker)
    {
        // create collision detection
        CD_NTI = std::make_shared<LineToPointSetCD>(std::dynamic_pointer_cast<PointSet>(g_volTetMesh),
            g_deviceTracker,
            NTISimulationConfig::needleStartPoint,
            NTISimulationConfig::needleEndPoint,
            colData);

        surfaceStatus = std::vector<bool>(g_volTetMesh->getNumVertices(), false);
        for (unsigned int i = 0; i < g_surfMesh->getNumVertices(); ++i)
        {
            surfaceStatus[g_oneToOneNodalMap->getMapIdx(i)] = true;
        }
        CD_NTI->setSurfaceNodeList(surfaceStatus);

        // collision handling
        CHA_NTI = std::make_shared<NeedleTissueInteraction>(CollisionHandling::Side::A,
            colData,
            &needleProjList,
            g_needleObject,
            g_deformableObj);

        CHA_NTI->setScalingFactor(NTISimulationConfig::forceScalingFactor);

        // Add the interaction pair to the scene
        g_scene->getCollisionGraph()->addInteractionPair(std::dynamic_pointer_cast<CollidingObject>(g_deformableObj),
            g_needleObject,
            CD_NTI,
            nullptr,
            CHA_NTI);

        // Rotate the dragon every frame
        auto communicateNeedleAxisAndState = [](Module* module)
        {
            CHA_NTI->updateNeedleAxis(CD_NTI->getNeedleAxis());
            CHA_NTI->setNeedleState(CD_NTI->getNeedleState());
        };
        g_sdk->getSceneManager(g_scene)->setPostUpdateCallback(communicateNeedleAxisAndState);
    }
    else
    {
        return 0;
    }

    return 1;
}

int main()
{
    // Create simulation manager and Scene
    g_sdk = std::make_shared<SimulationManager>();
    g_scene = g_sdk->createNewScene("Anesthesia Simulator");

    createTissue();
    createNeedle();
    createNeedleTissueInteraction();

    // optionally render debug info
    /*if (NTISimulationConfig::renderDebugInfo)
    {
        auto constrainedNodesDisplay = std::make_shared<PointSet>();
        StdVectorOfVec3d dbgPointList(300, Vec3d(0., 0., 0.));
        constrainedNodesDisplay->initialize(dbgPointList);

        auto dbgMaterial = std::make_shared<RenderMaterial>();
        dbgMaterial->setDisplayMode(RenderMaterial::DisplayMode::POINTS);
        dbgMaterial->setPointSize(6.0);
        dbgMaterial->setColor(imstk::Color::Pink);

        auto constrainedNodesDisplayModel = std::make_shared<VisualModel>(constrainedNodesDisplay);
        constrainedNodesDisplayModel->setRenderMaterial(dbgMaterial);

        auto constrainedNodesObj = std::make_shared<VisualObject>("debugDisplayObj");
        constrainedNodesObj->setVisualGeometry(constrainedNodesDisplay);
        g_scene->addSceneObject(constrainedNodesObj);

        auto udpdateConstrNodes = [&](Module* module)
        {
            if (colData.NeedleColData.size() != 0)
            {
                for (auto& c : colData.NeedleColData)
                {
                    constrainedNodesDisplay->setVertexPosition(c.nodeId, g_volTetMesh->getVertexPosition(c.nodeId));
                }

                for (int i = colData.NeedleColData.size(); i < dbgPointList.size(); ++i)
                {
                    constrainedNodesDisplay->setVertexPosition(i, Vec3d(0., 0., 0.));
                }
            }
        };
        g_sdk->getSceneManager(g_scene)->setPostUpdateCallback(udpdateConstrNodes);
    }*/

    createPlane();
    createVirtualCouplingWithNeedle();
    auto translateVisualMesh = [](Module* module)
    {
        auto collData = g_CD->getCollisionData().PDColData;
        Vec3d t(0., 0., 0.);
        if (collData.size() > 0)
        {
            t = collData[0].dirAtoB;
        }
        g_needleObject->getVisualGeometry()->setTranslation(g_deviceTracker->getPosition() + t);
        g_needleObject->getVisualGeometry()->setRotation(g_deviceTracker->getRotation());
    };
    g_sdk->getSceneManager(g_scene)->setPostUpdateCallback(translateVisualMesh);

    // Set up Light
    auto light = std::make_shared<DirectionalLight>("light1");    
    light->setFocalPoint(Vec3d(5, -8, -5));
    light->setIntensity(2);
    g_scene->addLight(light);

    // Set up Camera
    auto camera = g_scene->getCamera();
    camera->setPosition(imstk::Vec3d(0., 30., 100.));
    camera->setFocalPoint(imstk::Vec3d(0., 0., 0.));

    // optionally print frame rate
    if (NTISimulationConfig::dispayFPS)
    {
        apiutils::printUPS(g_sdk->getSceneManager(g_scene), std::make_shared<UPSCounter>());
    }

    if (NTISimulationConfig::logPosVel)
    {
        g_logger = std::make_shared<AnesthesiaSimLogger>("posVelLog.txt");
        g_logger->enableLogging();
        auto updateLogger = [](Module* module)
        {
            // Log position & velocity
            auto p = g_deviceTracker->getPosition();
            auto v = g_deviceTracker->getDeviceClient()->getVelocity();
            v = g_deviceTracker->getRotationOffset() * v * g_deviceTracker->getTranslationScaling();
            g_logger->logPosAndVel(p, v);
        };
        g_sdk->getSceneManager(g_scene)->setPostUpdateCallback(updateLogger);
    }

    // Run the simulation
    g_sdk->setActiveScene(g_scene);
    g_scene->getCamera()->setPosition(0, 2.0, 150.0);
    g_sdk->getViewer()->setBackgroundColors(NTISimulationConfig::bgColor1, NTISimulationConfig::bgColor2, true);
    g_sdk->startSimulation(SimulationStatus::RUNNING);

    if (NTISimulationConfig::logPosVel)
    {
        g_logger->disableLogging();
        g_logger->shutdown();
        g_logger.reset();
    }

    return 0;
}