/*=========================================================================

Library: iMSTK

Copyright (c) Kitware, Inc. & Center for Modeling, Simulation,
& Imaging in Medicine, Rensselaer Polytechnic Institute.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0.txt

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

=========================================================================*/

#include <iostream>

#include "imstkNewtonSolver.h"
#include "imstkIterativeLinearSolver.h"
#include "imstkConjugateGradientEigen.h"

namespace imstk
{

NewtonSolver::NewtonSolver():
    m_linearSolver(std::make_shared<ConjugateGradientEigen>()),
    m_forcingTerm(0.9),
    m_absoluteTolerance(1e-3),
    m_relativeTolerance(1e-6),
    m_gamma(0.9),
    m_etaMax(0.9),
    m_maxIterations(50),
    m_useArmijo(true) {}


void
NewtonSolver::solveGivenState(Vectord& x)
{
    if(!m_nonLinearSystem)
    {
        LOG(WARNING) << "NewtonMethod::solve - nonlinear system is not set to the nonlinear solver";
        return;
    }

    // Compute norms, set tolerances and other temporaries
    double fnorm = m_nonLinearSystem->evaluateF(x).norm();
    double stopTolerance = m_absoluteTolerance + m_relativeTolerance * fnorm;

    m_linearSolver->setTolerance(stopTolerance);

    Vectord dx = x;

    for(size_t i = 0; i < m_maxIterations; ++i)
    {
        if(fnorm < stopTolerance)
        {
            return;
        }
        updateJacobian(x);
        m_linearSolver->solve(dx);
        m_updateIterate(-dx,x);

        double newNorm = fnorm;

        newNorm = armijo(dx, x, fnorm);

        if(m_forcingTerm > 0.0 && newNorm > stopTolerance)
        {
            double ratio = newNorm / fnorm; // Ratio of successive residual norms
            updateForcingTerm(ratio, stopTolerance, fnorm);

            // Reset tolerance in the linear solver according to the new forcing term
            // to avoid over solving of the system.
            m_linearSolver->setTolerance(m_forcingTerm);
        }

        fnorm = newNorm;
    }
}

void
NewtonSolver::solve()
{
    if (!m_nonLinearSystem)
    {
        LOG(WARNING) << "NewtonMethod::solve - nonlinear system is not set to the nonlinear solver";
        return;
    }

    auto u = m_nonLinearSystem->getUnknownVector();
    Vectord du = u;
    du.setZero();

    for (size_t i = 0; i < 1; ++i)
    {
        du.setZero();
        updateJacobian(u);
        m_linearSolver->solve(du);
        u -= du;
        m_nonLinearSystem->m_FUpdate(u);
    }
}

void
NewtonSolver::updateJacobian(const Vectord& x)
{
    // Evaluate the Jacobian and sets the matrix
    if (!m_nonLinearSystem)
    {
        LOG(WARNING) << "NewtonMethod::updateJacobian - nonlinear system is not set to the nonlinear solver";
        return;
    }

    auto &b = m_nonLinearSystem->m_F(x);
    auto &A = m_nonLinearSystem->m_dF(x);

    if (A.innerSize() == 0)
    {
        LOG(WARNING) << "NewtonMethod::updateJacobian - Size of matrix is 0!";
        return;
    }

    auto linearSystem = std::make_shared<LinearSolverType::LinearSystemType>(A, b);
    m_linearSolver->setSystem(linearSystem);
}

void
NewtonSolver::updateForcingTerm(const double ratio, const double stopTolerance, const double fnorm)
{
    double eta = m_gamma * ratio * ratio;
    double forcingTermSqr = m_forcingTerm * m_forcingTerm;

    // Save guard to prevent the forcing term to become too small for far away iterates
    if(m_gamma * forcingTermSqr > .1)
    {
        // TODO: Log this
        eta = std::max(eta, m_gamma * forcingTermSqr);
    }

    m_forcingTerm = std::max(std::min(eta, m_etaMax), 0.5 * stopTolerance / fnorm);
}


void
NewtonSolver::setLinearSolver(std::shared_ptr< NewtonSolver::LinearSolverType > newLinearSolver)
{
    m_linearSolver = newLinearSolver;
}


std::shared_ptr<NewtonSolver::LinearSolverType>
NewtonSolver::getLinearSolver() const
{
    return m_linearSolver;
}


void
NewtonSolver::setAbsoluteTolerance(const double aTolerance)
{
    m_absoluteTolerance = aTolerance;
}

double
NewtonSolver::getAbsoluteTolerance() const
{
    return m_absoluteTolerance;
}

} // imstk