// This file is part of the iMSTK project.
//
// Copyright (c) Kitware, Inc.
//
// Copyright (c) Center for Modeling, Simulation, and Imaging in Medicine,
//                        Rensselaer Polytechnic Institute
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "imstkConjugateGradientEigen.h"

namespace imstk
{

ConjugateGradientEigen::ConjugateGradientEigen(const SparseMatrixd& A, const Vectord& rhs) : m_cgSolver(A)
{
    m_linearSystem = std::make_shared<LinearSystem<SparseMatrixd>>(A, rhs);
    m_maxIterations = rhs.size();
    m_tolerance = 1.0e-5;

    m_cgSolver.setMaxIterations(m_maxIterations);
    m_cgSolver.setTolerance(m_tolerance);
    m_cgSolver.compute(A);
}

void
ConjugateGradientEigen::solve(Vectord& x)
{
    if(!m_linearSystem)
    {
        LOG(WARNING) << "ConjugateGradientEigen::solve : Linear system is not set\n";
        return;
    }

    x = m_cgSolver.solve(m_linearSystem->getRHSVector());
}

void
ConjugateGradientEigen::setTolerance(const double epsilon)
{
    IterativeLinearSolver::setTolerance(epsilon);
    m_cgSolver.setTolerance(epsilon);
}

void
ConjugateGradientEigen::setMaxNumIterations(const size_t maxIter)
{
    IterativeLinearSolver::setMaxNumIterations(maxIter);
    m_cgSolver.setMaxIterations(maxIter);
}

void
ConjugateGradientEigen::setSystem(std::shared_ptr<LinearSystem<SparseMatrixd>> newSystem)
{
    LinearSolver<SparseMatrixd>::setSystem(newSystem);
    m_cgSolver.compute(m_linearSystem->getMatrix());
}

void
ConjugateGradientEigen::print() const
{
    LOG(INFO) << "Solver: Conjugate gradient based on Eigen";
    LOG(INFO) << "Tolerance: " << m_tolerance;
    LOG(INFO) << "max. iterations: " << m_maxIterations;
}

void
ConjugateGradientEigen::solve(Vectord& x, const double tolerance)
{
    setTolerance(tolerance);
    solve(x);
}

} // imstk