Commit 757fe326 authored by Nghia Truong's avatar Nghia Truong
Browse files

ENH: Implement OctreeBasedCD for collision detection using octree

parent 7cf4b857
/*=========================================================================
Library: iMSTK
Copyright (c) Kitware, Inc. & Center for Modeling, Simulation,
& Imaging in Medicine, Rensselaer Polytechnic Institute.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0.txt
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=========================================================================*/
#include "imstkOctreeBasedCD.h"
#include "imstkGeometry.h"
#include "imstkPointSet.h"
#include "imstkSurfaceMesh.h"
// Collision detection headers
#include "imstkNarrowPhaseCD.h"
#include "imstkCollisionData.h"
#include "imstkCollisionUtils.h"
namespace imstk
{
void
OctreeBasedCD::clear()
{
LooseOctree::clear();
m_mCollisionPair2AssociatedData.clear();
m_vCollidingGeomPairs.clear();
m_sCollidingPrimitiveTypes = 0u;
}
void
OctreeBasedCD::addCollisionPair(const std::shared_ptr<Geometry>& geom1, const std::shared_ptr<Geometry>& geom2,
const CollisionDetection::Type collisionType)
{
// Collision pairs are encoded as 64 bit unsigned integer
// The first 32 bit is obj1Idx, following by 32 bit of obj2Idx
const auto objIdx1 = geom1->getGlobalIndex();
const auto objIdx2 = geom2->getGlobalIndex();
const auto collisionPair = computeCollisionPairHash(objIdx1, objIdx2);
LOG_IF(FATAL, (m_mCollisionPair2AssociatedData.find(collisionPair) != m_mCollisionPair2AssociatedData.end()))
<< "Collision pair has previously been added";
m_mCollisionPair2AssociatedData[collisionPair] = { collisionType, std::make_shared<CollisionData>() };
m_vCollidingGeomPairs.push_back({ geom1.get(), geom2.get() });
const auto geomType1 = geom1->getType();
const auto geomType2 = geom2->getType();
if (geomType1 == Geometry::Type::PointSet
|| geomType2 == Geometry::Type::PointSet)
{
const uint32_t mask = 1 << static_cast<int>(OctreePrimitiveType::Point);
m_sCollidingPrimitiveTypes |= mask;
}
if (geomType1 == Geometry::Type::SurfaceMesh
|| geomType2 == Geometry::Type::SurfaceMesh)
{
const uint32_t mask = 1 << static_cast<int>(OctreePrimitiveType::Triangle);
m_sCollidingPrimitiveTypes |= mask;
}
if (geomType1 != Geometry::Type::PointSet
|| geomType2 != Geometry::Type::PointSet
|| geomType1 != Geometry::Type::SurfaceMesh
|| geomType2 != Geometry::Type::SurfaceMesh)
{
const uint32_t mask = 1 << static_cast<int>(OctreePrimitiveType::AnalyticalGeometry);
m_sCollidingPrimitiveTypes |= mask;
}
LOG(INFO) << m_Name << ":: Add collision pair between objects '"
<< geom1->getName() << "' (ID = " << objIdx1 << ") and '"
<< geom2->getName() << "' (ID = " << objIdx2 << ")";
}
const std::shared_ptr<CollisionData>&
OctreeBasedCD::getCollisionPairData(const uint32_t geomIdx1, const uint32_t geomIdx2)
{
const auto collisionPair = computeCollisionPairHash(geomIdx1, geomIdx2);
const auto it = m_mCollisionPair2AssociatedData.find(collisionPair);
LOG_IF(FATAL, (it == m_mCollisionPair2AssociatedData.end())) << "Collision pair does not exist";
return it->second.m_CollisionData;
}
void
OctreeBasedCD::detectCollision()
{
for (auto& kv : m_mCollisionPair2AssociatedData)
{
// Clear all collision data
kv.second.m_CollisionData->clearAll();
}
// Clear invalid flags for point-mesh collision pairs
m_mInvalidPointMeshCollisions.clear();
for (int type = 0; type < OctreePrimitiveType::NumPrimitiveTypes; ++type)
{
const auto& vPrimitivePtrs = m_vPrimitivePtrs[type];
if (vPrimitivePtrs.size() > 0 && hasCollidingPrimitive(type))
{
ParallelUtils::parallelFor(vPrimitivePtrs.size(),
[&](const size_t idx)
{
const auto pPrimitive = vPrimitivePtrs[idx];
if (type == OctreePrimitiveType::Point)
{
checkPointWithSubtree(m_pRootNode, pPrimitive, pPrimitive->m_GeomIdx);
}
else
{
const auto& lowerCorner = pPrimitive->m_LowerCorner;
const auto& upperCorner = pPrimitive->m_UpperCorner;
const Vec3r center(
(lowerCorner[0] + upperCorner[0]) * 0.5,
(lowerCorner[1] + upperCorner[1]) * 0.5,
(lowerCorner[2] + upperCorner[2]) * 0.5);
checkNonPointWithSubtree(m_pRootNode, pPrimitive, pPrimitive->m_GeomIdx,
lowerCorner, upperCorner, static_cast<OctreePrimitiveType>(type));
}
});
}
}
// Remove all invalid collision between point-mesh
for (auto& geoPair: m_vCollidingGeomPairs)
{
if (geoPair.first->getType() != Geometry::Type::PointSet)
{
continue;
}
auto& collisionData = getCollisionPairData(geoPair.first->getGlobalIndex(), geoPair.second->getGlobalIndex());
if (collisionData->VTColData.getSize() == 0)
{
continue;
}
#if defined(DEBUG) || defined(_DEBUG) || !defined(NDEBUG)
LOG_IF(FATAL, (geoPair.second->getType() != Geometry::Type::SurfaceMesh))
<< "Incorrectly detected invalid collision between point and geometry that is not a surface mesh";
#endif
const auto geomIdxPointSet = geoPair.first->getGlobalIndex();
const auto geomIdxMesh = geoPair.second->getGlobalIndex();
size_t writeIdx = 0;
for (size_t readIdx = 0; readIdx < collisionData->VTColData.getSize(); ++readIdx)
{
const auto& vt = collisionData->VTColData[readIdx];
if (pointStillColliding(vt.vertexIdx, geomIdxPointSet, geomIdxMesh))
{
if (readIdx != writeIdx)
{
collisionData->VTColData.setElement(writeIdx, collisionData->VTColData[readIdx]);
}
++writeIdx;
}
}
collisionData->VTColData.resize(writeIdx);
}
}
void
OctreeBasedCD::checkPointWithSubtree(OctreeNode* const pNode, OctreePrimitive* const pPrimitive, const uint32_t geomIdx)
{
if (!pNode->looselyContains(pPrimitive->m_Position))
{
return;
}
if (!pNode->isLeaf())
{
for (uint32_t childIdx = 0; childIdx < 8u; childIdx++)
{
OctreeNode* const pChildNode = &pNode->m_pChildren->m_Nodes[childIdx];
checkPointWithSubtree(pChildNode, pPrimitive, geomIdx);
}
}
for (int type = 0; type < OctreePrimitiveType::NumPrimitiveTypes; ++type)
{
// Points do not collide with points
if (type == OctreePrimitiveType::Point)
{
continue;
}
auto pIter = pNode->m_pPrimitiveListHeads[type];
#if defined(DEBUG) || defined(_DEBUG) || !defined(NDEBUG)
uint32_t count = 0;
#endif
while (pIter)
{
#if defined(DEBUG) || defined(_DEBUG) || !defined(NDEBUG)
++count;
#endif
if (pPrimitive != pIter)
{
const auto geomIdxOther = pIter->m_GeomIdx;
if (pointStillColliding(pPrimitive->m_Idx, geomIdx, geomIdxOther))
{
const auto collisionPair = computeCollisionPairHash(geomIdx, geomIdxOther);
const auto& collisionAssociatedData = getCollisionPairAssociatedData(collisionPair);
if (collisionAssociatedData.m_CollisionData != nullptr)
{
checkPointWithPrimitive(pPrimitive, pIter, collisionAssociatedData);
}
}
}
pIter = pIter->m_pNext;
}
#if defined(DEBUG) || defined(_DEBUG) || !defined(NDEBUG)
LOG_IF(FATAL, (count != pNode->m_PrimitiveCounts[type])) << "Internal data corrupted";
#endif
}
}
void
OctreeBasedCD::checkNonPointWithSubtree(OctreeNode* const pNode, OctreePrimitive* const pPrimitive,
const uint32_t geomIdx, const std::array<Real, 3>& lowerCorner, const std::array<Real, 3>& upperCorner,
const OctreePrimitiveType type)
{
if (!pNode->looselyOverlaps(lowerCorner, upperCorner))
{
return;
}
if (!pNode->isLeaf())
{
for (uint32_t childIdx = 0; childIdx < 8u; childIdx++)
{
OctreeNode* const pChildNode = &pNode->m_pChildren->m_Nodes[childIdx];
checkNonPointWithSubtree(pChildNode, pPrimitive, geomIdx, lowerCorner, upperCorner, type);
}
}
for (int i = 0; i < OctreePrimitiveType::NumPrimitiveTypes; ++i)
{
auto pIter = pNode->m_pPrimitiveListHeads[i];
#if defined(DEBUG) || defined(_DEBUG) || !defined(NDEBUG)
uint32_t count = 0;
#endif
while (pIter)
{
#if defined(DEBUG) || defined(_DEBUG) || !defined(NDEBUG)
++count;
#endif
if (pPrimitive != pIter)
{
// todo: this is necessary but not help
const auto geomIdxIter = pIter->m_GeomIdx;
const auto collisionPair = computeCollisionPairHash(geomIdx, geomIdxIter);
const auto collisionAssociatedData = getCollisionPairAssociatedData(collisionPair);
if (collisionAssociatedData.m_CollisionData != nullptr) // Has collision pair
{
const auto& lowerCornerIter = pIter->m_LowerCorner;
const auto& upperCornerIter = pIter->m_UpperCorner;
const Vec3r centerIter(
(lowerCorner[0] + upperCorner[0]) * 0.5,
(lowerCorner[1] + upperCorner[1]) * 0.5,
(lowerCorner[2] + upperCorner[2]) * 0.5);
if (CollisionUtils::testAABBToAABB(lowerCorner[0], upperCorner[0],
lowerCorner[1], upperCorner[1],
lowerCorner[2], upperCorner[2],
lowerCornerIter[0], upperCornerIter[0],
lowerCornerIter[1], upperCornerIter[1],
lowerCornerIter[2], upperCornerIter[2]))
{
checkNonPointWithPrimitive(pPrimitive, pIter, collisionAssociatedData);
//
//
// TODO: May not check here, but collect into an array
//
//
//
// });
}
}
}
pIter = pIter->m_pNext;
}
#if defined(DEBUG) || defined(_DEBUG) || !defined(NDEBUG)
LOG_IF(FATAL, (count != pNode->m_PrimitiveCounts[i])) << "Internal data corrupted";
#endif
}
}
void
OctreeBasedCD::checkPointWithPrimitive(OctreePrimitive* const pPrimitive1, OctreePrimitive* const pPrimitive2,
const CollisionPairAssociatedData& collisionAssociatedData)
{
const auto collisionType = collisionAssociatedData.m_Type;
const auto& collisionData = collisionAssociatedData.m_CollisionData;
const auto point = Vec3r(pPrimitive1->m_Position[0], pPrimitive1->m_Position[1], pPrimitive1->m_Position[2]);
const auto pointIdx = pPrimitive1->m_Idx;
switch (collisionType)
{
case CollisionDetection::Type::PointSetToSurfaceMesh:
if (!NarrowPhaseCD::pointToTriangle(point, pointIdx, pPrimitive2->m_Idx, pPrimitive2->m_pGeometry, collisionData))
{
setPointMeshCollisionInvalid(pointIdx, pPrimitive1->m_GeomIdx, pPrimitive2->m_GeomIdx);
}
break;
case CollisionDetection::Type::PointSetToSphere:
NarrowPhaseCD::pointToSphere(point, pointIdx, pPrimitive2->m_pGeometry, collisionData);
break;
case CollisionDetection::Type::PointSetToPlane:
NarrowPhaseCD::pointToPlane(point, pointIdx, pPrimitive2->m_pGeometry, collisionData);
break;
case CollisionDetection::Type::PointSetToCapsule:
NarrowPhaseCD::pointToCapsule(point, pointIdx, pPrimitive2->m_pGeometry, collisionData);
break;
case CollisionDetection::Type::PointSetToSpherePicking:
NarrowPhaseCD::pointToSpherePicking(point, pointIdx, pPrimitive2->m_pGeometry, collisionData);
break;
default:
LOG(FATAL) << "Unsupported collision type";
}
}
void
OctreeBasedCD::checkNonPointWithPrimitive(OctreePrimitive* const pPrimitive1, OctreePrimitive* const pPrimitive2,
const CollisionPairAssociatedData& collisionAssociatedData)
{
const auto collisionType = collisionAssociatedData.m_Type;
const auto& collisionData = collisionAssociatedData.m_CollisionData;
switch (collisionType)
{
// Mesh to mesh
case CollisionDetection::Type::SurfaceMeshToSurfaceMesh:
NarrowPhaseCD::triangleToTriangle(pPrimitive1->m_Idx, pPrimitive1->m_pGeometry,
pPrimitive2->m_Idx, pPrimitive2->m_pGeometry,
collisionData);
break;
// Analytical object to analytical object
case CollisionDetection::Type::UnidirectionalPlaneToSphere:
NarrowPhaseCD::unidirectionalPlaneToSphere(pPrimitive1->m_pGeometry, pPrimitive2->m_pGeometry, collisionData);
break;
case CollisionDetection::Type::BidirectionalPlaneToSphere:
NarrowPhaseCD::bidirectionalPlaneToSphere(pPrimitive1->m_pGeometry, pPrimitive2->m_pGeometry, collisionData);
break;
case CollisionDetection::Type::SphereToCylinder:
NarrowPhaseCD::sphereToCylinder(pPrimitive1->m_pGeometry, pPrimitive2->m_pGeometry, collisionData);
break;
case CollisionDetection::Type::SphereToSphere:
NarrowPhaseCD::sphereToSphere(pPrimitive1->m_pGeometry, pPrimitive2->m_pGeometry, collisionData);
break;
default:
LOG(FATAL) << "Unsupported collision type";
}
}
uint64_t
OctreeBasedCD::computeCollisionPairHash(const uint32_t objIdx1, const uint32_t objIdx2)
{
const uint64_t uint64Idx1 = static_cast<uint64_t>(objIdx1);
const uint64_t uint64Idx2 = static_cast<uint64_t>(objIdx2);
return (uint64Idx1 << 32) | uint64Idx2;
}
const OctreeBasedCD::CollisionPairAssociatedData&
OctreeBasedCD::getCollisionPairAssociatedData(const uint64_t collisionPair) const
{
static const auto invalidData = CollisionPairAssociatedData { CollisionDetection::Type::Custom, nullptr };
const auto it = m_mCollisionPair2AssociatedData.find(collisionPair);
return (it != m_mCollisionPair2AssociatedData.end()) ? it->second : invalidData;
}
bool
OctreeBasedCD::pointStillColliding(const uint32_t primitiveIdx, const uint32_t geometryIdx,
const uint32_t otherGeometryIdx)
{
const uint64_t uint64PrimitiveIdx = static_cast<uint64_t>(primitiveIdx);
const uint64_t uint64GeometryIdx = static_cast<uint64_t>(geometryIdx);
const uint64_t source = (uint64PrimitiveIdx << 32) | uint64GeometryIdx;
auto it = m_mInvalidPointMeshCollisions.find(source);
if (it == m_mInvalidPointMeshCollisions.end())
{
return true;
}
const auto& invalidTargets = it->second;
return invalidTargets.find(otherGeometryIdx) == invalidTargets.end();
}
void
OctreeBasedCD::setPointMeshCollisionInvalid(const uint32_t primitiveIdx, const uint32_t geometryIdx,
const uint32_t otherGeometryIdx)
{
const uint64_t uint64PrimitiveIdx = static_cast<uint64_t>(primitiveIdx);
const uint64_t uint64GeometryIdx = static_cast<uint64_t>(geometryIdx);
const uint64_t source = (uint64PrimitiveIdx << 32) | uint64GeometryIdx;
m_mInvalidPointMeshCollisions[source].insert(otherGeometryIdx);
}
} // end namespace imstk
/*=========================================================================
Library: iMSTK
Copyright (c) Kitware, Inc. & Center for Modeling, Simulation,
& Imaging in Medicine, Rensselaer Polytechnic Institute.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0.txt
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=========================================================================*/
#pragma once
#include "imstkLooseOctree.h"
#include "imstkCollisionDetection.h"
namespace imstk
{
struct CollisionData;
///
/// \brief Class OctreeBasedCD, a subclass of LooseOctree which supports collision detection between octree primitives
///
class OctreeBasedCD : public LooseOctree
{
friend class OctreeBasedCDTest;
public:
OctreeBasedCD(const OctreeBasedCD&) = delete;
OctreeBasedCD& operator=(const OctreeBasedCD&) = delete;
///
/// \brief OctreeBasedCD
/// \param center The center of the tree, which also is the center of the root node
/// \param width Width of the octree bounding box
/// \param minWidth Minimum allowed width of the tree nodes, valid only if there are only points primitives
/// \param minWidthRatio If there is primitive that is not a point, minWidth will be recomputed as minWidth = min(width of all non-point primitives) * minWidthRatio
/// \param name Name of the octree
///
explicit OctreeBasedCD(const Vec3r& center, const Real width, const Real minWidth,
const Real minWidthRatio = 1.0, const std::string name = "OctreeBasedCD") :
LooseOctree(center, width, minWidth, minWidthRatio, name) {}
///
/// \brief Clear all primitive and geometry and collision data, but still keep nodes data in memory pool
///
virtual void clear() override;
///
/// \brief Define a collision pair between two geometry objects
/// The collisionType parameter must be valid (no check), otherwise will result in undefined behaviors
///
///
void addCollisionPair(const std::shared_ptr<Geometry>& geom1, const std::shared_ptr<Geometry>& geom2,
const CollisionDetection::Type collisionType);
///
/// \brief Get pairs of geometries from the added collision pairs
/// \return List of pairs of geometries, each pair corresponds to an added collision pair
///
const std::vector<std::pair<Geometry*, Geometry*>> getCollidingGeometryPairs() const { return m_vCollidingGeomPairs; }
///
/// \brief Check for collision between pritimives in the tree, based on the provided collision pairs
///
void detectCollision();
///
/// \brief Get CollisionData for a collision pair between two geometries
/// (That collision pair must be added before accessing collision data by this function)
/// For performance reason, to avoid casting between geomery pointers, the function only accepts global indices of geometries
/// \param geomeIdx1 Global index of the first geometry
/// \param geomeIdx2 Global index of the second geometry
///
const std::shared_ptr<CollisionData>& getCollisionPairData(const uint32_t geomIdx1, const uint32_t geomIdx2);
private:
///
/// \brief The CollisionPairAssociatedData struct
/// For each collision pair, map it with a collision type and collision data
///
struct CollisionPairAssociatedData
{
CollisionDetection::Type m_Type;
std::shared_ptr<CollisionData> m_CollisionData;
};
///
/// \brief Check for collisions of the given point primitive with primitives in the subtree rooting at the given tree node
/// The collision checks are not brute-force, since the (loose) bounding boxes of the tree nodes are use to prune unnecessary checks
///
void checkPointWithSubtree(OctreeNode* const pNode, OctreePrimitive* const pPrimitive, const uint32_t geomIdx);
///
/// \brief Check for collisions of the given non-point primitive with primitives in the subtree rooting at the given tree node
/// The collision checks are not brute-force, since the (loose) bounding boxes of the tree nodes are use to prune unnecessary checks
///
void checkNonPointWithSubtree(OctreeNode* const pNode, OctreePrimitive* const pPrimitive, const uint32_t geomIdx,
const std::array<Real, 3>& lowerCorner, const std::array<Real, 3>& upperCorner,
const OctreePrimitiveType type);
///
/// \brief Check for narrow-phase collision between a point primitive with another primitive
///
void checkPointWithPrimitive(OctreePrimitive* const pPrimitive1, OctreePrimitive* const pPrimitive2,
const CollisionPairAssociatedData& collisionAssociatedData);
///
/// \brief Check for narrow-phase collision between a non-point primitive with another primitive
///
void checkNonPointWithPrimitive(OctreePrimitive* const pPrimitive1, OctreePrimitive* const pPrimitive2,
const CollisionPairAssociatedData& collisionAssociatedData);
///
/// \brief Compute the hash value for a collision pair between two geometry objects
/// The hash value is computed as concatenation of the two objects' global indices
/// \param objIdx1 Global index of the first geometry
/// \param objIdx2 Global index of the second geometry
///
uint64_t computeCollisionPairHash(const uint32_t objIdx1, const uint32_t objIdx2);
///
/// \brief Get associated data for a given collision pair
/// \param collisionPair The hash value of the given collision pair
///
const CollisionPairAssociatedData& getCollisionPairAssociatedData(const uint64_t collisionPair) const;
///
/// \brief Return true if any of the added collision pairs contains primitives of the given type
/// This is used to avoid unnecessary collision check
/// For example, if there was pointset(s) added but there is only collision pair between triangle meshes
/// then we can totally ignore all point primitives during collision detection
///
bool hasCollidingPrimitive(const int type) const { return (m_sCollidingPrimitiveTypes & (1 << type)) != 0; }
///
/// \brief Return true if the collision between a point and a triangle mesh is still valid
/// This is applied specifically for point primitive
/// When a point P is detected as above surface of a triangle ABC ( dot(P-A, n) > 0, where n is triangle normal pointing outward of the mesh)
/// then that point is obviously outside of the mesh containiing the triangle ABC
/// Thus, collisions between point P and all triangles of that mesh should be mark as invalid and discarded
/// \param primitiveIdx Index of the given point in the pointset
/// \param geometryIdx Global index of the parent pointset of the given point
/// \param otherGeometryIdx Global index of the surface (triangle) mesh that the given point is colliding
///
bool pointStillColliding(const uint32_t primitiveIdx, const uint32_t geometryIdx, const uint32_t otherGeometryIdx);
///
/// \brief Mark all the collisions between a point and triangles of a surface mesh is invalid
/// \param primitiveIdx Index of the given point in the pointset
/// \param geometryIdx Global index of the parent pointset of the given point
/// \param otherGeometryIdx Global index of the surface (triangle) mesh that the given point is colliding
///
void setPointMeshCollisionInvalid(const uint32_t primitiveIdx, const uint32_t geometryIdx, const uint32_t otherGeometryIdx);
/// For each collision pair, related primitives need to be marked as colliding
/// (for example, for pointset-surface mesh collision pair, 'point' and 'triangle' are now 'colliding primitives')
/// This variable is used to avoid unnecessary collision check