diff --git a/Source/Solvers/imstkConjugateGradient.cpp b/Source/Solvers/imstkConjugateGradient.cpp index 39f4865b9b5665295c3b30f69c51da04416af6d7..c72c81412ffb5d2751ddec2b4da6bcef3b73d5ce 100644 --- a/Source/Solvers/imstkConjugateGradient.cpp +++ b/Source/Solvers/imstkConjugateGradient.cpp @@ -21,6 +21,8 @@ #include "imstkConjugateGradient.h" +#include <iostream> + namespace imstk { ConjugateGradient::ConjugateGradient() @@ -38,19 +40,14 @@ ConjugateGradient::ConjugateGradient(const SparseMatrixd& A, const Vectord& rhs) void ConjugateGradient::applyLinearProjectionFilter(Vectord& x, const std::vector<LinearProjectionConstraint>& linProj, const bool setVal) { - Vec3d p; for (auto &localProjector : linProj) { const auto threeI = 3 * localProjector.getNodeId(); - p = Vec3d(x(threeI), x(threeI + 1), x(threeI + 2)); - - if (!setVal) - { - p = localProjector.getProjector()*p; - } - else + Vec3d p = localProjector.getProjector()*Vec3d(x(threeI), x(threeI + 1), x(threeI + 2)); + + if (setVal) { - p = (Mat3d::Identity() - localProjector.getProjector())*localProjector.getValue(); + p += (Mat3d::Identity() - localProjector.getProjector())*localProjector.getValue(); } x(threeI) = p.x(); @@ -68,7 +65,7 @@ ConjugateGradient::solve(Vectord& x) return; } - if (m_FixedLinearProjConstraints->size() == 0) + if (!(m_FixedLinearProjConstraints || m_DynamicLinearProjConstraints)) { x = m_cgSolver.solve(m_linearSystem->getRHSVector()); } @@ -86,12 +83,25 @@ ConjugateGradient::modifiedCGSolve(Vectord& x) // Set the initial guess to zero x.setZero(); - applyLinearProjectionFilter(x, *m_DynamicLinearProjConstraints, true); - applyLinearProjectionFilter(x, *m_FixedLinearProjConstraints, true); + if (m_DynamicLinearProjConstraints) + { + applyLinearProjectionFilter(x, *m_DynamicLinearProjConstraints, true); + } + + if (m_FixedLinearProjConstraints) + { + applyLinearProjectionFilter(x, *m_FixedLinearProjConstraints, true); + } auto res = b; - applyLinearProjectionFilter(res, *m_DynamicLinearProjConstraints, false); - applyLinearProjectionFilter(res, *m_FixedLinearProjConstraints, false); + if (m_DynamicLinearProjConstraints) + { + applyLinearProjectionFilter(res, *m_DynamicLinearProjConstraints, false); + } + if (m_FixedLinearProjConstraints) + { + applyLinearProjectionFilter(res, *m_FixedLinearProjConstraints, false); + } auto c = res; auto delta = res.dot(res); auto deltaPrev = delta; @@ -104,8 +114,14 @@ ConjugateGradient::modifiedCGSolve(Vectord& x) while (delta > eps) { q = A * c; - applyLinearProjectionFilter(q, *m_DynamicLinearProjConstraints, false); - applyLinearProjectionFilter(q, *m_FixedLinearProjConstraints, false); + if (m_DynamicLinearProjConstraints) + { + applyLinearProjectionFilter(q, *m_DynamicLinearProjConstraints, false); + } + if (m_FixedLinearProjConstraints) + { + applyLinearProjectionFilter(q, *m_FixedLinearProjConstraints, false); + } dotval = c.dot(q); if (dotval != 0.0) { @@ -122,12 +138,18 @@ ConjugateGradient::modifiedCGSolve(Vectord& x) delta = res.dot(res); c *= delta / deltaPrev; c += res; - applyLinearProjectionFilter(c, *m_DynamicLinearProjConstraints, false); - applyLinearProjectionFilter(c, *m_FixedLinearProjConstraints, false); + if (m_DynamicLinearProjConstraints) + { + applyLinearProjectionFilter(c, *m_DynamicLinearProjConstraints, false); + } + if (m_FixedLinearProjConstraints) + { + applyLinearProjectionFilter(c, *m_FixedLinearProjConstraints, false); + } if (++iterNum >= m_maxIterations) { - LOG(WARNING) << "ConjugateGradient::modifiedCGSolve - The solver did not converge after max. iterations"; + //LOG(WARNING) << "ConjugateGradient::modifiedCGSolve - The solver did not converge after max. iterations"; break; } } diff --git a/Source/Solvers/imstkConjugateGradient.h b/Source/Solvers/imstkConjugateGradient.h index cf16aafdbbc6013546b1e4a9f5fa2dfbaed85afc..544150ddc3df64678e772df7d6d7ba6d1bfca286 100644 --- a/Source/Solvers/imstkConjugateGradient.h +++ b/Source/Solvers/imstkConjugateGradient.h @@ -98,6 +98,7 @@ public: /// void applyLinearProjectionFilter(Vectord& x, const std::vector<LinearProjectionConstraint>& linProj, const bool setVal); + /// /// \brief Get the vector denoting the filter /// void setLinearProjectors(std::vector<LinearProjectionConstraint>* f)