Commit c63f3635 authored by Utkarsh Ayachit's avatar Utkarsh Ayachit
Browse files

diy: pass operator instance to mpi_op<>::get()

This makes it possible to add custom MPI reduction operations without
having to add too complex logic to allocating and freeing MPI operation.
parent 42d5be31
......@@ -152,13 +152,13 @@ namespace mpi
}
}
static void reduce(const communicator& comm, const T& in, T& out, int root, const Op&)
static void reduce(const communicator& comm, const T& in, T& out, int root, const Op& op)
{
MPI_Reduce(Datatype::address(const_cast<T&>(in)),
Datatype::address(out),
Datatype::count(in),
Datatype::datatype(),
detail::mpi_op<Op>::get(),
detail::mpi_op<Op>::get(op),
root, comm);
}
......@@ -168,38 +168,38 @@ namespace mpi
Datatype::address(const_cast<T&>(in)),
Datatype::count(in),
Datatype::datatype(),
detail::mpi_op<Op>::get(),
detail::mpi_op<Op>::get(op),
root, comm);
}
static void all_reduce(const communicator& comm, const T& in, T& out, const Op&)
static void all_reduce(const communicator& comm, const T& in, T& out, const Op& op)
{
MPI_Allreduce(Datatype::address(const_cast<T&>(in)),
Datatype::address(out),
Datatype::count(in),
Datatype::datatype(),
detail::mpi_op<Op>::get(),
detail::mpi_op<Op>::get(op),
comm);
}
static void all_reduce(const communicator& comm, const std::vector<T>& in, std::vector<T>& out, const Op&)
static void all_reduce(const communicator& comm, const std::vector<T>& in, std::vector<T>& out, const Op& op)
{
out.resize(in.size());
MPI_Allreduce(Datatype::address(const_cast<T&>(in[0])),
Datatype::address(out[0]),
in.size(),
Datatype::datatype(),
detail::mpi_op<Op>::get(),
detail::mpi_op<Op>::get(op),
comm);
}
static void scan(const communicator& comm, const T& in, T& out, const Op&)
static void scan(const communicator& comm, const T& in, T& out, const Op& op)
{
MPI_Scan(Datatype::address(const_cast<T&>(in)),
Datatype::address(out),
Datatype::count(in),
Datatype::datatype(),
detail::mpi_op<Op>::get(),
detail::mpi_op<Op>::get(op),
comm);
}
......
......@@ -14,13 +14,13 @@ namespace mpi
namespace detail
{
template<class T> struct mpi_op { static MPI_Op get(); };
template<class U> struct mpi_op< maximum<U> > { static MPI_Op get() { return MPI_MAX; } };
template<class U> struct mpi_op< minimum<U> > { static MPI_Op get() { return MPI_MIN; } };
template<class U> struct mpi_op< std::plus<U> > { static MPI_Op get() { return MPI_SUM; } };
template<class U> struct mpi_op< std::multiplies<U> > { static MPI_Op get() { return MPI_PROD; } };
template<class U> struct mpi_op< std::logical_and<U> > { static MPI_Op get() { return MPI_LAND; } };
template<class U> struct mpi_op< std::logical_or<U> > { static MPI_Op get() { return MPI_LOR; } };
template<class T> struct mpi_op { static MPI_Op get(const T&); };
template<class U> struct mpi_op< maximum<U> > { static MPI_Op get(const maximum<U>&) { return MPI_MAX; } };
template<class U> struct mpi_op< minimum<U> > { static MPI_Op get(const minimum<U>&) { return MPI_MIN; } };
template<class U> struct mpi_op< std::plus<U> > { static MPI_Op get(const std::plus<U>&) { return MPI_SUM; } };
template<class U> struct mpi_op< std::multiplies<U> > { static MPI_Op get(const std::multiplies<U>&) { return MPI_PROD; } };
template<class U> struct mpi_op< std::logical_and<U> > { static MPI_Op get(const std::logical_and<U>&) { return MPI_LAND; } };
template<class U> struct mpi_op< std::logical_or<U> > { static MPI_Op get(const std::logical_or<U>&) { return MPI_LOR; } };
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment