// This file is part of the iMSTK project.
//
// Copyright (c) Kitware, Inc.
//
// Copyright (c) Center for Modeling, Simulation, and Imaging in Medicine,
//                        Rensselaer Polytechnic Institute
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef imstkLinearSolver_h
#define imstkLinearSolver_h

// imstk includes
#include "imstkLinearSystem.h"
#include "imstkMath.h"

#include "g3log/g3log.hpp"

namespace imstk
{

///
/// \brief Base class for linear solvers
///
template<typename MatrixType>
class LinearSolver
{
public:
    using LinearSystemType = LinearSystem<MatrixType> ;

    enum class Type
    {
        ConjugateGradient,
        CholeskyFactorization,
        LUFactorization,
        GaussSeidel,
        SuccessiveOverRelaxation,
        Jacobi,
        GMRES,
        none
    };

public:
    ///
    /// \brief Default constructor/destructor
    ///
    LinearSolver() : m_linearSystem(nullptr){};
    virtual ~LinearSolver(){};

    ///
    /// \brief Set/get the system. Replaces/Returns the stored linear system of equations.
    ///
    virtual void setSystem(std::shared_ptr<LinearSystemType> newSystem)
    {
        m_linearSystem.reset();
        m_linearSystem = newSystem;
    }

    ///
    /// \brief Return the linear system
    ///
    std::shared_ptr<LinearSystemType> getSystem() const { return m_linearSystem; }

    ///
    /// \brief Set solver tolerance
    ///
    void setTolerance(const double tolerance) { m_tolerance = tolerance; }
    double getTolerance() const { return m_tolerance; }

    ///
    /// \brief Return residue in 2-norm
    ///
    virtual double getResidual2Norm(const Vectord &x) { return m_linearSystem->computeResidual2Norm(x); }

    ///
    /// \brief Returns true if the solver is iterative
    ///
    virtual bool isIterative() const = 0;

    ///
    /// \brief Returns the type of the linear solver
    ///
    virtual Type getType() const { return Type::none; };

    ///
    /// \brief Print solver information.
    ///
    virtual void print() const { LOG(INFO) << "Linear solver base\n"; };

    ///
    /// \brief Solve the linear system
    ///
    virtual void solve(Vectord& x) = 0;

protected:
    double m_tolerance = MACHINE_PRECISION;             ///> default tolerance
    std::shared_ptr<LinearSystemType> m_linearSystem;   ///> Linear system of equations
};

} // imstk

#endif // imstkLinearSolver_h