Commit 44120404 authored by Sreekanth Arikatla's avatar Sreekanth Arikatla
Browse files

WIP: refactor solver classes

parent 85f969ab
......@@ -5,6 +5,7 @@ include(imstkAddLibrary)
imstk_add_library( Solvers
DEPENDS
Core
SceneElements
)
#-----------------------------------------------------------------------------
......
......@@ -17,80 +17,65 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "imstkConjugateGradient.h"
#include "imstkConjugateGradientEigen.h"
namespace imstk
{
ConjugateGradient::ConjugateGradient(const SparseMatrixd& A, const Vectord& rhs) : cgSolver(A)
ConjugateGradientEigen::ConjugateGradientEigen(const SparseMatrixd& A, const Vectord& rhs) : m_cgSolver(A)
{
m_linearSystem = std::make_shared<LinearSystem<SparseMatrixd>>(A, rhs);
m_maxIterations = rhs.size();
m_tolerance = 1.0e-6;
m_tolerance = 1.0e-5;
cgSolver.setMaxIterations(m_maxIterations);
cgSolver.setTolerance(m_tolerance);
cgSolver.compute(A);
m_cgSolver.setMaxIterations(m_maxIterations);
m_cgSolver.setTolerance(m_tolerance);
m_cgSolver.compute(A);
}
void
ConjugateGradient::iterate(Vectord& , bool)
ConjugateGradientEigen::solve(Vectord& x)
{
// Nothing to do
}
void
ConjugateGradient::solve(Vectord& x)
{
if(!this->m_linearSystem)
if(!m_linearSystem)
{
// TODO: Log this
LOG(WARNING) << "ConjugateGradientEigen::solve : Linear system is not set\n";
return;
}
cgSolver.setMaxIterations(1000);
cgSolver.setTolerance(1.0e-5);
x = cgSolver.solve(m_linearSystem->getRHSVector());
}
double
ConjugateGradient::getResidual(const Vectord& )
{
return cgSolver.error();
x = m_cgSolver.solve(m_linearSystem->getRHSVector());
}
void
ConjugateGradient::setTolerance(const double epsilon)
ConjugateGradientEigen::setTolerance(const double epsilon)
{
IterativeLinearSolver::setTolerance(epsilon);
cgSolver.setTolerance(epsilon);
m_cgSolver.setTolerance(epsilon);
}
void
ConjugateGradient::setMaxNumIterations(const size_t maxIter)
ConjugateGradientEigen::setMaxNumIterations(const size_t maxIter)
{
IterativeLinearSolver::setMaxNumIterations(maxIter);
cgSolver.setMaxIterations(maxIter);
m_cgSolver.setMaxIterations(maxIter);
}
void
ConjugateGradient::setSystem(std::shared_ptr<LinearSystem<SparseMatrixd>> newSystem)
ConjugateGradientEigen::setSystem(std::shared_ptr<LinearSystem<SparseMatrixd>> newSystem)
{
LinearSolver<SparseMatrixd>::setSystem(newSystem);
this->cgSolver.compute(this->m_linearSystem->getMatrix());
m_cgSolver.compute(m_linearSystem->getMatrix());
}
void
ConjugateGradient::print() const
ConjugateGradientEigen::print() const
{
IterativeLinearSolver::print();
LOG(INFO) << "Solver: Conjugate gradient";
LOG(INFO) << "Solver: Conjugate gradient based on Eigen";
LOG(INFO) << "Tolerance: " << m_tolerance;
LOG(INFO) << "max. iterations: " << m_maxIterations;
}
void
ConjugateGradient::solve(Vectord& x, const double tolerance)
ConjugateGradientEigen::solve(Vectord& x, const double tolerance)
{
setTolerance(tolerance);
solve(x);
......
......@@ -17,8 +17,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef imstkConjugateGradient_h
#define imstkConjugateGradient_h
#ifndef imstkConjugateGradientEigen_h
#define imstkConjugateGradientEigen_h
#include <memory>
......@@ -33,49 +33,30 @@ namespace imstk
{
///
/// \brief Conjugate gradient sparse linear solver for SPD matrices
/// \brief Conjugate gradient sparse linear solver for SPD matrices using Eigen
///
class ConjugateGradient : public IterativeLinearSolver
class ConjugateGradientEigen : public IterativeLinearSolver<SparseMatrixd>
{
public:
///
/// \brief Constructors/Destructor
///
ConjugateGradient() = default;
~ConjugateGradient() = default;
ConjugateGradientEigen() = default;
ConjugateGradientEigen(const SparseMatrixd &A, const Vectord& rhs);
~ConjugateGradientEigen() = default;
///
/// \brief Constructor
///
ConjugateGradient(const SparseMatrixd &A, const Vectord& rhs);
ConjugateGradient(const ConjugateGradient &) = delete;
ConjugateGradient &operator=(const ConjugateGradient &) = delete;
///
/// \brief Do one iteration of the method.
///
void iterate(Vectord& x, bool updateResidual = true) override;
///
/// \brief Solve the system of equations.
///
void solve(Vectord& x) override;
///
/// \brief Solve the linear system using Conjugate gradient iterations to a
/// specified tolerance.
///
void solve(Vectord& x, const double tolerance);
ConjugateGradientEigen(const ConjugateGradientEigen &) = delete;
ConjugateGradientEigen &operator=(const ConjugateGradientEigen &) = delete;
///
/// \brief Return the error calculated by the solver.
/// \brief Sets the system. System of linear equations.
///
double getResidual(const Vectord& x) override;
void setSystem(std::shared_ptr<LinearSystemType> newSystem) override;
///
/// \brief Sets the system. System of linear equations.
/// \brief Returns the type of the linear solver
///
void setSystem(std::shared_ptr<LinearSystemType> newSystem) override;
Type getType() const final { return Type::ConjugateGradient; };
///
/// \brief set/get the maximum number of iterations for the iterative solver.
......@@ -92,11 +73,21 @@ public:
///
void print() const override;
///
/// \brief Solve the system of equations.
///
void solve(Vectord& x) override;
///
/// \brief Solve the linear system using Conjugate gradient iterations to a
/// specified tolerance.
///
void solve(Vectord& x, const double tolerance);
private:
///> Pointer to the Eigen's Conjugate gradient solver
Eigen::ConjugateGradient<SparseMatrixd> cgSolver;
Eigen::ConjugateGradient<SparseMatrixd> m_cgSolver;///> Eigen's Conjugate gradient solver
};
} // imstk
#endif // imstkConjugateGradient_h
#endif // imstkConjugateGradientEigen_h
......@@ -19,76 +19,74 @@
#include "imstkDirectLinearSolver.h"
#include "imstkLinearSystem.h"
#include <iostream>
namespace imstk
{
DirectLinearSolver<Matrixd>::
DirectLinearSolver(const Matrixd &matrix, const Vectord &b)
DirectLinearSolver<Matrixd>::DirectLinearSolver(const Matrixd &matrix, const Vectord &b)
{
this->m_linearSystem = std::make_shared<LinearSystem<Matrixd>>(matrix, b);
this->solver.compute(matrix);
}
void DirectLinearSolver<Matrixd>::
setSystem(std::shared_ptr<LinearSystem<Matrixd>> newSystem)
void
DirectLinearSolver<Matrixd>::setSystem(std::shared_ptr<LinearSystem<Matrixd>> newSystem)
{
LinearSolver<Matrixd>::setSystem(newSystem);
this->solver.compute(this->m_linearSystem->getMatrix());
}
DirectLinearSolver<SparseMatrixd>::
DirectLinearSolver(const SparseMatrixd &matrix, const Vectord &b)
{
this->m_linearSystem = std::make_shared<LinearSystem<SparseMatrixd>>(matrix, b);
this->solver.compute(matrix);
}
void DirectLinearSolver<SparseMatrixd>::
setSystem(std::shared_ptr<LinearSystem<SparseMatrixd>> newSystem)
{
LinearSolver<SparseMatrixd>::setSystem(newSystem);
this->solver.compute(this->m_linearSystem->getMatrix());
}
void
DirectLinearSolver<SparseMatrixd>::solve(const Vectord &rhs, Vectord &x)
DirectLinearSolver<Matrixd>::solve(const Vectord &rhs, Vectord &x)
{
x = this->solver.solve(rhs);
x = this->solver.solve(rhs);// CHECK
}
void
DirectLinearSolver<SparseMatrixd>::solve(Vectord &x)
DirectLinearSolver<Matrixd>::solve(Vectord &x)
{
if (!this->m_linearSystem)
{
LOG(WARNING) << "DirectLinearSolver<Matrixd>::solve : Linear system non-existent!\n";
return;
}
x.setZero();
auto b = this->m_linearSystem->getRHSVector();
x = this->solver.solve(b);
}
//-------------------------------------------------
DirectLinearSolver<SparseMatrixd>::DirectLinearSolver(const SparseMatrixd &matrix, const Vectord &b)
{
this->m_linearSystem = std::make_shared<LinearSystem<SparseMatrixd>>(matrix, b);
this->solver.compute(matrix);
}
void
DirectLinearSolver<Matrixd>::solve(const Vectord &rhs, Vectord &x)
DirectLinearSolver<SparseMatrixd>::setSystem(std::shared_ptr<LinearSystem<SparseMatrixd>> newSystem)
{
LinearSolver<SparseMatrixd>::setSystem(newSystem);
this->solver.compute(this->m_linearSystem->getMatrix());
}
void
DirectLinearSolver<SparseMatrixd>::solve(const Vectord &rhs, Vectord &x)
{
x = this->solver.solve(rhs);
x = this->solver.solve(rhs);// CHECK
}
void
DirectLinearSolver<Matrixd>::solve(Vectord &x)
DirectLinearSolver<SparseMatrixd>::solve(Vectord &x)
{
if (!this->m_linearSystem)
{
// TODO: Log this
LOG(WARNING) << "DirectLinearSolver<SparseMatrixd>::solve(Vectord &x) : Linear system non-existent!\n";
return;
}
x.setZero();
auto b = this->m_linearSystem->getRHSVector();
x = this->solver.solve(b);
}
......
......@@ -21,11 +21,10 @@
#define imstkDirectLinearSolver_h
#include <Eigen/Sparse>
#include<Eigen/SparseLU>
#include <Eigen/SparseLU>
// iMSTK includes
#include "imstkLinearSolver.h"
#include "imstkMath.h"
namespace imstk
{
......@@ -34,7 +33,7 @@ template<typename MatrixType> class DirectLinearSolver;
///
/// \brief Dense direct solvers. Solves a dense system of equations using Cholesky
/// decomposition.
/// decomposition.
///
template<>
class DirectLinearSolver<Matrixd> : public LinearSolver<Matrixd>
......@@ -43,39 +42,37 @@ public:
///
/// \brief Default constructor/destructor.
///
DirectLinearSolver() = delete;
~DirectLinearSolver(){};
DirectLinearSolver() = default;
DirectLinearSolver(const Matrixd &matrix, const Vectord &b);
~DirectLinearSolver() = default;
///
/// \brief Constructor
/// \brief Sets the system. System of linear equations.
///
DirectLinearSolver(const Matrixd &A, const Vectord &b);
void setSystem(std::shared_ptr<LinearSystemType> newSystem) override;
///
/// \brief Solve the system of equations.
/// \brief Returns true if the solver is iterative
///
void solve(Vectord &x) override;
bool isIterative() const { return false; };
///
/// \brief Solve the system of equations for arbitrary right hand side vector.
/// \brief Returns the type of the linear solver
///
void solve(const Vectord &rhs, Vectord &x);
Type getType() const final { return Type::CholeskyFactorization; };
///
/// \brief Sets the system. System of linear equations.
/// \brief Solve the system of equations.
///
void setSystem(std::shared_ptr<LinearSystemType> newSystem) override;
void solve(Vectord &x) override;
///
/// \brief Returns true if the solver is iterative
/// \brief Solve the system of equations for arbitrary right hand side vector.
///
bool isIterative() const
{
return false;
};
void solve(const Vectord &rhs, Vectord &x);
private:
Eigen::LDLT<Matrixd> solver;
Eigen::LDLT<Matrixd> solver; ///> Eigen's direct linear solver based on cholesky decomposition
};
///
......@@ -90,17 +87,24 @@ public:
/// \brief Default constructor/destructor
///
DirectLinearSolver() = default;
DirectLinearSolver(const SparseMatrixd &matrix, const Vectord &b);
~DirectLinearSolver() = default;
///
/// \brief Constructor
/// \brief Sets the system. System of linear equations.
///
DirectLinearSolver(const SparseMatrixd &matrix, const Vectord &b);
void setSystem(std::shared_ptr<LinearSystemType> newSystem) override;
///
/// \brief Sets the system. System of linear equations.
/// \brief Returns true if the solver is iterative
///
void setSystem(std::shared_ptr<LinearSystemType> newSystem) override;
bool isIterative() const { return false; };
///
/// \brief Returns the type of the linear solver
///
Type getType() const final { return Type::LUFactorization; };
///
/// \brief Solve the system of equations
......@@ -113,7 +117,7 @@ public:
void solve(const Vectord &rhs, Vectord &x);
private:
Eigen::SparseLU<SparseMatrixd, Eigen::COLAMDOrdering<MatrixType::Index>> solver;//?
Eigen::SparseLU<SparseMatrixd, Eigen::COLAMDOrdering<SparseMatrixd::Index>> solver; ///> Eigen LU solver
};
} // imstk
......
......@@ -21,74 +21,4 @@
namespace imstk
{
IterativeLinearSolver::IterativeLinearSolver()
: m_maxIterations(100)
{
}
void
IterativeLinearSolver::setMaxNumIterations(const size_t maxIter)
{
this->m_maxIterations = maxIter;
}
size_t
IterativeLinearSolver::getMaxNumIterations() const
{
return this->m_maxIterations;
}
const Vectord&
IterativeLinearSolver::getResidualVector()
{
return this->m_residual;
}
const Vectord&
IterativeLinearSolver::getResidualVector(const Vectord& x)
{
this->m_linearSystem->computeResidual(x, this->m_residual);
return this->m_residual;
}
double
IterativeLinearSolver::getResidual(const Vectord& x)
{
this->m_linearSystem->computeResidual(x, this->m_residual);
return this->m_residual.squaredNorm();
}
void
IterativeLinearSolver::print() const
{
// Print Type
LinearSolver::print();
LOG(INFO) << "Solver type (direct/iterative): Iterative";
}
void
IterativeLinearSolver::solve(Vectord& x)
{
if (!this->m_linearSystem)
{
LOG(WARNING) << "IterativeLinearSolver::solve: The linear system should be assigned before solving!";
return;
}
auto epsilon = m_tolerance * m_tolerance;
m_linearSystem->computeResidual(x, m_residual);
for (size_t i = 0; i < m_maxIterations; ++i)
{
if (m_residual.squaredNorm() < epsilon)
{
return;
}
this->iterate(x);
}
}
} //imstk
\ No newline at end of file
......@@ -17,8 +17,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef imstkIterativeLinearSystem_h
#define imstkIterativeLinearSystem_h
#ifndef imstkIterativeLinearSolver_h
#define imstkIterativeLinearSolver_h
// iMSTK includes
#include "imstkLinearSolver.h"
......@@ -29,60 +29,39 @@ namespace imstk
///
/// \brief Base class for iterative linear solvers.
///
class IterativeLinearSolver : public LinearSolver<SparseMatrixd>
template<typename T>
class IterativeLinearSolver : public LinearSolver<T>
{
public:
protected:
///
/// \brief Default constructor/destructor
///
IterativeLinearSolver();
virtual ~IterativeLinearSolver() {};
IterativeLinearSolver() : m_maxIterations(100) {};
///
/// \brief Do one iteration of the method.
///
virtual void iterate(Vectord &x, bool updateResidual = true) = 0;
public:
virtual ~IterativeLinearSolver() {};
///
/// \brief set/get the maximum number of iterations for the iterative solver.
///
virtual void setMaxNumIterations(const size_t maxIter);
virtual size_t getMaxNumIterations() const;
///
/// \brief Return residual vector
///
virtual const Vectord& getResidualVector();
virtual const Vectord& getResidualVector(const Vectord& x);
///
/// \brief Return residue in 2-norm
///
virtual double getResidual(const Vectord &x);
virtual void setMaxNumIterations(const size_t maxIter) { m_maxIterations = maxIter; }
virtual size_t getMaxNumIterations() const { return m_maxIterations; }
///
/// \brief Print solver information.
///
void print() const override;
///
/// \brief Solve the linear system using Gauss-Seidel iterations.
///
virtual void solve(Vectord &x) override;
void print() const override { LOG(INFO) << "Solver type (direct/iterative): Iterative"; }
///
/// \brief Returns true if the solver is iterative
///
bool isIterative() const
{
return true;
};
bool isIterative() const { return true; };
protected:
size_t m_maxIterations; ///> Maximum number of iterations to be performed.
Vectord m_residual; ///> Storage for residual vector.
size_t m_maxIterations; ///> Maximum number of iterations to be performed
Vectord m_residual; ///> Storage for residual vector
};
} //imstk
#endif // imstkIterativeLinearSystem_h
#endif // imstkIterativeLinearSolver_h
......@@ -32,16 +32,16 @@ namespace imstk
///
/// \brief Base class for linear solvers
///
template<typename SystemMatrixType>
template<typename MatrixType>
class LinearSolver
{
public:
using MatrixType = SystemMatrixType;
using LinearSystemType = LinearSystem < MatrixType > ;
using LinearSystemType = LinearSystem<MatrixType> ;
enum class Type
{
ConjugateGradient,
CholeskyFactorization,
LUFactorization,
GaussSeidel,
SuccessiveOverRelaxation,
......@@ -57,11 +57,6 @@ public:
LinearSolver() : m_linearSystem(nullptr){};
virtual ~LinearSolver(){};
///