diff --git a/Base/Core/imstkMath.h b/Base/Core/imstkMath.h index 30f1610c8a1ec4e980da33a6f3a20447ef01b576..4be8dd825d013972e61d5262d4541a04a1c48284 100644 --- a/Base/Core/imstkMath.h +++ b/Base/Core/imstkMath.h @@ -61,6 +61,15 @@ using Mat3d = Eigen::Matrix3d; using Mat4f = Eigen::Matrix4f; using Mat4d = Eigen::Matrix4d; +/// A dynamic size matrix of floats +using Matrixf = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>; + +/// A dynamic size matrix of doubles +using Matrixd = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>; + +// A dynamic size sparse matrix of doubles +using SparseMatrixf = Eigen::SparseMatrix < float, Eigen::RowMajor >; + // A dynamic size sparse matrix of doubles using SparseMatrixd = Eigen::SparseMatrix < double, Eigen::RowMajor > ; diff --git a/Base/Solvers/imstkConjugateGradient.cpp b/Base/Solvers/imstkConjugateGradient.cpp index 84ff5ded1978fc8498902594955b6ae0c67e4d93..d47b0ebc0a234955e98c6edd085be7685004c8a4 100644 --- a/Base/Solvers/imstkConjugateGradient.cpp +++ b/Base/Solvers/imstkConjugateGradient.cpp @@ -33,14 +33,12 @@ ConjugateGradient::ConjugateGradient(const SparseMatrixd& A, const Vectord& rhs) cgSolver.compute(A); } - void ConjugateGradient::iterate(Vectord& , bool) { // Nothing to do } - void ConjugateGradient::solve(Vectord& x) { @@ -52,22 +50,12 @@ ConjugateGradient::solve(Vectord& x) x = cgSolver.solve(m_linearSystem->getRHSVector()); } - -void -ConjugateGradient::solve(Vectord& x, double tolerance) -{ - setTolerance(tolerance); - solve(x); -} - - double ConjugateGradient::getResidual(const Vectord& ) { return cgSolver.error(); } - void ConjugateGradient::setTolerance(const double epsilon) { @@ -75,7 +63,6 @@ ConjugateGradient::setTolerance(const double epsilon) cgSolver.setTolerance(epsilon); } - void ConjugateGradient::setMaxNumIterations(const size_t maxIter) { @@ -83,7 +70,6 @@ ConjugateGradient::setMaxNumIterations(const size_t maxIter) cgSolver.setMaxIterations(maxIter); } - void ConjugateGradient::setSystem(std::shared_ptr<LinearSystem<SparseMatrixd>> newSystem) { @@ -91,7 +77,6 @@ ConjugateGradient::setSystem(std::shared_ptr<LinearSystem<SparseMatrixd>> newSys this->cgSolver.compute(this->m_linearSystem->getMatrix()); } - void ConjugateGradient::print() { @@ -102,4 +87,11 @@ ConjugateGradient::print() LOG(INFO) << "max. iterations: " << m_maxIterations; } +void +ConjugateGradient::solve(Vectord& x, const double tolerance) +{ + setTolerance(tolerance); + solve(x); +} + } diff --git a/Base/Solvers/imstkConjugateGradient.h b/Base/Solvers/imstkConjugateGradient.h index e5f7c7506317ad87e238af847fcc9a9073dec824..b3678bfd95227b894ce9ef3f3b6d90c9aef34897 100644 --- a/Base/Solvers/imstkConjugateGradient.h +++ b/Base/Solvers/imstkConjugateGradient.h @@ -65,7 +65,7 @@ public: /// \brief Solve the linear system using Conjugate gradient iterations to a /// specified tolerance. /// - void solve(Vectord& x, double tolerance); + void solve(Vectord& x, const double tolerance); /// /// \brief Return the error calculated by the solver. diff --git a/Base/Solvers/imstkDirectLinearSolver.cpp b/Base/Solvers/imstkDirectLinearSolver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dc46366c017a85bc929b886ed6de31a4b7145f75 --- /dev/null +++ b/Base/Solvers/imstkDirectLinearSolver.cpp @@ -0,0 +1,95 @@ +// This file is part of the iMSTK project. +// +// Copyright (c) Kitware, Inc. +// +// Copyright (c) Center for Modeling, Simulation, and Imaging in Medicine, +// Rensselaer Polytechnic Institute +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "imstkDirectLinearSolver.h" +#include "imstkLinearSystem.h" +#include <iostream> + +namespace imstk +{ + +DirectLinearSolver<Matrixd>:: +DirectLinearSolver(const Matrixd &matrix, const Vectord &b) +{ + this->m_linearSystem = std::make_shared<LinearSystem<Matrixd>>(matrix, b); + this->solver.compute(matrix); +} + +void DirectLinearSolver<Matrixd>:: +setSystem(std::shared_ptr<LinearSystem<Matrixd>> newSystem) +{ + LinearSolver<Matrixd>::setSystem(newSystem); + this->solver.compute(this->m_linearSystem->getMatrix()); +} + +DirectLinearSolver<SparseMatrixd>:: +DirectLinearSolver(const SparseMatrixd &matrix, const Vectord &b) +{ + this->m_linearSystem = std::make_shared<LinearSystem<SparseMatrixd>>(matrix, b); + this->solver.compute(matrix); +} + +void DirectLinearSolver<SparseMatrixd>:: +setSystem(std::shared_ptr<LinearSystem<SparseMatrixd>> newSystem) +{ + LinearSolver<SparseMatrixd>::setSystem(newSystem); + this->solver.compute(this->m_linearSystem->getMatrix()); +} + +void +DirectLinearSolver<SparseMatrixd>::solve(const Vectord &rhs, Vectord &x) +{ + x = this->solver.solve(rhs); +} + +void +DirectLinearSolver<SparseMatrixd>::solve(Vectord &x) +{ + if (!this->m_linearSystem) + { + return; + } + + x.setZero(); + + auto b = this->m_linearSystem->getRHSVector(); + x = this->solver.solve(b); +} + +void +DirectLinearSolver<Matrixd>::solve(const Vectord &rhs, Vectord &x) +{ + x = this->solver.solve(rhs); +} + +void +DirectLinearSolver<Matrixd>::solve(Vectord &x) +{ + if (!this->m_linearSystem) + { + // TODO: Log this + return; + } + + x.setZero(); + + auto b = this->m_linearSystem->getRHSVector(); + x = this->solver.solve(b); +} +} diff --git a/Base/Solvers/imstkDirectLinearSolver.h b/Base/Solvers/imstkDirectLinearSolver.h new file mode 100644 index 0000000000000000000000000000000000000000..c4cf76dc49820dc29495171f98ca1c29347ee29a --- /dev/null +++ b/Base/Solvers/imstkDirectLinearSolver.h @@ -0,0 +1,113 @@ +// This file is part of the iMSTK project. +// +// Copyright (c) Kitware, Inc. +// +// Copyright (c) Center for Modeling, Simulation, and Imaging in Medicine, +// Rensselaer Polytechnic Institute +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef imstkDirectLinearSolver_h +#define imstkDirectLinearSolver_h + +#include <Eigen/Sparse> +#include<Eigen/SparseLU> + +// iMSTK includes +#include "imstkLinearSolver.h" +#include "imstkMath.h" + +namespace imstk +{ + +template<typename MatrixType> class DirectLinearSolver; + +/// +/// \brief Dense direct solvers. Solves a dense system of equations using Cholesky +/// decomposition. +/// +template<> +class DirectLinearSolver<Matrixd> : public LinearSolver<Matrixd> +{ +public: + /// + /// \brief Default constructor/destructor. + /// + DirectLinearSolver() = delete; + ~DirectLinearSolver() = default; + + /// + /// \brief Constructor + /// + DirectLinearSolver(const Matrixd &A, const Vectord &b); + + /// + /// \brief Solve the system of equations. + /// + void solve(Vectord &x) override; + + /// + /// \brief Solve the system of equations for arbitrary right hand side vector. + /// + void solve(const Vectord &rhs, Vectord &x); + + /// + /// \brief Sets the system. System of linear equations. + /// + void setSystem(std::shared_ptr<LinearSystemType> newSystem) override; + +private: + Eigen::LDLT<Matrixd> solver; +}; + +/// +/// \brief Sparse direct solvers. Solves a sparse system of equations using a sparse LU +/// decomposition. +/// +template<> +class DirectLinearSolver<SparseMatrixd> : public LinearSolver<SparseMatrixd> +{ +public: + /// + /// \brief Default constructor/destructor + /// + DirectLinearSolver() = default; + ~DirectLinearSolver() = default; + + /// + /// \brief Constructor + /// + DirectLinearSolver(const SparseMatrixd &matrix, const Vectord &b); + + /// + /// \brief Sets the system. System of linear equations. + /// + void setSystem(std::shared_ptr<LinearSystemType> newSystem) override; + + /// + /// \brief Solve the system of equations + /// + void solve(Vectord &x) override; + + /// + /// \brief Solve the system of equations for arbitrary right hand side vector. + /// + void solve(const Vectord &rhs, Vectord &x); + +private: + Eigen::SparseLU<SparseMatrixd, Eigen::COLAMDOrdering<MatrixType::Index>> solver;//? +}; + +} + +#endif // SOLVERS_DIRECT_LINEAR_SOLVER diff --git a/Base/Solvers/imstkLinearSolver.cpp b/Base/Solvers/imstkLinearSolver.cpp index 32e48bcfda56896edba95906ba74a3533d1cb1d3..c3d4353119a36b9ef954240f4ec7b0dbd1a90d6d 100644 --- a/Base/Solvers/imstkLinearSolver.cpp +++ b/Base/Solvers/imstkLinearSolver.cpp @@ -32,15 +32,15 @@ template<typename SystemMatrixType> void LinearSolver<SystemMatrixType>::setSystem(std::shared_ptr<LinearSystem<SystemMatrixType>> newSystem) { - this->m_linearSystem.reset(); - this->m_linearSystem = newSystem; + m_linearSystem.reset(); + m_linearSystem = newSystem; } template<typename SystemMatrixType> std::shared_ptr<LinearSystem<SystemMatrixType>> LinearSolver<SystemMatrixType>::getSystem() const { - return this->m_linearSystem; + return m_linearSystem; }