Skip to content
Snippets Groups Projects
Commit b961cbcd authored by Harald Scheirich's avatar Harald Scheirich
Browse files

ENH: Common Factory infrastructure

replaces the separate implementation for the factory
parent a516aeb0
No related branches found
No related tags found
No related merge requests found
......@@ -40,38 +40,34 @@ limitations under the License.
namespace imstk
{
// Manually register creation functions
std::unordered_map<std::string, CDObjectFactory::CDMakeFunc> CDObjectFactory::cdObjCreationMap =
std::unordered_map<std::string, CDObjectFactory::CDMakeFunc>();
REGISTER_COLLISION_DETECTION(BidirectionalPlaneToSphereCD);
REGISTER_COLLISION_DETECTION(ImplicitGeometryToPointSetCD);
REGISTER_COLLISION_DETECTION(ImplicitGeometryToPointSetCCD);
REGISTER_COLLISION_DETECTION(MeshToMeshBruteForceCD);
REGISTER_COLLISION_DETECTION(PointSetToCapsuleCD);
REGISTER_COLLISION_DETECTION(PointSetToPlaneCD);
REGISTER_COLLISION_DETECTION(PointSetToSphereCD);
REGISTER_COLLISION_DETECTION(PointSetToOrientedBoxCD);
REGISTER_COLLISION_DETECTION(SphereToCylinderCD);
REGISTER_COLLISION_DETECTION(SphereToSphereCD);
REGISTER_COLLISION_DETECTION(SurfaceMeshToSurfaceMeshCD);
REGISTER_COLLISION_DETECTION(SurfaceMeshToCapsuleCD);
REGISTER_COLLISION_DETECTION(SurfaceMeshToSphereCD);
REGISTER_COLLISION_DETECTION(TetraToPointSetCD);
REGISTER_COLLISION_DETECTION(TetraToLineMeshCD);
REGISTER_COLLISION_DETECTION(UnidirectionalPlaneToSphereCD);
IMSTK_REGISTER_COLLISION_DETECTION(BidirectionalPlaneToSphereCD);
IMSTK_REGISTER_COLLISION_DETECTION(ImplicitGeometryToPointSetCD);
IMSTK_REGISTER_COLLISION_DETECTION(ImplicitGeometryToPointSetCCD);
IMSTK_REGISTER_COLLISION_DETECTION(MeshToMeshBruteForceCD);
IMSTK_REGISTER_COLLISION_DETECTION(PointSetToCapsuleCD);
IMSTK_REGISTER_COLLISION_DETECTION(PointSetToPlaneCD);
IMSTK_REGISTER_COLLISION_DETECTION(PointSetToSphereCD);
IMSTK_REGISTER_COLLISION_DETECTION(PointSetToOrientedBoxCD);
IMSTK_REGISTER_COLLISION_DETECTION(SphereToCylinderCD);
IMSTK_REGISTER_COLLISION_DETECTION(SphereToSphereCD);
IMSTK_REGISTER_COLLISION_DETECTION(SurfaceMeshToSurfaceMeshCD);
IMSTK_REGISTER_COLLISION_DETECTION(SurfaceMeshToCapsuleCD);
IMSTK_REGISTER_COLLISION_DETECTION(SurfaceMeshToSphereCD);
IMSTK_REGISTER_COLLISION_DETECTION(TetraToPointSetCD);
IMSTK_REGISTER_COLLISION_DETECTION(TetraToLineMeshCD);
IMSTK_REGISTER_COLLISION_DETECTION(UnidirectionalPlaneToSphereCD);
std::shared_ptr<CollisionDetectionAlgorithm>
CDObjectFactory::makeCollisionDetection(const std::string collisionTypeName)
{
if (cdObjCreationMap.count(collisionTypeName) == 0)
if (!contains(collisionTypeName))
{
LOG(FATAL) << "No collision detection type named: " << collisionTypeName;
return nullptr;
}
else
{
return cdObjCreationMap.at(collisionTypeName)();
return create(collisionTypeName);
}
}
} // namespace imstk
\ No newline at end of file
......@@ -21,6 +21,8 @@ limitations under the License.
#pragma once
#include "imstkFactory.h"
#include <functional>
#include <memory>
#include <string>
......@@ -30,51 +32,15 @@ namespace imstk
{
class CollisionDetectionAlgorithm;
///
/// \class CDObjectFactory
///
/// \brief This is the factory class for CollisionDetectionAlgorithm. It may be
/// used to construct CollisionDetectionAlgorithm objects by name.
/// Note: Does not auto register CollisionDetectionAlgorithm's. If one creates
/// their own CollisionDetectionAlgorithm they must register themselves.
///
class CDObjectFactory
class CDObjectFactory : public ObjectFactory<std::shared_ptr<CollisionDetectionAlgorithm>>
{
public:
using CDMakeFunc = std::function<std::shared_ptr<CollisionDetectionAlgorithm>()>;
///
/// \brief Register the CollisionDetectionAlgorithm creation function given name
///
static void registerCD(std::string name, CDMakeFunc func)
{
cdObjCreationMap[name] = func;
}
///
/// \brief Creates a CollisionDetectionAlgorithm object by name if registered to factory
///
static std::shared_ptr<CollisionDetectionAlgorithm> makeCollisionDetection(const std::string collisionTypeName);
private:
static std::unordered_map<std::string, CDMakeFunc> cdObjCreationMap;
static std::shared_ptr<CollisionDetectionAlgorithm>
makeCollisionDetection(const std::string collisionTypeName);
};
///
/// \class CDObjectRegistrar
///
/// \brief Construction of this object will register to the CDObjectFactory. One could
/// construct this at the bottom of their CollisionDetectionAlgorithm when building
/// dynamic libraries or executables for static initialization.
///
template<typename T>
class CDObjectRegistrar
{
public:
CDObjectRegistrar(std::string name)
{
CDObjectFactory::registerCD(name, []() { return std::make_shared<T>(); });
}
};
#define REGISTER_COLLISION_DETECTION(cdType) CDObjectRegistrar<cdType> __register ## cdType(#cdType)
} // namespace imstk
\ No newline at end of file
using CDObjectRegistrar = SharedObjectRegistrar<CollisionDetectionAlgorithm, T>;
#define IMSTK_REGISTER_COLLISION_DETECTION(objType) CDObjectRegistrar<objType> _imstk_registercd ## objType(#objType)
} // namespace imstk
/*=========================================================================
Library: iMSTK
Copyright (c) Kitware, Inc. & Center for Modeling, Simulation,
& Imaging in Medicine, Rensselaer Polytechnic Institute.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0.txt
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=========================================================================*/
#include "gtest/gtest.h"
#include "imstkFactory.h"
using namespace imstk;
namespace
{
struct A
{
virtual int val() { return 1; };
virtual ~A() {};
};
struct B : public A
{
virtual ~B() {};
int val() override { return 2; };
};
struct C : public A
{
virtual ~C() {};
int val() override { return 3; };
};
struct AA
{
AA(int i) : m_i(i) {}
virtual int val() { return 1; };
virtual int vali() { return m_i; };
virtual ~AA() {};
protected:
int m_i;
};
struct BB : public AA
{
BB(int i) : AA(i) {}
virtual ~BB() {};
int val() override { return 2; };
};
struct CC : public AA
{
CC(int i) : AA(i) {}
virtual ~CC() {};
int val() override { return 3; };
};
} // namespace
TEST(FactoryTest, Instantiation)
{
ObjectFactory<A>::add("a", [] () { return A(); });
auto a = ObjectFactory<A>::create("a");
}
TEST(FactoryTest, DerivedClasses)
{
using TestFactory = ObjectFactory<std::shared_ptr<A>>;
TestFactory::add("b", []() { return std::make_shared<B>(); });
TestFactory::add("c", []() { return std::make_shared<C>(); });
auto b = TestFactory::create("b");
ASSERT_NE(nullptr, b);
EXPECT_EQ(2, b->val());
auto c = TestFactory::create("c");
EXPECT_EQ(3, c->val());
}
TEST(FactoryTest, RegistrarTest)
{
using TestFactory = ObjectFactory<std::shared_ptr<A>>;
{
auto a = SharedObjectRegistrar<A, B>("b");
auto b = SharedObjectRegistrar<A, C>("c");
}
auto b = TestFactory::create("b");
ASSERT_NE(nullptr, b);
EXPECT_EQ(2, b->val());
auto c = TestFactory::create("c");
EXPECT_EQ(3, c->val());
}
TEST(FactoryTest, RegistrarTestWParams)
{
using TestFactory = ObjectFactory<std::shared_ptr<AA>, int>;
{
auto a = SharedObjectRegistrar<AA, BB, int>("b");
auto b = SharedObjectRegistrar<AA, CC, int>("c");
}
auto b = TestFactory::create("b", 10);
ASSERT_NE(nullptr, b);
EXPECT_EQ(2, b->val());
EXPECT_EQ(10, b->vali());
auto c = TestFactory::create("c", 20);
EXPECT_EQ(3, c->val());
EXPECT_EQ(20, c->vali());
}
/*=========================================================================
Library: iMSTK
Copyright (c) Kitware, Inc. & Center for Modeling, Simulation,
& Imaging in Medicine, Rensselaer Polytechnic Institute.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0.txt
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=========================================================================*/
#include <unordered_map>
#include <functional>
#include <string>
#include <memory>
namespace imstk
{
/// \brief Generic Factory class that can take objects with constructor parameters
///
/// All the objects in the factory need to be convertible to a common base class
///
/// \tparam T The base class type, is also the return type of create()
/// \tparam Args Parameter pack argument for the list of params used to pass into create
template<typename T, class ... Args>
class ObjectFactory
{
public:
using BaseType = T;
using Creator = std::function<T(Args...)>; ///< Type of the function to generate a new object
/// \brief tries to construct the object give name, it will forward the given paramters
/// \param name Name to use for lookup
/// \param args parameters to pass into the creation function
/// \throws if name was not registered
static T create(const std::string& name, Args&& ... args)
{
// at will throw if name doesn't exist in the map
return registry().at(name)(std::forward<Args>(args)...);
}
/// \brief adds a new creation function to the factory
/// \param name Name to use, will overwrite an already defined name
/// \param c function to be called when create() is called with the given name
static void add(const std::string& name, typename Creator c)
{
registry()[name] = std::move(c);
}
/// \return true if the name can be found in the registry
static bool contains(const std::string& name)
{
return registry().find(name) != registry().cend();
}
private:
using Registry = std::unordered_map<std::string, Creator>;
/// \brief static registry, let's us use static functions with a static data member
static Registry& registry()
{
static Registry registry = {};
return registry;
}
};
/// \brief Templated class that can add to the object factory with objects that will
/// be generated via `std::make_shared`
///
/// \tparam T The base class type (see ObjectFactory)
/// \tparam U The class that should be generated here (needs to be a subclass of T)
/// \tparam Args constructor parameter types, these can then be pass in ObjectFactory::create
template<typename T, typename U, typename ... Args>
class SharedObjectRegistrar
{
public:
/// \brief The constructor can automatically register the given class in the Factory
/// For example it can be used in global scope in an implementation file
SharedObjectRegistrar(std::string name)
{
static_assert(std::is_base_of<T, U>::value,
"U must be a subclass of T");
ObjectFactory<std::shared_ptr<T>, Args ...>::add(name, [](Args&& ... args) { return std::make_shared<U>(std::forward<Args>(args)...); });
}
};
} // namespace imstk
......@@ -50,22 +50,21 @@ limitations under the License.
namespace imstk
{
std::unordered_map<std::string, RenderDelegateObjectFactory::DelegateMakeFunc> RenderDelegateObjectFactory::m_objCreationMap =
{
{ SurfaceMesh::getStaticTypeName(), makeFunc<VTKSurfaceMeshRenderDelegate>() },
{ TetrahedralMesh::getStaticTypeName(), makeFunc<VTKTetrahedralMeshRenderDelegate>() },
{ LineMesh::getStaticTypeName(), makeFunc<VTKLineMeshRenderDelegate>() },
{ HexahedralMesh::getStaticTypeName(), makeFunc<VTKHexahedralMeshRenderDelegate>() },
{ PointSet::getStaticTypeName(), makeFunc<VTKPointSetRenderDelegate>() },
{ Plane::getStaticTypeName(), makeFunc<VTKPlaneRenderDelegate>() },
{ Sphere::getStaticTypeName(), makeFunc<VTKSphereRenderDelegate>() },
{ Capsule::getStaticTypeName(), makeFunc<VTKCapsuleRenderDelegate>() },
{ OrientedBox::getStaticTypeName(), makeFunc<VTKOrientedCubeRenderDelegate>() },
{ Cylinder::getStaticTypeName(), makeFunc<VTKCylinderRenderDelegate>() },
{ ImageData::getStaticTypeName(), makeFunc<VTKImageDataRenderDelegate>() },
{ "Fluid", makeFunc<VTKFluidRenderDelegate>() },
{ "SurfaceNormals", makeFunc<VTKSurfaceNormalRenderDelegate>() }
};
IMSTK_REGISTER_RENDERDELEGATE(SurfaceMesh, VTKSurfaceMeshRenderDelegate)
IMSTK_REGISTER_RENDERDELEGATE(Cylinder, VTKCylinderRenderDelegate)
IMSTK_REGISTER_RENDERDELEGATE(TetrahedralMesh, VTKTetrahedralMeshRenderDelegate)
IMSTK_REGISTER_RENDERDELEGATE(LineMesh, VTKLineMeshRenderDelegate)
IMSTK_REGISTER_RENDERDELEGATE(HexahedralMesh, VTKHexahedralMeshRenderDelegate)
IMSTK_REGISTER_RENDERDELEGATE(PointSet, VTKPointSetRenderDelegate)
IMSTK_REGISTER_RENDERDELEGATE(Plane, VTKPlaneRenderDelegate)
IMSTK_REGISTER_RENDERDELEGATE(Sphere, VTKSphereRenderDelegate)
IMSTK_REGISTER_RENDERDELEGATE(Capsule, VTKCapsuleRenderDelegate)
IMSTK_REGISTER_RENDERDELEGATE(OrientedBox, VTKOrientedCubeRenderDelegate)
IMSTK_REGISTER_RENDERDELEGATE(ImageData, VTKImageDataRenderDelegate)
// Custom algorithms
RenderDelegateRegistrar<VTKFluidRenderDelegate> _imstk_registerrenderdelegate_fluid("Fluid");
RenderDelegateRegistrar<VTKSurfaceNormalRenderDelegate> _imstk_registerrenderdelegate_surfacenormals("SurfaceNormals");
std::shared_ptr<VTKRenderDelegate>
RenderDelegateObjectFactory::makeRenderDelegate(std::shared_ptr<VisualModel> visualModel)
......@@ -78,15 +77,15 @@ RenderDelegateObjectFactory::makeRenderDelegate(std::shared_ptr<VisualModel> vis
<< visualModel->getName();
return nullptr;
}
auto found = m_objCreationMap.find(delegateHint);
if (found == m_objCreationMap.end())
if (!contains(delegateHint))
{
LOG(FATAL) << "RenderDelegate::makeDelegate error: can't find delegate with hint: "
<< delegateHint << " for visual model " << visualModel->getName();
return nullptr;
}
return found->second(visualModel);
// Still a bug, should be able to copy the visual model ptr as well
return create(delegateHint, std::move(visualModel));
}
} // namespace imstk
\ No newline at end of file
......@@ -21,6 +21,8 @@ limitations under the License.
#pragma once
#include "imstkCDObjectFactory.h"
#include <functional>
#include <memory>
#include <string>
......@@ -45,40 +47,16 @@ class VTKRenderDelegate;
/// Note: Does not auto register VTKRenderDelegate's. If one creates
/// their own VTKRenderDelegate they must register themselves.
///
class RenderDelegateObjectFactory
class VTKRenderDelegate;
class RenderDelegateObjectFactory : public ObjectFactory<std::shared_ptr<VTKRenderDelegate>, std::shared_ptr<VisualModel>>
{
public:
using DelegateMakeFunc = std::function<std::shared_ptr<VTKRenderDelegate>(std::shared_ptr<VisualModel>)>;
///
/// \brief Register the RenderDelegate creation function with
/// template type. Provide a delegateHint in the VisualModel to use
/// this creation function instead. Creation functions can be overridden
///
template<typename T>
static void registerDelegate(std::string name)
{
static_assert(std::is_base_of<VTKRenderDelegate, T>::value,
"T must be a subclass of VTKRenderDelegate");
m_objCreationMap[name] = makeFunc<T>();
}
template<typename T>
static DelegateMakeFunc makeFunc()
{
return [](std::shared_ptr<VisualModel> visualModel)
{
return std::make_shared<T>(visualModel);
};
}
///
/// \brief Creates a VTKRenderDelegate object by VisualModel if registered to factory
///
static std::shared_ptr<VTKRenderDelegate> makeRenderDelegate(std::shared_ptr<VisualModel> visualModel);
private:
static std::unordered_map<std::string, DelegateMakeFunc> m_objCreationMap;
};
#define REGISTER_RENDER_DELEGATE(delegateType) RenderDelegateObjectFactory::registerDelegate<delegateType>(#delegateType)
template<typename T>
using RenderDelegateRegistrar = SharedObjectRegistrar<VTKRenderDelegate, T, std::shared_ptr<VisualModel>>;
#define IMSTK_REGISTER_RENDERDELEGATE(geomType, objType) RenderDelegateRegistrar<objType> _imstk_registerrenderdelegate ## geomType(#geomType);
} // namespace imstk
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment